Spaces:
Running
on
Zero
Running
on
Zero
| from diffsynth import ModelManager | |
| from diffsynth.pipelines.base import BasePipeline | |
| from diffsynth.vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear | |
| from model.dit import WanModel | |
| from model.text_encoder import WanTextEncoder | |
| from model.vae import WanVideoVAE | |
| from model.image_encoder import WanImageEncoder | |
| from model.prompter import WanPrompter | |
| from scheduler.flow_match import FlowMatchScheduler | |
| import torch, os | |
| from einops import rearrange, repeat | |
| import numpy as np | |
| import PIL.Image | |
| from tqdm import tqdm | |
| from safetensors import safe_open | |
| from model.text_encoder import T5RelativeEmbedding, T5LayerNorm | |
| from model.dit import WanLayerNorm, WanRMSNorm, WanSelfAttention | |
| from model.vae import RMS_norm, CausalConv3d, Upsample | |
| def binary_tensor_to_indices(tensor): | |
| assert tensor.dim() == 2, "Input tensor must be in [b, t]" | |
| indices = [(row == 1).nonzero(as_tuple=True)[0] for row in tensor] | |
| return indices | |
| def propagate_visualize_attention_arg(model, visualize_attention=False): | |
| """ | |
| Recursively set the visualize_attention parameter to True for all WanSelfAttention modules | |
| Only for inference/test mode | |
| """ | |
| for name, module in model.named_modules(): | |
| if isinstance(module, WanSelfAttention): | |
| if "blocks.0.self_attn" in name or "blocks.19.self_attn" in name or "blocks.39.self_attn" in name: | |
| print(f"Set `visualize_attention` to {visualize_attention} for {name}") | |
| module.visualize_attention = visualize_attention | |
| class WanVideoPipeline(BasePipeline): | |
| def __init__(self, device="cuda", torch_dtype=torch.float16, tokenizer_path=None): | |
| super().__init__(device=device, torch_dtype=torch_dtype) | |
| self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True) | |
| self.prompter = WanPrompter(tokenizer_path=tokenizer_path) | |
| self.text_encoder: WanTextEncoder = None | |
| self.image_encoder: WanImageEncoder = None | |
| self.dit: WanModel = None | |
| self.vae: WanVideoVAE = None | |
| self.model_names = ['text_encoder', 'dit', 'vae'] | |
| self.height_division_factor = 16 | |
| self.width_division_factor = 16 | |
| def enable_vram_management(self, num_persistent_param_in_dit=None): | |
| dtype = next(iter(self.text_encoder.parameters())).dtype | |
| enable_vram_management( | |
| self.text_encoder, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Embedding: AutoWrappedModule, | |
| T5RelativeEmbedding: AutoWrappedModule, | |
| T5LayerNorm: AutoWrappedModule, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device="cpu", | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| ) | |
| dtype = next(iter(self.dit.parameters())).dtype | |
| enable_vram_management( | |
| self.dit, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Conv3d: AutoWrappedModule, | |
| torch.nn.LayerNorm: AutoWrappedModule, | |
| WanLayerNorm: AutoWrappedModule, | |
| WanRMSNorm: AutoWrappedModule, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device=self.device, | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| max_num_param=num_persistent_param_in_dit, | |
| overflow_module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device="cpu", | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| ) | |
| dtype = next(iter(self.vae.parameters())).dtype | |
| enable_vram_management( | |
| self.vae, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Conv2d: AutoWrappedModule, | |
| RMS_norm: AutoWrappedModule, | |
| CausalConv3d: AutoWrappedModule, | |
| Upsample: AutoWrappedModule, | |
| torch.nn.SiLU: AutoWrappedModule, | |
| torch.nn.Dropout: AutoWrappedModule, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device=self.device, | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| ) | |
| if self.image_encoder is not None: | |
| dtype = next(iter(self.image_encoder.parameters())).dtype | |
| enable_vram_management( | |
| self.image_encoder, | |
| module_map = { | |
| torch.nn.Linear: AutoWrappedLinear, | |
| torch.nn.Conv2d: AutoWrappedModule, | |
| torch.nn.LayerNorm: AutoWrappedModule, | |
| }, | |
| module_config = dict( | |
| offload_dtype=dtype, | |
| offload_device="cpu", | |
| onload_dtype=dtype, | |
| onload_device="cpu", | |
| computation_dtype=self.torch_dtype, | |
| computation_device=self.device, | |
| ), | |
| ) | |
| self.enable_cpu_offload() | |
| def fetch_models_from_model_manager(self, model_manager: ModelManager): | |
| text_encoder_model_and_path = model_manager.fetch_model("wan_video_text_encoder", require_model_path=True) | |
| if text_encoder_model_and_path is not None: | |
| self.text_encoder, tokenizer_path = text_encoder_model_and_path | |
| self.prompter.fetch_models(self.text_encoder) | |
| self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(tokenizer_path), "google/umt5-xxl")) | |
| self.dit = model_manager.fetch_model("wan_video_dit") | |
| self.vae = model_manager.fetch_model("wan_video_vae") | |
| self.image_encoder = model_manager.fetch_model("wan_video_image_encoder") | |
| def _init_component_from_checkpoint_path(self, model_cls, state_dict_path, strict=True, config_dict=None): | |
| config = {} | |
| state_dict = self._load_state_dict(state_dict_path) | |
| if hasattr(model_cls, "state_dict_converter"): | |
| state_dict_converter = model_cls.state_dict_converter() | |
| state_dict = state_dict_converter.from_civitai(state_dict) | |
| if isinstance(state_dict, tuple): | |
| state_dict, config = state_dict | |
| config.update(config_dict or {}) | |
| model = model_cls(**config) | |
| if "use_local_lora" in config_dict or "use_dera" in config_dict: | |
| strict = False | |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=strict) | |
| print(f"Missing keys: {missing_keys}") | |
| print(f"Unexpected keys: {unexpected_keys}") | |
| return model | |
| def _load_state_dict(self, state_dict_paths): | |
| if isinstance(state_dict_paths, str): | |
| state_dict_paths = [state_dict_paths] | |
| state_dict = {} | |
| for state_dict_path in tqdm(state_dict_paths, desc="Reading file(s) from disk"): | |
| state_dict.update(self._load_single_file(state_dict_path)) | |
| return state_dict | |
| def _load_single_file(self, file_path): | |
| if file_path.endswith(".safetensors"): | |
| return self._load_state_dict_from_safetensors(file_path) | |
| else: | |
| return torch.load(file_path, map_location='cpu') | |
| def _load_state_dict_from_safetensors(self, file_path, torch_dtype=None): | |
| state_dict = {} | |
| with safe_open(file_path, framework="pt", device="cpu") as f: | |
| for k in f.keys(): | |
| state_dict[k] = f.get_tensor(k) | |
| if torch_dtype is not None: | |
| state_dict[k] = state_dict[k].to(torch_dtype) | |
| return state_dict | |
| def initialize_dummy_dit(self, config): | |
| print("Initializing a dummy DIT model.") | |
| self.dit = WanModel(**config) | |
| print("Dummy DIT model is initialized.") | |
| def fetch_models_from_checkpoints(self, path_dict, config_dict=None): | |
| default_config = {"text_encoder": {}, "dit": {}, "vae": {}, "image_encoder": {}} | |
| config_dict = {**default_config, **(config_dict or {})} | |
| components = { | |
| "text_encoder": WanTextEncoder, | |
| "dit": WanModel, | |
| "vae": WanVideoVAE, | |
| "image_encoder": WanImageEncoder | |
| } | |
| for name, model_cls in components.items(): | |
| if name not in path_dict: | |
| print(f"Component {name} is not found in the checkpoint path dict. Skipping.") | |
| continue | |
| path = path_dict[name] | |
| config = config_dict.get(name, {}) | |
| print(f"Loading {name} from {path} with config {config}.") | |
| setattr(self, name, self._init_component_from_checkpoint_path(model_cls, path, config_dict=config)) | |
| print(f"Initialized {name} from checkpoint.") | |
| if "text_encoder" in path_dict: | |
| self.prompter.fetch_models(self.text_encoder) | |
| self.prompter.fetch_tokenizer(os.path.join(os.path.dirname(path_dict["text_encoder"]), "google/umt5-xxl")) | |
| print("Initialized prompter from checkpoint.") | |
| print("All components are initialized from checkpoints.") | |
| def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None): | |
| if device is None: device = model_manager.device | |
| if torch_dtype is None: torch_dtype = model_manager.torch_dtype | |
| pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype) | |
| pipe.fetch_models_from_model_manager(model_manager) | |
| return pipe | |
| def denoising_model(self): | |
| return self.dit | |
| def encode_prompt(self, prompt, positive=True): | |
| prompt_emb = self.prompter.encode_prompt(prompt, positive=positive) | |
| return {"context": prompt_emb} | |
| def encode_image(self, image, num_frames, height, width): | |
| with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): | |
| image = self.preprocess_image(image.resize((width, height))).to(self.device) | |
| clip_context = self.image_encoder.encode_image([image]) | |
| msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) | |
| msk[:, 1:] = 0 | |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
| msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) | |
| msk = msk.transpose(1, 2)[0] | |
| y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0] | |
| y = torch.concat([msk, y]) | |
| return {"clip_fea": clip_context, "y": [y]} | |
| def check_and_fix_image_or_video_tensor_input(self, _tensor): | |
| assert isinstance(_tensor, torch.Tensor), "Input must be a tensor." | |
| if _tensor.max() <= 255 and _tensor.max() > 1.0: | |
| _tensor = _tensor.to(self.device) / 127.5 - 1 | |
| print("Input tensor is converted from [0, 255] to [-1, 1].") | |
| elif _tensor.min() >= 0 and _tensor.max() <= 1: | |
| _tensor = _tensor.to(self.device) * 2 - 1 | |
| print("Input tensor is converted from [0, 1] to [-1, 1].") | |
| return _tensor | |
| def encode_video_with_mask(self, video, num_frames, height, width, condition_preserved_mask): | |
| with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): | |
| video = video.to(self.device) | |
| y = self.vae.encode(video, device=self.device) | |
| msk = condition_preserved_mask | |
| assert msk is not None, "The mask must be provided for the masked video input." | |
| assert msk.dim() == 2, "The mask must be a 2D tensor in [b, t]." | |
| assert msk.shape[0] == video.shape[0], "The batch size of the mask must be the same as the input video." | |
| assert msk.shape[1] == num_frames, "The number of frames in the mask must be the same as the input video." | |
| msk = msk.to(self.device) | |
| msk = msk.unsqueeze(-1).unsqueeze(-1) | |
| msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8) | |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
| msk = msk.view(video.shape[0], msk.shape[1] // 4, 4, height//8, width//8) # b, t, c, h, w | |
| msk = msk.transpose(1, 2) # b, c, t, h, w | |
| y = torch.concat([msk, y], dim=1) | |
| return y | |
| def encode_video_with_mask_sparse(self, video, height, width, condition_preserved_mask, sketch_local_mask=None): | |
| with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): | |
| batch_size = video.shape[0] | |
| cond_indices = binary_tensor_to_indices(condition_preserved_mask) | |
| sequence_cond_compressed_indices = [(cond_index + 3) // 4 for cond_index in cond_indices] | |
| video = video.to(self.device) | |
| video_latent = self.vae.encode(video, device=self.device) | |
| video_latent = video_latent[:, :, sequence_cond_compressed_indices[0], :, :] | |
| msk = condition_preserved_mask.to(self.device) | |
| msk = msk.unsqueeze(-1).unsqueeze(-1) # b, t, 1, 1 | |
| msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8) | |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
| msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8) # b, t, 4, h//8, w//8 | |
| msk = msk.transpose(1, 2) # b, 4, t, h//8, w//8 | |
| msk = msk[:, :, sequence_cond_compressed_indices[0], :, :] | |
| if sketch_local_mask is not None: | |
| sketch_local_mask = sketch_local_mask.to(self.device) | |
| if sketch_local_mask.shape[-2:] != (height//8, width//8): | |
| sk_batch_t = sketch_local_mask.shape[0] * sketch_local_mask.shape[2] | |
| sketch_local_mask_reshaped = sketch_local_mask.reshape(sk_batch_t, 1, sketch_local_mask.shape[3], sketch_local_mask.shape[4]) | |
| sketch_local_mask_resized = torch.nn.functional.interpolate( | |
| sketch_local_mask_reshaped, | |
| size=(height//8, width//8), | |
| mode='nearest' | |
| ) | |
| sketch_local_mask_resized = sketch_local_mask_resized.reshape( | |
| sketch_local_mask.shape[0], | |
| sketch_local_mask.shape[1], | |
| sketch_local_mask.shape[2], | |
| height//8, width//8 | |
| ) | |
| else: | |
| sketch_local_mask_resized = sketch_local_mask | |
| sketch_mask = sketch_local_mask_resized | |
| sketch_mask = torch.concat([torch.repeat_interleave(sketch_mask[:, :, 0:1], repeats=4, dim=2), sketch_mask[:, :, 1:]], dim=2) | |
| sketch_mask = sketch_mask.view(batch_size, sketch_mask.shape[1], sketch_mask.shape[2] // 4, 4, height//8, width//8) | |
| sketch_mask = sketch_mask.permute(0, 1, 3, 2, 4, 5) # [b, 1, 4, t//4, h//8, w//8] | |
| sketch_mask = sketch_mask.view(batch_size, 4, sketch_mask.shape[3], height//8, width//8) # [b, 4, t//4, h//8, w//8] | |
| sketch_mask = sketch_mask[:, :, sequence_cond_compressed_indices[0], :, :] # [b, 4, len(indices), h//8, w//8] | |
| combined_latent = torch.cat([msk, video_latent, sketch_mask], dim=1) | |
| else: | |
| combined_latent = torch.concat([msk, video_latent], dim=1) | |
| return combined_latent, sequence_cond_compressed_indices # b, c=(4+16+4=24), t, h, w when sketch_local_mask is provided | |
| def encode_image_or_masked_video(self, image_or_masked_video, num_frames, height, width, condition_preserved_mask=None): | |
| with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): | |
| batch_size = image_or_masked_video.shape[0] | |
| if isinstance(image_or_masked_video, PIL.Image.Image) or (isinstance(image_or_masked_video, torch.Tensor) and image_or_masked_video.dim() <= 4): | |
| if isinstance(image_or_masked_video, PIL.Image.Image): | |
| image_or_masked_video = self.preprocess_image(image_or_masked_video.resize((width, height))).to(self.device) | |
| else: | |
| if image_or_masked_video.dim() == 3: | |
| image_or_masked_video = image_or_masked_video.unsqueeze(0) # b=1, c, h, w | |
| image_or_masked_video = image_or_masked_video.to(self.device) | |
| y = self.vae.encode([torch.concat([image_or_masked_video.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image_or_masked_video.device)], dim=1)], device=self.device) | |
| msk_idx_to_be_zero = range(1, num_frames) | |
| clip_context = self.image_encoder.encode_image(image_or_masked_video.unsqueeze(1)) # need to be [b, 1, c, h, w] | |
| msk = torch.ones(batch_size, num_frames, height//8, width//8, device=self.device) | |
| msk[:, msk_idx_to_be_zero] = 0 | |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
| msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8) | |
| msk = msk.transpose(1, 2) | |
| elif isinstance(image_or_masked_video, torch.Tensor) and image_or_masked_video.dim() == 5: | |
| image_or_masked_video = image_or_masked_video.to(self.device) | |
| first_image = image_or_masked_video[:, :, 0, :, :].unsqueeze(1) | |
| clip_context = self.image_encoder.encode_image(first_image) | |
| y = self.vae.encode(image_or_masked_video, device=self.device) | |
| msk = condition_preserved_mask # b, t | |
| assert msk is not None, "The mask must be provided for the masked video input." | |
| assert msk.dim() == 2, "The mask must be a 2D tensor in [b, t]." | |
| assert msk.shape[0] == batch_size, "The batch size of the mask must be the same as the input video." | |
| assert msk.shape[1] == num_frames, "The number of frames in the mask must be the same as the input video." | |
| msk = msk.to(self.device) | |
| msk = msk.unsqueeze(-1).unsqueeze(-1) # b, t, 1, 1 | |
| msk = repeat(msk, 'b t 1 1 -> b t h w', h=height//8, w=width//8) | |
| msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) | |
| msk = msk.view(batch_size, msk.shape[1] // 4, 4, height//8, width//8) # b, t, 4, h//8, w//8 | |
| msk = msk.transpose(1, 2) # b, 4, t, h//8, w//8 | |
| else: | |
| raise ValueError("Input must be an image (PIL/Tensor in [b, c, h, w]) or a masked video (Tensor in [b, c, t, h, w]).") | |
| y = torch.concat([msk, y], dim=1) | |
| return {"clip_fea": clip_context, "y": y} | |
| def tensor2video(self, frames): | |
| frames = rearrange(frames, "C T H W -> T H W C") | |
| frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) | |
| frames = [PIL.Image.fromarray(frame) for frame in frames] | |
| return frames | |
| def prepare_extra_input(self, latents=None): | |
| return {"seq_len": latents.shape[2] * latents.shape[3] * latents.shape[4] // 4} | |
| def encode_video(self, input_video, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): | |
| with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): | |
| latents = self.vae.encode(input_video, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| return latents | |
| def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): | |
| with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): | |
| frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) | |
| return frames | |
| def __call__( | |
| self, | |
| prompt, | |
| negative_prompt="", | |
| input_image=None, | |
| input_video=None, | |
| denoising_strength=1.0, | |
| seed=None, | |
| rand_device="cpu", | |
| height=480, | |
| width=832, | |
| num_frames=81, | |
| cfg_scale=5.0, | |
| num_inference_steps=50, | |
| sigma_shift=5.0, | |
| tiled=True, | |
| tile_size=(30, 52), | |
| tile_stride=(15, 26), | |
| progress_bar_cmd=tqdm, | |
| # progress_bar_st=None, | |
| input_condition_video=None, | |
| input_condition_preserved_mask=None, | |
| input_condition_video_sketch=None, | |
| input_condition_preserved_mask_sketch=None, | |
| sketch_local_mask=None, | |
| visualize_attention=False, | |
| output_path=None, | |
| batch_idx=None, | |
| sequence_cond_residual_scale=1.0, | |
| ): | |
| height, width = self.check_resize_height_width(height, width) | |
| if num_frames % 4 != 1: | |
| num_frames = (num_frames + 2) // 4 * 4 + 1 | |
| print(f"Only `num_frames % 4 != 1` is acceptable. We round it up to {num_frames}.") | |
| tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} | |
| self.scheduler.set_timesteps(num_inference_steps, denoising_strength, shift=sigma_shift) | |
| noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=torch.float32).to(self.device) | |
| if input_video is not None: | |
| self.load_models_to_device(['vae']) | |
| input_video = self.preprocess_images(input_video) | |
| input_video = torch.stack(input_video, dim=2) | |
| latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=noise.dtype, device=noise.device) | |
| latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) | |
| else: | |
| latents = noise | |
| self.load_models_to_device(["text_encoder"]) | |
| prompt_emb_posi = self.encode_prompt(prompt, positive=True) | |
| if cfg_scale != 1.0: | |
| prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) | |
| self.load_models_to_device(["image_encoder", "vae"]) | |
| if input_image is not None and self.image_encoder is not None: | |
| image_emb = self.encode_image(input_image, num_frames, height, width) | |
| elif input_condition_video is not None and self.image_encoder is not None: | |
| assert input_condition_preserved_mask is not None, "`input_condition_preserved_mask` must not be None when `input_condition_video` is given." | |
| image_emb = self.encode_image_or_masked_video(input_condition_video, num_frames, height, width, input_condition_preserved_mask) | |
| else: | |
| image_emb = {} | |
| # Extra input | |
| extra_input = self.prepare_extra_input(latents) | |
| if self.dit.use_sequence_cond: | |
| assert input_condition_video_sketch is not None, "`input_condition_video_sketch` must not be None when `use_sequence_cond` is True." | |
| assert input_condition_preserved_mask_sketch is not None, "`input_condition_preserved_mask_sketch` must not be None when `input_condition_video_sketch` is given." | |
| if self.dit.sequence_cond_mode == "sparse": | |
| sequence_cond, sequence_cond_compressed_indices = self.encode_video_with_mask_sparse(input_condition_video_sketch, height, width, input_condition_preserved_mask_sketch, sketch_local_mask) | |
| extra_input.update({"sequence_cond": sequence_cond, | |
| "sequence_cond_compressed_indices": sequence_cond_compressed_indices}) | |
| elif self.dit.sequence_cond_mode == "full": | |
| sequence_cond = self.encode_video_with_mask(input_condition_video_sketch, num_frames, height, width, input_condition_preserved_mask_sketch) | |
| extra_input.update({"sequence_cond": sequence_cond}) | |
| else: | |
| raise ValueError(f"Invalid `sequence_cond_model`={self.dit.sequence_cond_mode} in the DIT model.") | |
| elif self.dit.use_channel_cond: | |
| sequence_cond = self.encode_video_with_mask(input_condition_video_sketch, num_frames, height, width, input_condition_preserved_mask_sketch) | |
| extra_input.update({"channel_cond": sequence_cond}) | |
| self.load_models_to_device([]) | |
| if sequence_cond_residual_scale != 1.0: | |
| extra_input.update({"sequence_cond_residual_scale": sequence_cond_residual_scale}) | |
| # Denoise | |
| self.load_models_to_device(["dit"]) | |
| with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): | |
| for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): | |
| timestep = timestep.unsqueeze(0).to(dtype=torch.float32, device=self.device) | |
| _should_visualize_attention = visualize_attention and (progress_id == len(self.scheduler.timesteps) - 1) | |
| if _should_visualize_attention: | |
| print(f"Visualizing attention maps (Step {progress_id + 1}/{len(self.scheduler.timesteps)}).") | |
| propagate_visualize_attention_arg(self.dit, True) | |
| # Inference | |
| noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input) | |
| if isinstance(noise_pred_posi, tuple): | |
| noise_pred_posi = noise_pred_posi[0] | |
| if cfg_scale != 1.0: | |
| noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input) | |
| if isinstance(noise_pred_nega, tuple): | |
| noise_pred_nega = noise_pred_nega[0] | |
| noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) | |
| else: | |
| noise_pred = noise_pred_posi | |
| # Scheduler | |
| latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) | |
| # If visualization is enabled, save the attention maps | |
| if _should_visualize_attention: | |
| print("Saving attention maps...") | |
| from util.model_util import save_attention_maps | |
| save_attention_maps(self.dit, output_path, batch_idx, timestep.squeeze().cpu().numpy().item()) | |
| propagate_visualize_attention_arg(self.dit, False) | |
| # Decode | |
| self.load_models_to_device(['vae']) | |
| frames = self.decode_video(latents, **tiler_kwargs) | |
| self.load_models_to_device([]) | |
| return frames |