Spaces:
Runtime error
Runtime error
| # Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import sys | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
| import PIL.Image | |
| import torch | |
| import PIL | |
| import numpy as np | |
| from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5Tokenizer, T5EncoderModel | |
| from transformers.models.gemma2.modeling_gemma2 import Gemma2Model | |
| from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast | |
| from diffusers.image_processor import VaeImageProcessor | |
| from diffusers.models import VQModel | |
| from diffusers.utils import replace_example_docstring | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
| from diffusers.utils import BaseOutput | |
| from src.scheduler import Scheduler | |
| from src.transformer import SymmetricTransformer2DModel | |
| EXAMPLE_DOC_STRING = """ | |
| Examples: | |
| ```py | |
| >>> image = pipe(prompt).images[0] | |
| ``` | |
| """ | |
| def _prepare_latent_image_ids(batch_size, height, width, device, dtype): | |
| latent_image_ids = torch.zeros(height // 2, width // 2, 3) | |
| latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] | |
| latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] | |
| latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape | |
| latent_image_ids = latent_image_ids.reshape( | |
| latent_image_id_height * latent_image_id_width, latent_image_id_channels | |
| ) | |
| return latent_image_ids.to(device=device, dtype=dtype) | |
| def dedup_consecutive_words(text: str) -> str: | |
| """ | |
| >>> dedup_consecutive_words("hello hello world world world") | |
| 'hello world' | |
| """ | |
| words = text.split() | |
| if not words: | |
| return text | |
| out = [words[0]] | |
| for w in words[1:]: | |
| if w != out[-1]: | |
| out.append(w) | |
| return " ".join(out) | |
| def keep_upto_last_period(text: str) -> str: | |
| """ | |
| Return the substring up to (and including) the last period-mark. | |
| The function searches first for the Chinese full stop βγβ; | |
| if none is found, it falls back to the ASCII dot β.β. | |
| Parameters | |
| ---------- | |
| text : str | |
| Input string. | |
| Returns | |
| ------- | |
| str | |
| Substring ending at the final period-mark. If no period is present, | |
| the original string is returned unchanged. | |
| """ | |
| # Weired problem | |
| text = text.replace("such is such", "").replace("is such is", "").replace("such is", "").replace("is such", "") | |
| # Fallback to the ASCII period | |
| idx = -1 | |
| if idx == -1: | |
| idx = text.rfind(".") | |
| # If still not found, return original text | |
| if idx == -1: | |
| return text | |
| # Keep everything up to (and including) the last period | |
| return text[:idx + 1] | |
| class UnifiedPipelineOutput(BaseOutput): | |
| """ | |
| Output class for image pipelines. | |
| Args: | |
| images (`List[PIL.Image.Image]` or `np.ndarray`) | |
| List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width, | |
| num_channels)`. | |
| """ | |
| images: Union[List[PIL.Image.Image], np.ndarray] | |
| prompts: List[str] | |
| class UnifiedPipeline(DiffusionPipeline): | |
| image_processor: VaeImageProcessor | |
| vqvae: VQModel | |
| tokenizer: CLIPTokenizer | |
| tokenizer_2: T5Tokenizer | |
| text_encoder: CLIPTextModelWithProjection | |
| text_encoder_2: T5EncoderModel | |
| transformer: SymmetricTransformer2DModel | |
| scheduler: Scheduler | |
| model_cpu_offload_seq = "text_encoder->transformer->vqvae" | |
| def __init__( | |
| self, | |
| vqvae: VQModel, | |
| tokenizer: CLIPTokenizer, | |
| text_encoder: CLIPTextModelWithProjection, | |
| transformer: SymmetricTransformer2DModel, | |
| scheduler: Scheduler, | |
| tokenizer_2: T5Tokenizer = None, | |
| text_encoder_2: T5EncoderModel = None, | |
| ): | |
| super().__init__() | |
| self.register_modules( | |
| vqvae=vqvae, | |
| tokenizer=tokenizer, | |
| tokenizer_2=tokenizer_2, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| transformer=transformer, | |
| scheduler=scheduler, | |
| ) | |
| self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) | |
| self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) | |
| def __call__( | |
| self, | |
| prompt: Optional[Union[List[str], str]] = None, | |
| height: Optional[int] = 1024, | |
| width: Optional[int] = 1024, | |
| image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None, | |
| num_inference_steps: int = 48, | |
| guidance_scale: float = 9.0, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| num_images_per_prompt: Optional[int] = 1, | |
| generator: Optional[torch.Generator] = None, | |
| latents: Optional[torch.IntTensor] = None, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| negative_prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_encoder_hidden_states: Optional[torch.Tensor] = None, | |
| output_type = "pil", | |
| return_dict: bool = True, | |
| callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, | |
| callback_steps: int = 1, | |
| micro_conditioning_aesthetic_score: int = 6, | |
| micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), | |
| temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), | |
| mask_token_embedding: Optional[str] = None, | |
| ): | |
| """ | |
| The call function to the pipeline for generation. | |
| Args: | |
| prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. | |
| height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): | |
| The height in pixels of the generated image. | |
| width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): | |
| The width in pixels of the generated image. | |
| num_inference_steps (`int`, *optional*, defaults to 16): | |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
| expense of slower inference. | |
| guidance_scale (`float`, *optional*, defaults to 10.0): | |
| A higher guidance scale value encourages the model to generate images closely linked to the text | |
| `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. | |
| negative_prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to guide what to not include in image generation. If not defined, you need to | |
| pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). | |
| num_images_per_prompt (`int`, *optional*, defaults to 1): | |
| The number of images to generate per prompt. | |
| generator (`torch.Generator`, *optional*): | |
| A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | |
| generation deterministic. | |
| latents (`torch.IntTensor`, *optional*): | |
| Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image | |
| gneration. If not provided, the starting latents will be completely masked. | |
| prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not | |
| provided, text embeddings are generated from the `prompt` input argument. A single vector from the | |
| pooled and projected final hidden states. | |
| encoder_hidden_states (`torch.Tensor`, *optional*): | |
| Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. | |
| negative_prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If | |
| not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. | |
| negative_encoder_hidden_states (`torch.Tensor`, *optional*): | |
| Analogous to `encoder_hidden_states` for the positive prompt. | |
| output_type (`str`, *optional*, defaults to `"pil"`): | |
| The output format of the generated image. Choose between `PIL.Image` or `np.array`. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | |
| plain tuple. | |
| callback (`Callable`, *optional*): | |
| A function that calls every `callback_steps` steps during inference. The function is called with the | |
| following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. | |
| callback_steps (`int`, *optional*, defaults to 1): | |
| The frequency at which the `callback` function is called. If not specified, the callback is called at | |
| every step. | |
| cross_attention_kwargs (`dict`, *optional*): | |
| A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in | |
| [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | |
| micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): | |
| The targeted aesthetic score according to the laion aesthetic classifier. See | |
| https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of | |
| https://arxiv.org/abs/2307.01952. | |
| micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): | |
| The targeted height, width crop coordinates. See the micro-conditioning section of | |
| https://arxiv.org/abs/2307.01952. | |
| temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): | |
| Configures the temperature scheduler on `self.scheduler` see `Scheduler#set_timesteps`. | |
| Examples: | |
| Returns: | |
| [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: | |
| If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a | |
| `tuple` is returned where the first element is a list with the generated images. | |
| """ | |
| if (prompt_embeds is not None and encoder_hidden_states is None) or ( | |
| prompt_embeds is None and encoder_hidden_states is not None | |
| ): | |
| raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") | |
| if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( | |
| negative_prompt_embeds is None and negative_encoder_hidden_states is not None | |
| ): | |
| raise ValueError( | |
| "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" | |
| ) | |
| if self.text_encoder_2 is not None: | |
| self.text_encoder_2.to(self._execution_device) | |
| text2image = image is None | |
| image2text = image is not None | |
| if image2text: | |
| if self.text_encoder_2 is not None: | |
| prompt = "<extra_id_0>" * 256 | |
| prompt = [prompt] * len(image) | |
| t5_mask_id = self.tokenizer_2.convert_tokens_to_ids("<extra_id_0>") | |
| self.scheduler.config.mask_token_id = t5_mask_id | |
| else: | |
| mask_token = "<mask>" | |
| self.tokenizer.add_tokens(mask_token, special_tokens=False) | |
| clip_mask_id = self.tokenizer.convert_tokens_to_ids(mask_token) | |
| self.text_encoder.resize_token_embeddings(len(self.tokenizer)) | |
| if mask_token_embedding is not None: | |
| if mask_token_embedding.endswith(".pth"): | |
| mask_token_embedding = torch.load(mask_token_embedding) | |
| else: | |
| mask_token_embedding = os.path.dirname(mask_token_embedding) | |
| mask_token_embedding_path = os.path.join(mask_token_embedding, "mask_token_embedding.pth") | |
| assert os.path.exists(mask_token_embedding_path), f"{mask_token_embedding_path} doesn't exists!" | |
| mask_token_embedding = torch.load(mask_token_embedding_path) | |
| mask_token_embedding = mask_token_embedding.to(self._execution_device, dtype=self.text_encoder.dtype) | |
| self.text_encoder.get_input_embeddings().weight.data[clip_mask_id].copy_(mask_token_embedding) | |
| self.scheduler.config.mask_token_id = clip_mask_id | |
| input_ids = torch.ones( | |
| size=(len(image), self.tokenizer.model_max_length), | |
| dtype=torch.int64, | |
| device=self._execution_device | |
| ) | |
| input_ids = input_ids * clip_mask_id | |
| question_len = [] | |
| if prompt is None: | |
| question_len = [0] * len(image) | |
| elif isinstance(prompt, str): | |
| question_ids = torch.LongTensor([self.tokenizer.encode(prompt)]) | |
| question_ids = question_ids.repeat(len(image), 1) | |
| q_len = len(question_ids[0]) - 1 # remove <eos> token | |
| question_len = [q_len] * len(image) | |
| input_ids[:, :q_len] = question_ids[:, :-1] | |
| else: | |
| assert isinstance(prompt, list), f"prompt must be None or str or list!" | |
| assert len(prompt) == len(image), f"VQA require equal num of images and prompts!" | |
| for i, p in enumerate(prompt): | |
| question_ids = torch.LongTensor([self.tokenizer.encode(p)]) | |
| q_len = len(question_ids[0]) - 1 | |
| question_len.append(q_len) | |
| input_ids[i, :q_len] = question_ids[0, :-1] | |
| else: | |
| self.scheduler.config.mask_token_id = self.transformer.config.vocab_size - 1 | |
| if isinstance(prompt, str): | |
| prompt = [prompt] | |
| if image is not None: | |
| batch_size = len(image) | |
| else: | |
| batch_size = len(prompt) | |
| if height is None: | |
| height = self.transformer.config.sample_size * self.vae_scale_factor | |
| if width is None: | |
| width = self.transformer.config.sample_size * self.vae_scale_factor | |
| if isinstance(self.text_encoder, CLIPTextModelWithProjection): | |
| text_encoder_type = "open_clip" | |
| if isinstance(self.text_encoder_2, Gemma2Model): | |
| text_encoder_type = "gemma" | |
| if prompt_embeds is None: | |
| if text_encoder_type == "t5_clip": | |
| if text2image: | |
| input_ids_clip = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| max_length=77, | |
| ).input_ids.to(self._execution_device) | |
| outputs = self.text_encoder(input_ids_clip, return_dict=True, output_hidden_states=True) | |
| prompt_embeds = outputs.text_embeds | |
| input_ids_t5 = self.tokenizer_2( | |
| prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| max_length=256, | |
| ).input_ids.to(self._execution_device) | |
| outputs_2 = self.text_encoder_2(input_ids_t5, return_dict=True, output_hidden_states=True) | |
| encoder_hidden_states = outputs_2.last_hidden_state | |
| elif text_encoder_type == "open_clip": | |
| if text2image: | |
| input_ids = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| max_length=77, | |
| ).input_ids.to(self._execution_device) | |
| outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) | |
| prompt_embeds = outputs.text_embeds | |
| encoder_hidden_states = outputs.hidden_states[-2] | |
| elif text_encoder_type == "gemma": | |
| if text2image: | |
| input_ids_clip = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| max_length=77, | |
| ).input_ids.to(self._execution_device) | |
| outputs = self.text_encoder(input_ids_clip, return_dict=True, output_hidden_states=True) | |
| prompt_embeds = outputs.text_embeds | |
| input_ids_2 = self.tokenizer_2( | |
| prompt, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=256, | |
| return_tensors="pt", | |
| ).input_ids.to(self._execution_device) | |
| outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True) | |
| encoder_hidden_states = outputs_2.last_hidden_state | |
| prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) | |
| encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) | |
| if guidance_scale > 1.0 and text2image: | |
| if negative_prompt_embeds is None: | |
| if negative_prompt is None: | |
| negative_prompt = [""] * len(prompt) | |
| if isinstance(negative_prompt, str): | |
| negative_prompt = [negative_prompt] * len(prompt) | |
| if text_encoder_type == "t5_clip": | |
| input_ids = self.tokenizer( | |
| negative_prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| max_length=77, | |
| ).input_ids.to(self._execution_device) | |
| outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) | |
| negative_prompt_embeds = outputs.text_embeds | |
| input_ids_2 = self.tokenizer_2( | |
| negative_prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| max_length=256, | |
| ).input_ids.to(self._execution_device) | |
| outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True) | |
| negative_encoder_hidden_states = outputs_2.last_hidden_state | |
| elif text_encoder_type == "open_clip": | |
| input_ids = self.tokenizer( | |
| negative_prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| max_length=77, | |
| ).input_ids.to(self._execution_device) | |
| outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) | |
| negative_prompt_embeds = outputs.text_embeds | |
| negative_encoder_hidden_states = outputs.hidden_states[-2] | |
| elif text_encoder_type == "gemma": | |
| input_ids = self.tokenizer( | |
| negative_prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| add_special_tokens=True, | |
| max_length=77, | |
| ).input_ids.to(self._execution_device) | |
| outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) | |
| negative_prompt_embeds = outputs.text_embeds | |
| input_ids_2 = self.tokenizer_2( | |
| negative_prompt, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=256, | |
| return_tensors="pt", | |
| ).input_ids.to(self._execution_device) | |
| outputs_2 = self.text_encoder_2(input_ids_2, return_dict=True, output_hidden_states=True) | |
| negative_encoder_hidden_states = outputs_2.last_hidden_state | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) | |
| negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) | |
| prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) | |
| encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) | |
| # Note that the micro conditionings _do_ flip the order of width, height for the original size | |
| # and the crop coordinates. This is how it was done in the original code base | |
| micro_conds = torch.tensor( | |
| [ | |
| width, | |
| height, | |
| micro_conditioning_crop_coord[0], | |
| micro_conditioning_crop_coord[1], | |
| micro_conditioning_aesthetic_score, | |
| ], | |
| device=self._execution_device, | |
| dtype=encoder_hidden_states.dtype, | |
| ) | |
| micro_conds = micro_conds.unsqueeze(0) | |
| micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 and text2image else batch_size, -1) | |
| shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor) | |
| if latents is None and text2image: | |
| latents = torch.full( | |
| shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device | |
| ) | |
| elif image2text: | |
| if text_encoder_type in ("t5_clip", "gemma"): | |
| latents = input_ids_2 # [b, l] | |
| else: | |
| latents = input_ids | |
| model_input = None | |
| step_by_step = [] | |
| self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) | |
| num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, timestep in enumerate(self.scheduler.timesteps): | |
| if guidance_scale > 1.0 and text2image: | |
| model_input = torch.cat([latents] * 2) | |
| encoder_hidden_states = encoder_hidden_states | |
| elif image2text: | |
| if model_input is None: | |
| model_input = self.vqvae.quantize( | |
| self.vqvae.encode(image.to(self._execution_device, dtype=self.vqvae.dtype)).latents | |
| )[2][2].reshape(batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor) | |
| if text_encoder_type in ("t5_clip", "gemma"): | |
| outputs_t5 = self.text_encoder_2(latents, return_dict=True) | |
| encoder_hidden_states = outputs_t5.last_hidden_state | |
| batch_prompt = [] | |
| for i in range(latents.size(0)): | |
| masked_prompt_input_id = latents[i].tolist() | |
| prompt = self.tokenizer_2.decode(masked_prompt_input_id, skip_special_tokens=True) | |
| batch_prompt.append(prompt) | |
| masked_prompt_input_ids_clip = self.tokenizer( | |
| batch_prompt, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=77, | |
| return_tensors="pt" | |
| ).input_ids | |
| masked_prompt_input_ids_clip = masked_prompt_input_ids_clip.to(self._execution_device) | |
| outputs_clip = self.text_encoder(input_ids=masked_prompt_input_ids_clip, return_dict=True) | |
| prompt_embeds = outputs_clip.text_embeds | |
| else: | |
| outputs = self.text_encoder(latents, return_dict=True, output_hidden_states=True) | |
| prompt_embeds = outputs.text_embeds | |
| encoder_hidden_states = outputs.hidden_states[-2] | |
| else: | |
| model_input = latents | |
| encoder_hidden_states = encoder_hidden_states | |
| if height == 1024: #args.resolution == 1024: | |
| img_ids = _prepare_latent_image_ids( | |
| model_input.shape[0], | |
| model_input.shape[-2], | |
| model_input.shape[-1], | |
| model_input.device, | |
| model_input.dtype | |
| ) | |
| else: | |
| img_ids = _prepare_latent_image_ids( | |
| model_input.shape[0], | |
| model_input.shape[-2], | |
| model_input.shape[-1], | |
| model_input.device, | |
| model_input.dtype | |
| ) | |
| txt_ids = torch.zeros(encoder_hidden_states.shape[1], 3).to( | |
| device=encoder_hidden_states.device, | |
| dtype=encoder_hidden_states.dtype | |
| ) | |
| # timestep_ = int(timestep / num_inference_steps * 1000) | |
| model_output, encoder_hidden_states_tmp = self.transformer( | |
| hidden_states=model_input, | |
| micro_conds=micro_conds, | |
| pooled_projections=prompt_embeds, | |
| encoder_hidden_states=encoder_hidden_states, | |
| img_ids=img_ids, | |
| txt_ids=txt_ids, | |
| timestep=torch.tensor([timestep / num_inference_steps], device=model_input.device), | |
| ) | |
| if image2text: | |
| encoder_hidden_states = encoder_hidden_states_tmp.clone() | |
| if guidance_scale > 1.0 and text2image: | |
| uncond_logits, cond_logits = model_output.chunk(2) | |
| to_scheduler = uncond_logits + guidance_scale * (cond_logits - uncond_logits) | |
| elif image2text: | |
| to_scheduler = encoder_hidden_states | |
| else: | |
| to_scheduler = model_output | |
| latents = self.scheduler.step( | |
| model_output=to_scheduler, | |
| timestep=timestep, | |
| sample=latents, | |
| generator=generator, | |
| ).prev_sample | |
| # this line will print the intermediate results of the image-to-text generation | |
| # step_by_step.append(self.tokenizer.decode(latents[0].tolist(), skip_special_tokens=True)) | |
| # this line will print the intermediate results of the text-to-image generation | |
| # output = self.vqvae.decode( | |
| # latents, | |
| # force_not_quantize=True, | |
| # shape=( | |
| # batch_size, | |
| # height // self.vae_scale_factor, | |
| # width // self.vae_scale_factor, | |
| # self.vqvae.config.latent_channels, | |
| # ), | |
| # ).sample.clip(0, 1) | |
| # output = self.image_processor.postprocess(output, output_type) # output is a list of PIL.Image, you need to save it. | |
| if i == len(self.scheduler.timesteps) - 1 or ( | |
| (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | |
| ): | |
| progress_bar.update() | |
| if callback is not None and i % callback_steps == 0: | |
| step_idx = i // getattr(self.scheduler, "order", 1) | |
| callback(step_idx, timestep, latents) | |
| # with open("step_by_step.txt", "w") as file: | |
| # for prompt in step_by_step: | |
| # file.write(prompt + "\n") | |
| if guidance_scale > 1.0 and text2image: | |
| decoded_input_ids = encoder_hidden_states[encoder_hidden_states.shape[0] // 2:].argmax(-1) | |
| else: | |
| decoded_input_ids = encoder_hidden_states.argmax(-1) | |
| prompts = [] | |
| for i, prompt in enumerate(decoded_input_ids): | |
| if image2text: | |
| q_len = question_len[i] | |
| prompt = self.tokenizer.decode(prompt.tolist()[q_len:], skip_special_tokens=True) | |
| prompts.append(keep_upto_last_period(dedup_consecutive_words(prompt))) | |
| else: | |
| prompts.append("Placeholder") | |
| if output_type == "latent": | |
| output = latents | |
| else: | |
| needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast | |
| if needs_upcasting: | |
| self.vqvae.float() | |
| if text2image: | |
| to_vqvae = latents | |
| else: | |
| to_vqvae = model_input | |
| output = self.vqvae.decode( | |
| to_vqvae, | |
| force_not_quantize=True, | |
| shape=( | |
| batch_size, | |
| height // self.vae_scale_factor, | |
| width // self.vae_scale_factor, | |
| self.vqvae.config.latent_channels, | |
| ), | |
| ).sample.clip(0, 1) | |
| output = self.image_processor.postprocess(output, output_type) | |
| if needs_upcasting: | |
| self.vqvae.half() | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (output,) | |
| return UnifiedPipelineOutput(images=output, prompts=prompts) |