Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	
		bubbliiiing
		
	commited on
		
		
					Commit 
							
							·
						
						c2a6cd2
	
1
								Parent(s):
							
							b0f1243
								
Update V5.1
Browse files- app.py +9 -18
- config/easyanimate_video_v5.1_magvit_qwen.yaml +21 -0
- easyanimate/api/api.py +1 -1
- easyanimate/api/post_infer.py +2 -2
- easyanimate/data/dataset_image_video.py +220 -32
- easyanimate/models/__init__.py +3 -4
- easyanimate/models/attention.py +60 -31
- easyanimate/models/autoencoder_magvit.py +15 -117
- easyanimate/models/embeddings.py +3 -2
- easyanimate/models/norm.py +16 -0
- easyanimate/models/processor.py +146 -0
- easyanimate/models/transformer3d.py +280 -43
- easyanimate/pipeline/pipeline_easyanimate.py +730 -486
- easyanimate/pipeline/{pipeline_easyanimate_multi_text_encoder_control.py → pipeline_easyanimate_control.py} +448 -229
- easyanimate/pipeline/pipeline_easyanimate_inpaint.py +0 -0
- easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py +0 -925
- easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py +0 -1334
- easyanimate/ui/ui.py +237 -179
- easyanimate/utils/lora_utils.py +42 -30
- easyanimate/utils/utils.py +53 -33
- easyanimate/vae/ldm/models/autoencoder.py +4 -4
- easyanimate/vae/ldm/models/casual3dcnn.py +5 -5
- easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py +5 -5
- easyanimate/vae/ldm/models/omnigen_casual3dcnn.py +13 -9
- easyanimate/vae/ldm/models/omnigen_enc_dec.py +6 -2
- easyanimate/vae/ldm/modules/losses/contperceptual.py +20 -3
- easyanimate/vae/ldm/modules/vaemodules/__init__.py +0 -0
- easyanimate/vae/ldm/modules/vaemodules/activations.py +0 -0
- easyanimate/vae/ldm/modules/vaemodules/common.py +39 -5
- easyanimate/vae/ldm/modules/vaemodules/down_blocks.py +0 -0
- easyanimate/vae/ldm/modules/vaemodules/mid_blocks.py +0 -0
- easyanimate/vae/ldm/modules/vaemodules/up_blocks.py +0 -0
- requirements.txt +2 -5
    	
        app.py
    CHANGED
    
    | @@ -19,6 +19,9 @@ if __name__ == "__main__": | |
| 19 | 
             
                # 
         | 
| 20 | 
             
                # "sequential_cpu_offload" means that each layer of the model will be moved to the CPU after use, 
         | 
| 21 | 
             
                # resulting in slower speeds but saving a large amount of GPU memory.
         | 
|  | |
|  | |
|  | |
| 22 | 
             
                GPU_memory_mode = "model_cpu_offload_and_qfloat8"
         | 
| 23 | 
             
                # Use torch.float16 if GPU does not support torch.bfloat16
         | 
| 24 | 
             
                # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
         | 
| @@ -29,11 +32,11 @@ if __name__ == "__main__": | |
| 29 | 
             
                server_port = 7860
         | 
| 30 |  | 
| 31 | 
             
                # Params below is used when ui_mode = "modelscope"
         | 
| 32 | 
            -
                edition = "v5"
         | 
| 33 | 
             
                # Config
         | 
| 34 | 
            -
                config_path = "config/ | 
| 35 | 
             
                # Model path of the pretrained model
         | 
| 36 | 
            -
                model_name = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP"
         | 
| 37 | 
             
                # "Inpaint" or "Control"
         | 
| 38 | 
             
                model_type = "Inpaint"
         | 
| 39 | 
             
                # Save dir
         | 
| @@ -46,18 +49,6 @@ if __name__ == "__main__": | |
| 46 | 
             
                else:
         | 
| 47 | 
             
                    demo, controller = ui(GPU_memory_mode, weight_dtype)
         | 
| 48 |  | 
| 49 | 
            -
                 | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
                    server_port=server_port,
         | 
| 53 | 
            -
                    prevent_thread_lock=True
         | 
| 54 | 
            -
                )
         | 
| 55 | 
            -
                
         | 
| 56 | 
            -
                # launch api
         | 
| 57 | 
            -
                infer_forward_api(None, app, controller)
         | 
| 58 | 
            -
                update_diffusion_transformer_api(None, app, controller)
         | 
| 59 | 
            -
                update_edition_api(None, app, controller)
         | 
| 60 | 
            -
                
         | 
| 61 | 
            -
                # not close the python
         | 
| 62 | 
            -
                while True:
         | 
| 63 | 
            -
                    time.sleep(5)
         | 
|  | |
| 19 | 
             
                # 
         | 
| 20 | 
             
                # "sequential_cpu_offload" means that each layer of the model will be moved to the CPU after use, 
         | 
| 21 | 
             
                # resulting in slower speeds but saving a large amount of GPU memory.
         | 
| 22 | 
            +
                # 
         | 
| 23 | 
            +
                # EasyAnimateV1, V2 and V3 support "model_cpu_offload" "sequential_cpu_offload"
         | 
| 24 | 
            +
                # EasyAnimateV4, V5 and V5.1 support "model_cpu_offload" "model_cpu_offload_and_qfloat8" "sequential_cpu_offload"
         | 
| 25 | 
             
                GPU_memory_mode = "model_cpu_offload_and_qfloat8"
         | 
| 26 | 
             
                # Use torch.float16 if GPU does not support torch.bfloat16
         | 
| 27 | 
             
                # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
         | 
|  | |
| 32 | 
             
                server_port = 7860
         | 
| 33 |  | 
| 34 | 
             
                # Params below is used when ui_mode = "modelscope"
         | 
| 35 | 
            +
                edition = "v5.1"
         | 
| 36 | 
             
                # Config
         | 
| 37 | 
            +
                config_path = "config/easyanimate_video_v5.1_magvit_qwen.yaml"
         | 
| 38 | 
             
                # Model path of the pretrained model
         | 
| 39 | 
            +
                model_name = "models/Diffusion_Transformer/EasyAnimateV5.1-12b-zh-InP"
         | 
| 40 | 
             
                # "Inpaint" or "Control"
         | 
| 41 | 
             
                model_type = "Inpaint"
         | 
| 42 | 
             
                # Save dir
         | 
|  | |
| 49 | 
             
                else:
         | 
| 50 | 
             
                    demo, controller = ui(GPU_memory_mode, weight_dtype)
         | 
| 51 |  | 
| 52 | 
            +
                demo.launch(
         | 
| 53 | 
            +
                    server_name=server_name, server_port=server_port
         | 
| 54 | 
            +
                )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        config/easyanimate_video_v5.1_magvit_qwen.yaml
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            transformer_additional_kwargs:
         | 
| 2 | 
            +
              transformer_type:                           "EasyAnimateTransformer3DModel"
         | 
| 3 | 
            +
              after_norm:                                 false
         | 
| 4 | 
            +
              time_position_encoding_type:                "3d_rope"
         | 
| 5 | 
            +
              resize_inpaint_mask_directly:               true
         | 
| 6 | 
            +
              enable_text_attention_mask:                 true
         | 
| 7 | 
            +
              enable_clip_in_inpaint:                     false
         | 
| 8 | 
            +
              add_ref_latent_in_control_model:            true
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            vae_kwargs:
         | 
| 11 | 
            +
              vae_type: "AutoencoderKLMagvit"
         | 
| 12 | 
            +
              mini_batch_encoder: 4
         | 
| 13 | 
            +
              mini_batch_decoder: 1
         | 
| 14 | 
            +
              slice_mag_vae: false
         | 
| 15 | 
            +
              slice_compression_vae: false
         | 
| 16 | 
            +
              cache_compression_vae: false
         | 
| 17 | 
            +
              cache_mag_vae: true
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            text_encoder_kwargs:
         | 
| 20 | 
            +
              enable_multi_text_encoder: false
         | 
| 21 | 
            +
              replace_t5_to_llm: true
         | 
    	
        easyanimate/api/api.py
    CHANGED
    
    | @@ -93,7 +93,7 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller): | |
| 93 | 
             
                    lora_model_path = datas.get('lora_model_path', 'none')
         | 
| 94 | 
             
                    lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
         | 
| 95 | 
             
                    prompt_textbox = datas.get('prompt_textbox', None)
         | 
| 96 | 
            -
                    negative_prompt_textbox = datas.get('negative_prompt_textbox', 'Blurring, mutation, deformation, distortion, dark and solid, comics.')
         | 
| 97 | 
             
                    sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
         | 
| 98 | 
             
                    sample_step_slider = datas.get('sample_step_slider', 30)
         | 
| 99 | 
             
                    resize_method = datas.get('resize_method', "Generate by")
         | 
|  | |
| 93 | 
             
                    lora_model_path = datas.get('lora_model_path', 'none')
         | 
| 94 | 
             
                    lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
         | 
| 95 | 
             
                    prompt_textbox = datas.get('prompt_textbox', None)
         | 
| 96 | 
            +
                    negative_prompt_textbox = datas.get('negative_prompt_textbox', 'Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art.')
         | 
| 97 | 
             
                    sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
         | 
| 98 | 
             
                    sample_step_slider = datas.get('sample_step_slider', 30)
         | 
| 99 | 
             
                    resize_method = datas.get('resize_method', "Generate by")
         | 
    	
        easyanimate/api/post_infer.py
    CHANGED
    
    | @@ -54,14 +54,14 @@ if __name__ == '__main__': | |
| 54 | 
             
                # -------------------------- #
         | 
| 55 | 
             
                #  Step 1: update edition
         | 
| 56 | 
             
                # -------------------------- #
         | 
| 57 | 
            -
                edition = "v5"
         | 
| 58 | 
             
                outputs = post_update_edition(edition)
         | 
| 59 | 
             
                print('Output update edition: ', outputs)
         | 
| 60 |  | 
| 61 | 
             
                # -------------------------- #
         | 
| 62 | 
             
                #  Step 2: update edition
         | 
| 63 | 
             
                # -------------------------- #
         | 
| 64 | 
            -
                diffusion_transformer_path = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP"
         | 
| 65 | 
             
                outputs = post_diffusion_transformer(diffusion_transformer_path)
         | 
| 66 | 
             
                print('Output update edition: ', outputs)
         | 
| 67 |  | 
|  | |
| 54 | 
             
                # -------------------------- #
         | 
| 55 | 
             
                #  Step 1: update edition
         | 
| 56 | 
             
                # -------------------------- #
         | 
| 57 | 
            +
                edition = "v5.1"
         | 
| 58 | 
             
                outputs = post_update_edition(edition)
         | 
| 59 | 
             
                print('Output update edition: ', outputs)
         | 
| 60 |  | 
| 61 | 
             
                # -------------------------- #
         | 
| 62 | 
             
                #  Step 2: update edition
         | 
| 63 | 
             
                # -------------------------- #
         | 
| 64 | 
            +
                diffusion_transformer_path = "models/Diffusion_Transformer/EasyAnimateV5.1-12b-zh-InP"
         | 
| 65 | 
             
                outputs = post_diffusion_transformer(diffusion_transformer_path)
         | 
| 66 | 
             
                print('Output update edition: ', outputs)
         | 
| 67 |  | 
    	
        easyanimate/data/dataset_image_video.py
    CHANGED
    
    | @@ -12,9 +12,12 @@ import albumentations | |
| 12 | 
             
            import cv2
         | 
| 13 | 
             
            import numpy as np
         | 
| 14 | 
             
            import torch
         | 
|  | |
| 15 | 
             
            import torchvision.transforms as transforms
         | 
| 16 | 
             
            from decord import VideoReader
         | 
|  | |
| 17 | 
             
            from func_timeout import FunctionTimedOut, func_timeout
         | 
|  | |
| 18 | 
             
            from PIL import Image
         | 
| 19 | 
             
            from torch.utils.data import BatchSampler, Sampler
         | 
| 20 | 
             
            from torch.utils.data.dataset import Dataset
         | 
| @@ -100,6 +103,152 @@ def get_random_mask(shape): | |
| 100 | 
             
                else:
         | 
| 101 | 
             
                    raise ValueError(f"The mask_index {mask_index} is not define")
         | 
| 102 | 
             
                return mask
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 103 |  | 
| 104 | 
             
            class ImageVideoSampler(BatchSampler):
         | 
| 105 | 
             
                """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
         | 
| @@ -184,7 +333,7 @@ class ImageVideoDataset(Dataset): | |
| 184 | 
             
                    video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
         | 
| 185 | 
             
                    image_sample_size=512,
         | 
| 186 | 
             
                    video_repeat=0,
         | 
| 187 | 
            -
                    text_drop_ratio | 
| 188 | 
             
                    enable_bucket=False,
         | 
| 189 | 
             
                    video_length_drop_start=0.1, 
         | 
| 190 | 
             
                    video_length_drop_end=0.9,
         | 
| @@ -355,7 +504,6 @@ class ImageVideoDataset(Dataset): | |
| 355 |  | 
| 356 | 
             
                    return sample
         | 
| 357 |  | 
| 358 | 
            -
             | 
| 359 | 
             
            class ImageVideoControlDataset(Dataset):
         | 
| 360 | 
             
                def __init__(
         | 
| 361 | 
             
                    self,
         | 
| @@ -363,11 +511,12 @@ class ImageVideoControlDataset(Dataset): | |
| 363 | 
             
                    video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
         | 
| 364 | 
             
                    image_sample_size=512,
         | 
| 365 | 
             
                    video_repeat=0,
         | 
| 366 | 
            -
                    text_drop_ratio | 
| 367 | 
             
                    enable_bucket=False,
         | 
| 368 | 
             
                    video_length_drop_start=0.1, 
         | 
| 369 | 
             
                    video_length_drop_end=0.9,
         | 
| 370 | 
             
                    enable_inpaint=False,
         | 
|  | |
| 371 | 
             
                ):
         | 
| 372 | 
             
                    # Loading annotations from files
         | 
| 373 | 
             
                    print(f"loading annotations from {ann_path} ...")
         | 
| @@ -397,6 +546,7 @@ class ImageVideoControlDataset(Dataset): | |
| 397 | 
             
                    self.enable_bucket = enable_bucket
         | 
| 398 | 
             
                    self.text_drop_ratio = text_drop_ratio
         | 
| 399 | 
             
                    self.enable_inpaint  = enable_inpaint
         | 
|  | |
| 400 |  | 
| 401 | 
             
                    self.video_length_drop_start = video_length_drop_start
         | 
| 402 | 
             
                    self.video_length_drop_end = video_length_drop_end
         | 
| @@ -412,6 +562,13 @@ class ImageVideoControlDataset(Dataset): | |
| 412 | 
             
                            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
         | 
| 413 | 
             
                        ]
         | 
| 414 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 415 |  | 
| 416 | 
             
                    # Image params
         | 
| 417 | 
             
                    self.image_sample_size  = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
         | 
| @@ -484,33 +641,59 @@ class ImageVideoControlDataset(Dataset): | |
| 484 | 
             
                        else:
         | 
| 485 | 
             
                            control_video_id = os.path.join(self.data_root, control_video_id)
         | 
| 486 |  | 
| 487 | 
            -
                         | 
| 488 | 
            -
                             | 
| 489 | 
            -
                                 | 
| 490 | 
            -
             | 
| 491 | 
            -
             | 
| 492 | 
            -
             | 
| 493 | 
            -
             | 
| 494 | 
            -
             | 
| 495 | 
            -
                                     | 
| 496 | 
            -
             | 
| 497 | 
            -
                                     | 
| 498 | 
            -
             | 
| 499 | 
            -
             | 
| 500 | 
            -
             | 
| 501 | 
            -
             | 
| 502 | 
            -
             | 
| 503 | 
            -
             | 
| 504 | 
            -
                            if not self.enable_bucket:
         | 
| 505 | 
            -
                                control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
         | 
| 506 | 
            -
                                control_pixel_values = control_pixel_values / 255.
         | 
| 507 | 
            -
                                del control_video_reader
         | 
| 508 | 
             
                            else:
         | 
| 509 | 
            -
                                 | 
| 510 | 
            -
             | 
| 511 | 
            -
             | 
| 512 | 
            -
                                 | 
| 513 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 514 | 
             
                    else:
         | 
| 515 | 
             
                        image_path, text = data_info['file_path'], data_info['text']
         | 
| 516 | 
             
                        if self.data_root is not None:
         | 
| @@ -536,7 +719,8 @@ class ImageVideoControlDataset(Dataset): | |
| 536 | 
             
                            control_image = self.image_transforms(control_image).unsqueeze(0)
         | 
| 537 | 
             
                        else:
         | 
| 538 | 
             
                            control_image = np.expand_dims(np.array(control_image), 0)
         | 
| 539 | 
            -
             | 
|  | |
| 540 |  | 
| 541 | 
             
                def __len__(self):
         | 
| 542 | 
             
                    return self.length
         | 
| @@ -552,13 +736,17 @@ class ImageVideoControlDataset(Dataset): | |
| 552 | 
             
                            if data_type_local != data_type:
         | 
| 553 | 
             
                                raise ValueError("data_type_local != data_type")
         | 
| 554 |  | 
| 555 | 
            -
                            pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
         | 
|  | |
| 556 | 
             
                            sample["pixel_values"] = pixel_values
         | 
| 557 | 
             
                            sample["control_pixel_values"] = control_pixel_values
         | 
| 558 | 
             
                            sample["text"] = name
         | 
| 559 | 
             
                            sample["data_type"] = data_type
         | 
| 560 | 
             
                            sample["idx"] = idx
         | 
| 561 | 
            -
             | 
|  | |
|  | |
|  | |
| 562 | 
             
                            if len(sample) > 0:
         | 
| 563 | 
             
                                break
         | 
| 564 | 
             
                        except Exception as e:
         | 
|  | |
| 12 | 
             
            import cv2
         | 
| 13 | 
             
            import numpy as np
         | 
| 14 | 
             
            import torch
         | 
| 15 | 
            +
            import torch.nn.functional as F
         | 
| 16 | 
             
            import torchvision.transforms as transforms
         | 
| 17 | 
             
            from decord import VideoReader
         | 
| 18 | 
            +
            from einops import rearrange
         | 
| 19 | 
             
            from func_timeout import FunctionTimedOut, func_timeout
         | 
| 20 | 
            +
            from packaging import version as pver
         | 
| 21 | 
             
            from PIL import Image
         | 
| 22 | 
             
            from torch.utils.data import BatchSampler, Sampler
         | 
| 23 | 
             
            from torch.utils.data.dataset import Dataset
         | 
|  | |
| 103 | 
             
                else:
         | 
| 104 | 
             
                    raise ValueError(f"The mask_index {mask_index} is not define")
         | 
| 105 | 
             
                return mask
         | 
| 106 | 
            +
             
         | 
| 107 | 
            +
            class Camera(object):
         | 
| 108 | 
            +
                """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
         | 
| 109 | 
            +
                """
         | 
| 110 | 
            +
                def __init__(self, entry):
         | 
| 111 | 
            +
                    fx, fy, cx, cy = entry[1:5]
         | 
| 112 | 
            +
                    self.fx = fx
         | 
| 113 | 
            +
                    self.fy = fy
         | 
| 114 | 
            +
                    self.cx = cx
         | 
| 115 | 
            +
                    self.cy = cy
         | 
| 116 | 
            +
                    w2c_mat = np.array(entry[7:]).reshape(3, 4)
         | 
| 117 | 
            +
                    w2c_mat_4x4 = np.eye(4)
         | 
| 118 | 
            +
                    w2c_mat_4x4[:3, :] = w2c_mat
         | 
| 119 | 
            +
                    self.w2c_mat = w2c_mat_4x4
         | 
| 120 | 
            +
                    self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            def custom_meshgrid(*args):
         | 
| 123 | 
            +
                """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
         | 
| 124 | 
            +
                """
         | 
| 125 | 
            +
                # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
         | 
| 126 | 
            +
                if pver.parse(torch.__version__) < pver.parse('1.10'):
         | 
| 127 | 
            +
                    return torch.meshgrid(*args)
         | 
| 128 | 
            +
                else:
         | 
| 129 | 
            +
                    return torch.meshgrid(*args, indexing='ij')
         | 
| 130 | 
            +
             | 
| 131 | 
            +
            def get_relative_pose(cam_params):
         | 
| 132 | 
            +
                """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
         | 
| 133 | 
            +
                """
         | 
| 134 | 
            +
                abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
         | 
| 135 | 
            +
                abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
         | 
| 136 | 
            +
                cam_to_origin = 0
         | 
| 137 | 
            +
                target_cam_c2w = np.array([
         | 
| 138 | 
            +
                    [1, 0, 0, 0],
         | 
| 139 | 
            +
                    [0, 1, 0, -cam_to_origin],
         | 
| 140 | 
            +
                    [0, 0, 1, 0],
         | 
| 141 | 
            +
                    [0, 0, 0, 1]
         | 
| 142 | 
            +
                ])
         | 
| 143 | 
            +
                abs2rel = target_cam_c2w @ abs_w2cs[0]
         | 
| 144 | 
            +
                ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
         | 
| 145 | 
            +
                ret_poses = np.array(ret_poses, dtype=np.float32)
         | 
| 146 | 
            +
                return ret_poses
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            def ray_condition(K, c2w, H, W, device):
         | 
| 149 | 
            +
                """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
         | 
| 150 | 
            +
                """
         | 
| 151 | 
            +
                # c2w: B, V, 4, 4
         | 
| 152 | 
            +
                # K: B, V, 4
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                B = K.shape[0]
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                j, i = custom_meshgrid(
         | 
| 157 | 
            +
                    torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
         | 
| 158 | 
            +
                    torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
         | 
| 159 | 
            +
                )
         | 
| 160 | 
            +
                i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5  # [B, HxW]
         | 
| 161 | 
            +
                j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5  # [B, HxW]
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                fx, fy, cx, cy = K.chunk(4, dim=-1)  # B,V, 1
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                zs = torch.ones_like(i)  # [B, HxW]
         | 
| 166 | 
            +
                xs = (i - cx) / fx * zs
         | 
| 167 | 
            +
                ys = (j - cy) / fy * zs
         | 
| 168 | 
            +
                zs = zs.expand_as(ys)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                directions = torch.stack((xs, ys, zs), dim=-1)  # B, V, HW, 3
         | 
| 171 | 
            +
                directions = directions / directions.norm(dim=-1, keepdim=True)  # B, V, HW, 3
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2)  # B, V, 3, HW
         | 
| 174 | 
            +
                rays_o = c2w[..., :3, 3]  # B, V, 3
         | 
| 175 | 
            +
                rays_o = rays_o[:, :, None].expand_as(rays_d)  # B, V, 3, HW
         | 
| 176 | 
            +
                # c2w @ dirctions
         | 
| 177 | 
            +
                rays_dxo = torch.cross(rays_o, rays_d)
         | 
| 178 | 
            +
                plucker = torch.cat([rays_dxo, rays_d], dim=-1)
         | 
| 179 | 
            +
                plucker = plucker.reshape(B, c2w.shape[1], H, W, 6)  # B, V, H, W, 6
         | 
| 180 | 
            +
                # plucker = plucker.permute(0, 1, 4, 2, 3)
         | 
| 181 | 
            +
                return plucker
         | 
| 182 | 
            +
             | 
| 183 | 
            +
            def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
         | 
| 184 | 
            +
                """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
         | 
| 185 | 
            +
                """
         | 
| 186 | 
            +
                with open(pose_file_path, 'r') as f:
         | 
| 187 | 
            +
                    poses = f.readlines()
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                poses = [pose.strip().split(' ') for pose in poses[1:]]
         | 
| 190 | 
            +
                cam_params = [[float(x) for x in pose] for pose in poses]
         | 
| 191 | 
            +
                if return_poses:
         | 
| 192 | 
            +
                    return cam_params
         | 
| 193 | 
            +
                else:
         | 
| 194 | 
            +
                    cam_params = [Camera(cam_param) for cam_param in cam_params]
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    sample_wh_ratio = width / height
         | 
| 197 | 
            +
                    pose_wh_ratio = original_pose_width / original_pose_height  # Assuming placeholder ratios, change as needed
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    if pose_wh_ratio > sample_wh_ratio:
         | 
| 200 | 
            +
                        resized_ori_w = height * pose_wh_ratio
         | 
| 201 | 
            +
                        for cam_param in cam_params:
         | 
| 202 | 
            +
                            cam_param.fx = resized_ori_w * cam_param.fx / width
         | 
| 203 | 
            +
                    else:
         | 
| 204 | 
            +
                        resized_ori_h = width / pose_wh_ratio
         | 
| 205 | 
            +
                        for cam_param in cam_params:
         | 
| 206 | 
            +
                            cam_param.fy = resized_ori_h * cam_param.fy / height
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    intrinsic = np.asarray([[cam_param.fx * width,
         | 
| 209 | 
            +
                                            cam_param.fy * height,
         | 
| 210 | 
            +
                                            cam_param.cx * width,
         | 
| 211 | 
            +
                                            cam_param.cy * height]
         | 
| 212 | 
            +
                                            for cam_param in cam_params], dtype=np.float32)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    K = torch.as_tensor(intrinsic)[None]  # [1, 1, 4]
         | 
| 215 | 
            +
                    c2ws = get_relative_pose(cam_params)  # Assuming this function is defined elsewhere
         | 
| 216 | 
            +
                    c2ws = torch.as_tensor(c2ws)[None]  # [1, n_frame, 4, 4]
         | 
| 217 | 
            +
                    plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous()  # V, 6, H, W
         | 
| 218 | 
            +
                    plucker_embedding = plucker_embedding[None]
         | 
| 219 | 
            +
                    plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
         | 
| 220 | 
            +
                    return plucker_embedding
         | 
| 221 | 
            +
             | 
| 222 | 
            +
            def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
         | 
| 223 | 
            +
                """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
         | 
| 224 | 
            +
                """
         | 
| 225 | 
            +
                cam_params = [Camera(cam_param) for cam_param in cam_params]
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                sample_wh_ratio = width / height
         | 
| 228 | 
            +
                pose_wh_ratio = original_pose_width / original_pose_height  # Assuming placeholder ratios, change as needed
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                if pose_wh_ratio > sample_wh_ratio:
         | 
| 231 | 
            +
                    resized_ori_w = height * pose_wh_ratio
         | 
| 232 | 
            +
                    for cam_param in cam_params:
         | 
| 233 | 
            +
                        cam_param.fx = resized_ori_w * cam_param.fx / width
         | 
| 234 | 
            +
                else:
         | 
| 235 | 
            +
                    resized_ori_h = width / pose_wh_ratio
         | 
| 236 | 
            +
                    for cam_param in cam_params:
         | 
| 237 | 
            +
                        cam_param.fy = resized_ori_h * cam_param.fy / height
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                intrinsic = np.asarray([[cam_param.fx * width,
         | 
| 240 | 
            +
                                        cam_param.fy * height,
         | 
| 241 | 
            +
                                        cam_param.cx * width,
         | 
| 242 | 
            +
                                        cam_param.cy * height]
         | 
| 243 | 
            +
                                        for cam_param in cam_params], dtype=np.float32)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                K = torch.as_tensor(intrinsic)[None]  # [1, 1, 4]
         | 
| 246 | 
            +
                c2ws = get_relative_pose(cam_params)  # Assuming this function is defined elsewhere
         | 
| 247 | 
            +
                c2ws = torch.as_tensor(c2ws)[None]  # [1, n_frame, 4, 4]
         | 
| 248 | 
            +
                plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous()  # V, 6, H, W
         | 
| 249 | 
            +
                plucker_embedding = plucker_embedding[None]
         | 
| 250 | 
            +
                plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
         | 
| 251 | 
            +
                return plucker_embedding
         | 
| 252 |  | 
| 253 | 
             
            class ImageVideoSampler(BatchSampler):
         | 
| 254 | 
             
                """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
         | 
|  | |
| 333 | 
             
                    video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
         | 
| 334 | 
             
                    image_sample_size=512,
         | 
| 335 | 
             
                    video_repeat=0,
         | 
| 336 | 
            +
                    text_drop_ratio=0.1,
         | 
| 337 | 
             
                    enable_bucket=False,
         | 
| 338 | 
             
                    video_length_drop_start=0.1, 
         | 
| 339 | 
             
                    video_length_drop_end=0.9,
         | 
|  | |
| 504 |  | 
| 505 | 
             
                    return sample
         | 
| 506 |  | 
|  | |
| 507 | 
             
            class ImageVideoControlDataset(Dataset):
         | 
| 508 | 
             
                def __init__(
         | 
| 509 | 
             
                    self,
         | 
|  | |
| 511 | 
             
                    video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
         | 
| 512 | 
             
                    image_sample_size=512,
         | 
| 513 | 
             
                    video_repeat=0,
         | 
| 514 | 
            +
                    text_drop_ratio=0.1,
         | 
| 515 | 
             
                    enable_bucket=False,
         | 
| 516 | 
             
                    video_length_drop_start=0.1, 
         | 
| 517 | 
             
                    video_length_drop_end=0.9,
         | 
| 518 | 
             
                    enable_inpaint=False,
         | 
| 519 | 
            +
                    enable_camera_info=False,
         | 
| 520 | 
             
                ):
         | 
| 521 | 
             
                    # Loading annotations from files
         | 
| 522 | 
             
                    print(f"loading annotations from {ann_path} ...")
         | 
|  | |
| 546 | 
             
                    self.enable_bucket = enable_bucket
         | 
| 547 | 
             
                    self.text_drop_ratio = text_drop_ratio
         | 
| 548 | 
             
                    self.enable_inpaint  = enable_inpaint
         | 
| 549 | 
            +
                    self.enable_camera_info = enable_camera_info
         | 
| 550 |  | 
| 551 | 
             
                    self.video_length_drop_start = video_length_drop_start
         | 
| 552 | 
             
                    self.video_length_drop_end = video_length_drop_end
         | 
|  | |
| 562 | 
             
                            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
         | 
| 563 | 
             
                        ]
         | 
| 564 | 
             
                    )
         | 
| 565 | 
            +
                    if self.enable_camera_info:
         | 
| 566 | 
            +
                        self.video_transforms_camera = transforms.Compose(
         | 
| 567 | 
            +
                            [
         | 
| 568 | 
            +
                                transforms.Resize(min(self.video_sample_size)),
         | 
| 569 | 
            +
                                transforms.CenterCrop(self.video_sample_size)
         | 
| 570 | 
            +
                            ]
         | 
| 571 | 
            +
                        )
         | 
| 572 |  | 
| 573 | 
             
                    # Image params
         | 
| 574 | 
             
                    self.image_sample_size  = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
         | 
|  | |
| 641 | 
             
                        else:
         | 
| 642 | 
             
                            control_video_id = os.path.join(self.data_root, control_video_id)
         | 
| 643 |  | 
| 644 | 
            +
                        if self.enable_camera_info:
         | 
| 645 | 
            +
                            if control_video_id.lower().endswith('.txt'):
         | 
| 646 | 
            +
                                if not self.enable_bucket:
         | 
| 647 | 
            +
                                    control_pixel_values = torch.zeros_like(pixel_values)
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                                    control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
         | 
| 650 | 
            +
                                    control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
         | 
| 651 | 
            +
                                    control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
         | 
| 652 | 
            +
                                    control_camera_values = self.video_transforms_camera(control_camera_values)
         | 
| 653 | 
            +
                                else:
         | 
| 654 | 
            +
                                    control_pixel_values = np.zeros_like(pixel_values)
         | 
| 655 | 
            +
             | 
| 656 | 
            +
                                    control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
         | 
| 657 | 
            +
                                    control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
         | 
| 658 | 
            +
                                    control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
         | 
| 659 | 
            +
                                    control_camera_values = np.array([control_camera_values[index] for index in batch_index])
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 660 | 
             
                            else:
         | 
| 661 | 
            +
                                if not self.enable_bucket:
         | 
| 662 | 
            +
                                    control_pixel_values = torch.zeros_like(pixel_values)
         | 
| 663 | 
            +
                                    control_camera_values = None
         | 
| 664 | 
            +
                                else:
         | 
| 665 | 
            +
                                    control_pixel_values = np.zeros_like(pixel_values)
         | 
| 666 | 
            +
                                    control_camera_values = None
         | 
| 667 | 
            +
                        else:
         | 
| 668 | 
            +
                            with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
         | 
| 669 | 
            +
                                try:
         | 
| 670 | 
            +
                                    sample_args = (control_video_reader, batch_index)
         | 
| 671 | 
            +
                                    control_pixel_values = func_timeout(
         | 
| 672 | 
            +
                                        VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
         | 
| 673 | 
            +
                                    )
         | 
| 674 | 
            +
                                    resized_frames = []
         | 
| 675 | 
            +
                                    for i in range(len(control_pixel_values)):
         | 
| 676 | 
            +
                                        frame = control_pixel_values[i]
         | 
| 677 | 
            +
                                        resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
         | 
| 678 | 
            +
                                        resized_frames.append(resized_frame)
         | 
| 679 | 
            +
                                    control_pixel_values = np.array(resized_frames)
         | 
| 680 | 
            +
                                except FunctionTimedOut:
         | 
| 681 | 
            +
                                    raise ValueError(f"Read {idx} timeout.")
         | 
| 682 | 
            +
                                except Exception as e:
         | 
| 683 | 
            +
                                    raise ValueError(f"Failed to extract frames from video. Error is {e}.")
         | 
| 684 | 
            +
             | 
| 685 | 
            +
                                if not self.enable_bucket:
         | 
| 686 | 
            +
                                    control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
         | 
| 687 | 
            +
                                    control_pixel_values = control_pixel_values / 255.
         | 
| 688 | 
            +
                                    del control_video_reader
         | 
| 689 | 
            +
                                else:
         | 
| 690 | 
            +
                                    control_pixel_values = control_pixel_values
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                                if not self.enable_bucket:
         | 
| 693 | 
            +
                                    control_pixel_values = self.video_transforms(control_pixel_values)
         | 
| 694 | 
            +
                            control_camera_values = None
         | 
| 695 | 
            +
             | 
| 696 | 
            +
                        return pixel_values, control_pixel_values, control_camera_values, text, "video"
         | 
| 697 | 
             
                    else:
         | 
| 698 | 
             
                        image_path, text = data_info['file_path'], data_info['text']
         | 
| 699 | 
             
                        if self.data_root is not None:
         | 
|  | |
| 719 | 
             
                            control_image = self.image_transforms(control_image).unsqueeze(0)
         | 
| 720 | 
             
                        else:
         | 
| 721 | 
             
                            control_image = np.expand_dims(np.array(control_image), 0)
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                        return image, control_image, None, text, 'image'
         | 
| 724 |  | 
| 725 | 
             
                def __len__(self):
         | 
| 726 | 
             
                    return self.length
         | 
|  | |
| 736 | 
             
                            if data_type_local != data_type:
         | 
| 737 | 
             
                                raise ValueError("data_type_local != data_type")
         | 
| 738 |  | 
| 739 | 
            +
                            pixel_values, control_pixel_values, control_camera_values, name, data_type = self.get_batch(idx)
         | 
| 740 | 
            +
             | 
| 741 | 
             
                            sample["pixel_values"] = pixel_values
         | 
| 742 | 
             
                            sample["control_pixel_values"] = control_pixel_values
         | 
| 743 | 
             
                            sample["text"] = name
         | 
| 744 | 
             
                            sample["data_type"] = data_type
         | 
| 745 | 
             
                            sample["idx"] = idx
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                            if self.enable_camera_info:
         | 
| 748 | 
            +
                                sample["control_camera_values"] = control_camera_values
         | 
| 749 | 
            +
             | 
| 750 | 
             
                            if len(sample) > 0:
         | 
| 751 | 
             
                                break
         | 
| 752 | 
             
                        except Exception as e:
         | 
    	
        easyanimate/models/__init__.py
    CHANGED
    
    | @@ -1,8 +1,7 @@ | |
| 1 | 
            -
            from .autoencoder_magvit import ( | 
|  | |
| 2 | 
             
            from .transformer3d import (EasyAnimateTransformer3DModel,
         | 
| 3 | 
            -
             | 
| 4 | 
            -
                                                          Transformer3DModel)
         | 
| 5 | 
            -
             | 
| 6 |  | 
| 7 | 
             
            name_to_transformer3d = {
         | 
| 8 | 
             
                "Transformer3DModel": Transformer3DModel,
         | 
|  | |
| 1 | 
            +
            from .autoencoder_magvit import (AutoencoderKL, AutoencoderKLCogVideoX,
         | 
| 2 | 
            +
                                             AutoencoderKLMagvit)
         | 
| 3 | 
             
            from .transformer3d import (EasyAnimateTransformer3DModel,
         | 
| 4 | 
            +
                                        HunyuanTransformer3DModel, Transformer3DModel)
         | 
|  | |
|  | |
| 5 |  | 
| 6 | 
             
            name_to_transformer3d = {
         | 
| 7 | 
             
                "Transformer3DModel": Transformer3DModel,
         | 
    	
        easyanimate/models/attention.py
    CHANGED
    
    | @@ -29,7 +29,7 @@ from diffusers.models.embeddings import (SinusoidalPositionalEmbedding, | |
| 29 | 
             
                                                     get_3d_sincos_pos_embed)
         | 
| 30 | 
             
            from diffusers.models.modeling_outputs import Transformer2DModelOutput
         | 
| 31 | 
             
            from diffusers.models.modeling_utils import ModelMixin
         | 
| 32 | 
            -
            from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormZero, | 
| 33 | 
             
                                                        CogVideoXLayerNormZero)
         | 
| 34 | 
             
            from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging
         | 
| 35 | 
             
            from diffusers.utils.import_utils import is_xformers_available
         | 
| @@ -38,12 +38,11 @@ from einops import rearrange, repeat | |
| 38 | 
             
            from torch import nn
         | 
| 39 |  | 
| 40 | 
             
            from .motion_module import PositionalEncoding, get_motion_module
         | 
| 41 | 
            -
            from .norm import AdaLayerNormShift,  | 
| 42 | 
             
            from .processor import (EasyAnimateAttnProcessor2_0,
         | 
|  | |
| 43 | 
             
                                    LazyKVCompressionProcessor2_0)
         | 
| 44 |  | 
| 45 | 
            -
             | 
| 46 | 
            -
             | 
| 47 | 
             
            if is_xformers_available():
         | 
| 48 | 
             
                import xformers
         | 
| 49 | 
             
                import xformers.ops
         | 
| @@ -1042,7 +1041,9 @@ class EasyAnimateDiTBlock(nn.Module): | |
| 1042 | 
             
                    ff_bias: bool = True,
         | 
| 1043 | 
             
                    qk_norm: bool = True,
         | 
| 1044 | 
             
                    after_norm: bool = False,
         | 
| 1045 | 
            -
                    norm_type: str="fp32_layer_norm"
         | 
|  | |
|  | |
| 1046 | 
             
                ):
         | 
| 1047 | 
             
                    super().__init__()
         | 
| 1048 |  | 
| @@ -1051,6 +1052,7 @@ class EasyAnimateDiTBlock(nn.Module): | |
| 1051 | 
             
                        time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
         | 
| 1052 | 
             
                    )
         | 
| 1053 |  | 
|  | |
| 1054 | 
             
                    self.attn1 = Attention(
         | 
| 1055 | 
             
                        query_dim=dim,
         | 
| 1056 | 
             
                        dim_head=attention_head_dim,
         | 
| @@ -1058,17 +1060,20 @@ class EasyAnimateDiTBlock(nn.Module): | |
| 1058 | 
             
                        qk_norm="layer_norm" if qk_norm else None,
         | 
| 1059 | 
             
                        eps=1e-6,
         | 
| 1060 | 
             
                        bias=True,
         | 
| 1061 | 
            -
                        processor=EasyAnimateAttnProcessor2_0(),
         | 
| 1062 | 
            -
                    )
         | 
| 1063 | 
            -
                    self.attn2 = Attention(
         | 
| 1064 | 
            -
                        query_dim=dim,
         | 
| 1065 | 
            -
                        dim_head=attention_head_dim,
         | 
| 1066 | 
            -
                        heads=num_attention_heads,
         | 
| 1067 | 
            -
                        qk_norm="layer_norm" if qk_norm else None,
         | 
| 1068 | 
            -
                        eps=1e-6,
         | 
| 1069 | 
            -
                        bias=True,
         | 
| 1070 | 
            -
                        processor=EasyAnimateAttnProcessor2_0(),
         | 
| 1071 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1072 |  | 
| 1073 | 
             
                    # FFN Part
         | 
| 1074 | 
             
                    self.norm2 = EasyAnimateLayerNormZero(
         | 
| @@ -1082,14 +1087,18 @@ class EasyAnimateDiTBlock(nn.Module): | |
| 1082 | 
             
                        inner_dim=ff_inner_dim,
         | 
| 1083 | 
             
                        bias=ff_bias,
         | 
| 1084 | 
             
                    )
         | 
| 1085 | 
            -
                     | 
| 1086 | 
            -
                         | 
| 1087 | 
            -
             | 
| 1088 | 
            -
             | 
| 1089 | 
            -
             | 
| 1090 | 
            -
             | 
| 1091 | 
            -
             | 
| 1092 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 1093 | 
             
                    if after_norm:
         | 
| 1094 | 
             
                        self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         | 
| 1095 | 
             
                    else:
         | 
| @@ -1101,6 +1110,9 @@ class EasyAnimateDiTBlock(nn.Module): | |
| 1101 | 
             
                    encoder_hidden_states: torch.Tensor,
         | 
| 1102 | 
             
                    temb: torch.Tensor,
         | 
| 1103 | 
             
                    image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         | 
|  | |
|  | |
|  | |
| 1104 | 
             
                ) -> torch.Tensor:
         | 
| 1105 | 
             
                    # Norm
         | 
| 1106 | 
             
                    norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
         | 
| @@ -1108,12 +1120,23 @@ class EasyAnimateDiTBlock(nn.Module): | |
| 1108 | 
             
                    )
         | 
| 1109 |  | 
| 1110 | 
             
                    # Attn
         | 
| 1111 | 
            -
                     | 
| 1112 | 
            -
                         | 
| 1113 | 
            -
             | 
| 1114 | 
            -
             | 
| 1115 | 
            -
             | 
| 1116 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1117 | 
             
                    hidden_states = hidden_states + gate_msa * attn_hidden_states
         | 
| 1118 | 
             
                    encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
         | 
| 1119 |  | 
| @@ -1125,10 +1148,16 @@ class EasyAnimateDiTBlock(nn.Module): | |
| 1125 | 
             
                    # FFN
         | 
| 1126 | 
             
                    if self.norm3 is not None:
         | 
| 1127 | 
             
                        norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
         | 
| 1128 | 
            -
                         | 
|  | |
|  | |
|  | |
| 1129 | 
             
                    else:
         | 
| 1130 | 
             
                        norm_hidden_states = self.ff(norm_hidden_states)
         | 
| 1131 | 
            -
                         | 
|  | |
|  | |
|  | |
| 1132 | 
             
                    hidden_states = hidden_states + gate_ff * norm_hidden_states
         | 
| 1133 | 
             
                    encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states
         | 
| 1134 | 
             
                    return hidden_states, encoder_hidden_states
         | 
|  | |
| 29 | 
             
                                                     get_3d_sincos_pos_embed)
         | 
| 30 | 
             
            from diffusers.models.modeling_outputs import Transformer2DModelOutput
         | 
| 31 | 
             
            from diffusers.models.modeling_utils import ModelMixin
         | 
| 32 | 
            +
            from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormZero,
         | 
| 33 | 
             
                                                        CogVideoXLayerNormZero)
         | 
| 34 | 
             
            from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging
         | 
| 35 | 
             
            from diffusers.utils.import_utils import is_xformers_available
         | 
|  | |
| 38 | 
             
            from torch import nn
         | 
| 39 |  | 
| 40 | 
             
            from .motion_module import PositionalEncoding, get_motion_module
         | 
| 41 | 
            +
            from .norm import AdaLayerNormShift, EasyAnimateLayerNormZero, FP32LayerNorm
         | 
| 42 | 
             
            from .processor import (EasyAnimateAttnProcessor2_0,
         | 
| 43 | 
            +
                                    EasyAnimateSWAttnProcessor2_0,
         | 
| 44 | 
             
                                    LazyKVCompressionProcessor2_0)
         | 
| 45 |  | 
|  | |
|  | |
| 46 | 
             
            if is_xformers_available():
         | 
| 47 | 
             
                import xformers
         | 
| 48 | 
             
                import xformers.ops
         | 
|  | |
| 1041 | 
             
                    ff_bias: bool = True,
         | 
| 1042 | 
             
                    qk_norm: bool = True,
         | 
| 1043 | 
             
                    after_norm: bool = False,
         | 
| 1044 | 
            +
                    norm_type: str="fp32_layer_norm",
         | 
| 1045 | 
            +
                    is_mmdit_block: bool = True,
         | 
| 1046 | 
            +
                    is_swa: bool = False,
         | 
| 1047 | 
             
                ):
         | 
| 1048 | 
             
                    super().__init__()
         | 
| 1049 |  | 
|  | |
| 1052 | 
             
                        time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
         | 
| 1053 | 
             
                    )
         | 
| 1054 |  | 
| 1055 | 
            +
                    self.is_swa = is_swa
         | 
| 1056 | 
             
                    self.attn1 = Attention(
         | 
| 1057 | 
             
                        query_dim=dim,
         | 
| 1058 | 
             
                        dim_head=attention_head_dim,
         | 
|  | |
| 1060 | 
             
                        qk_norm="layer_norm" if qk_norm else None,
         | 
| 1061 | 
             
                        eps=1e-6,
         | 
| 1062 | 
             
                        bias=True,
         | 
| 1063 | 
            +
                        processor=EasyAnimateAttnProcessor2_0() if not is_swa else EasyAnimateSWAttnProcessor2_0(),
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1064 | 
             
                    )
         | 
| 1065 | 
            +
                    if is_mmdit_block:
         | 
| 1066 | 
            +
                        self.attn2 = Attention(
         | 
| 1067 | 
            +
                            query_dim=dim,
         | 
| 1068 | 
            +
                            dim_head=attention_head_dim,
         | 
| 1069 | 
            +
                            heads=num_attention_heads,
         | 
| 1070 | 
            +
                            qk_norm="layer_norm" if qk_norm else None,
         | 
| 1071 | 
            +
                            eps=1e-6,
         | 
| 1072 | 
            +
                            bias=True,
         | 
| 1073 | 
            +
                            processor=EasyAnimateAttnProcessor2_0() if not is_swa else EasyAnimateSWAttnProcessor2_0(),
         | 
| 1074 | 
            +
                        )
         | 
| 1075 | 
            +
                    else:
         | 
| 1076 | 
            +
                        self.attn2 = None
         | 
| 1077 |  | 
| 1078 | 
             
                    # FFN Part
         | 
| 1079 | 
             
                    self.norm2 = EasyAnimateLayerNormZero(
         | 
|  | |
| 1087 | 
             
                        inner_dim=ff_inner_dim,
         | 
| 1088 | 
             
                        bias=ff_bias,
         | 
| 1089 | 
             
                    )
         | 
| 1090 | 
            +
                    if is_mmdit_block:
         | 
| 1091 | 
            +
                        self.txt_ff = FeedForward(
         | 
| 1092 | 
            +
                            dim,
         | 
| 1093 | 
            +
                            dropout=dropout,
         | 
| 1094 | 
            +
                            activation_fn=activation_fn,
         | 
| 1095 | 
            +
                            final_dropout=final_dropout,
         | 
| 1096 | 
            +
                            inner_dim=ff_inner_dim,
         | 
| 1097 | 
            +
                            bias=ff_bias,
         | 
| 1098 | 
            +
                        )
         | 
| 1099 | 
            +
                    else:
         | 
| 1100 | 
            +
                        self.txt_ff = None
         | 
| 1101 | 
            +
                        
         | 
| 1102 | 
             
                    if after_norm:
         | 
| 1103 | 
             
                        self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         | 
| 1104 | 
             
                    else:
         | 
|  | |
| 1110 | 
             
                    encoder_hidden_states: torch.Tensor,
         | 
| 1111 | 
             
                    temb: torch.Tensor,
         | 
| 1112 | 
             
                    image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
         | 
| 1113 | 
            +
                    num_frames = None,
         | 
| 1114 | 
            +
                    height = None,
         | 
| 1115 | 
            +
                    width = None
         | 
| 1116 | 
             
                ) -> torch.Tensor:
         | 
| 1117 | 
             
                    # Norm
         | 
| 1118 | 
             
                    norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
         | 
|  | |
| 1120 | 
             
                    )
         | 
| 1121 |  | 
| 1122 | 
             
                    # Attn
         | 
| 1123 | 
            +
                    if self.is_swa:
         | 
| 1124 | 
            +
                        attn_hidden_states, attn_encoder_hidden_states = self.attn1(
         | 
| 1125 | 
            +
                            hidden_states=norm_hidden_states,
         | 
| 1126 | 
            +
                            encoder_hidden_states=norm_encoder_hidden_states,
         | 
| 1127 | 
            +
                            image_rotary_emb=image_rotary_emb,
         | 
| 1128 | 
            +
                            attn2=self.attn2,
         | 
| 1129 | 
            +
                            num_frames=num_frames,
         | 
| 1130 | 
            +
                            height=height,
         | 
| 1131 | 
            +
                            width=width,
         | 
| 1132 | 
            +
                        )
         | 
| 1133 | 
            +
                    else:
         | 
| 1134 | 
            +
                        attn_hidden_states, attn_encoder_hidden_states = self.attn1(
         | 
| 1135 | 
            +
                            hidden_states=norm_hidden_states,
         | 
| 1136 | 
            +
                            encoder_hidden_states=norm_encoder_hidden_states,
         | 
| 1137 | 
            +
                            image_rotary_emb=image_rotary_emb,
         | 
| 1138 | 
            +
                            attn2=self.attn2
         | 
| 1139 | 
            +
                        )
         | 
| 1140 | 
             
                    hidden_states = hidden_states + gate_msa * attn_hidden_states
         | 
| 1141 | 
             
                    encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
         | 
| 1142 |  | 
|  | |
| 1148 | 
             
                    # FFN
         | 
| 1149 | 
             
                    if self.norm3 is not None:
         | 
| 1150 | 
             
                        norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
         | 
| 1151 | 
            +
                        if self.txt_ff is not None:
         | 
| 1152 | 
            +
                            norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
         | 
| 1153 | 
            +
                        else:
         | 
| 1154 | 
            +
                            norm_encoder_hidden_states = self.norm3(self.ff(norm_encoder_hidden_states))
         | 
| 1155 | 
             
                    else:
         | 
| 1156 | 
             
                        norm_hidden_states = self.ff(norm_hidden_states)
         | 
| 1157 | 
            +
                        if self.txt_ff is not None:
         | 
| 1158 | 
            +
                            norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
         | 
| 1159 | 
            +
                        else:
         | 
| 1160 | 
            +
                            norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states)
         | 
| 1161 | 
             
                    hidden_states = hidden_states + gate_ff * norm_hidden_states
         | 
| 1162 | 
             
                    encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states
         | 
| 1163 | 
             
                    return hidden_states, encoder_hidden_states
         | 
    	
        easyanimate/models/autoencoder_magvit.py
    CHANGED
    
    | @@ -44,6 +44,7 @@ from ..vae.ldm.models.cogvideox_enc_dec import (CogVideoXCausalConv3d, | |
| 44 | 
             
                                                            CogVideoXDecoder3D,
         | 
| 45 | 
             
                                                            CogVideoXEncoder3D,
         | 
| 46 | 
             
                                                            CogVideoXSafeConv3d)
         | 
|  | |
| 47 | 
             
            from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
         | 
| 48 | 
             
            from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
         | 
| 49 |  | 
| @@ -96,6 +97,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): | |
| 96 | 
             
                    out_channels: int = 3,
         | 
| 97 | 
             
                    ch =  128,
         | 
| 98 | 
             
                    ch_mult = [ 1,2,4,4 ],
         | 
|  | |
| 99 | 
             
                    use_gc_blocks = None,
         | 
| 100 | 
             
                    down_block_types: tuple = None,
         | 
| 101 | 
             
                    up_block_types: tuple = None,
         | 
| @@ -109,6 +111,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): | |
| 109 | 
             
                    latent_channels: int = 4,
         | 
| 110 | 
             
                    norm_num_groups: int = 32,
         | 
| 111 | 
             
                    scaling_factor: float = 0.1825,
         | 
|  | |
| 112 | 
             
                    slice_mag_vae=True,
         | 
| 113 | 
             
                    slice_compression_vae=False,
         | 
| 114 | 
             
                    cache_compression_vae=False,
         | 
| @@ -130,8 +133,9 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): | |
| 130 | 
             
                        in_channels=in_channels,
         | 
| 131 | 
             
                        out_channels=latent_channels,
         | 
| 132 | 
             
                        down_block_types=down_block_types,
         | 
| 133 | 
            -
                        ch | 
| 134 | 
            -
                        ch_mult | 
|  | |
| 135 | 
             
                        use_gc_blocks=use_gc_blocks,
         | 
| 136 | 
             
                        mid_block_type=mid_block_type,
         | 
| 137 | 
             
                        mid_block_use_attention=mid_block_use_attention,
         | 
| @@ -154,8 +158,9 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): | |
| 154 | 
             
                        in_channels=latent_channels,
         | 
| 155 | 
             
                        out_channels=out_channels,
         | 
| 156 | 
             
                        up_block_types=up_block_types,
         | 
| 157 | 
            -
                        ch | 
| 158 | 
            -
                        ch_mult | 
|  | |
| 159 | 
             
                        use_gc_blocks=use_gc_blocks,
         | 
| 160 | 
             
                        mid_block_type=mid_block_type,
         | 
| 161 | 
             
                        mid_block_use_attention=mid_block_use_attention,
         | 
| @@ -196,81 +201,10 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): | |
| 196 | 
             
                    if isinstance(module, (omnigen_Mag_Encoder, omnigen_Mag_Decoder)):
         | 
| 197 | 
             
                        module.gradient_checkpointing = value
         | 
| 198 |  | 
| 199 | 
            -
                 | 
| 200 | 
            -
             | 
| 201 | 
            -
             | 
| 202 | 
            -
             | 
| 203 | 
            -
                    Returns:
         | 
| 204 | 
            -
                        `dict` of attention processors: A dictionary containing all attention processors used in the model with
         | 
| 205 | 
            -
                        indexed by its weight name.
         | 
| 206 | 
            -
                    """
         | 
| 207 | 
            -
                    # set recursively
         | 
| 208 | 
            -
                    processors = {}
         | 
| 209 | 
            -
             | 
| 210 | 
            -
                    def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
         | 
| 211 | 
            -
                        if hasattr(module, "get_processor"):
         | 
| 212 | 
            -
                            processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
         | 
| 213 | 
            -
             | 
| 214 | 
            -
                        for sub_name, child in module.named_children():
         | 
| 215 | 
            -
                            fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
         | 
| 216 | 
            -
             | 
| 217 | 
            -
                        return processors
         | 
| 218 | 
            -
             | 
| 219 | 
            -
                    for name, module in self.named_children():
         | 
| 220 | 
            -
                        fn_recursive_add_processors(name, module, processors)
         | 
| 221 | 
            -
             | 
| 222 | 
            -
                    return processors
         | 
| 223 | 
            -
             | 
| 224 | 
            -
                # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
         | 
| 225 | 
            -
                def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
         | 
| 226 | 
            -
                    r"""
         | 
| 227 | 
            -
                    Sets the attention processor to use to compute attention.
         | 
| 228 | 
            -
             | 
| 229 | 
            -
                    Parameters:
         | 
| 230 | 
            -
                        processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
         | 
| 231 | 
            -
                            The instantiated processor class or a dictionary of processor classes that will be set as the processor
         | 
| 232 | 
            -
                            for **all** `Attention` layers.
         | 
| 233 | 
            -
             | 
| 234 | 
            -
                            If `processor` is a dict, the key needs to define the path to the corresponding cross attention
         | 
| 235 | 
            -
                            processor. This is strongly recommended when setting trainable attention processors.
         | 
| 236 | 
            -
             | 
| 237 | 
            -
                    """
         | 
| 238 | 
            -
                    count = len(self.attn_processors.keys())
         | 
| 239 | 
            -
             | 
| 240 | 
            -
                    if isinstance(processor, dict) and len(processor) != count:
         | 
| 241 | 
            -
                        raise ValueError(
         | 
| 242 | 
            -
                            f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
         | 
| 243 | 
            -
                            f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
         | 
| 244 | 
            -
                        )
         | 
| 245 | 
            -
             | 
| 246 | 
            -
                    def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
         | 
| 247 | 
            -
                        if hasattr(module, "set_processor"):
         | 
| 248 | 
            -
                            if not isinstance(processor, dict):
         | 
| 249 | 
            -
                                module.set_processor(processor)
         | 
| 250 | 
            -
                            else:
         | 
| 251 | 
            -
                                module.set_processor(processor.pop(f"{name}.processor"))
         | 
| 252 | 
            -
             | 
| 253 | 
            -
                        for sub_name, child in module.named_children():
         | 
| 254 | 
            -
                            fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
         | 
| 255 | 
            -
             | 
| 256 | 
            -
                    for name, module in self.named_children():
         | 
| 257 | 
            -
                        fn_recursive_attn_processor(name, module, processor)
         | 
| 258 | 
            -
             | 
| 259 | 
            -
                # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
         | 
| 260 | 
            -
                def set_default_attn_processor(self):
         | 
| 261 | 
            -
                    """
         | 
| 262 | 
            -
                    Disables custom attention processors and sets the default attention implementation.
         | 
| 263 | 
            -
                    """
         | 
| 264 | 
            -
                    if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
         | 
| 265 | 
            -
                        processor = AttnAddedKVProcessor()
         | 
| 266 | 
            -
                    elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
         | 
| 267 | 
            -
                        processor = AttnProcessor()
         | 
| 268 | 
            -
                    else:
         | 
| 269 | 
            -
                        raise ValueError(
         | 
| 270 | 
            -
                            f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
         | 
| 271 | 
            -
                        )
         | 
| 272 | 
            -
             | 
| 273 | 
            -
                    self.set_attn_processor(processor)
         | 
| 274 |  | 
| 275 | 
             
                @apply_forward_hook
         | 
| 276 | 
             
                def encode(
         | 
| @@ -308,6 +242,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): | |
| 308 | 
             
                    moments = self.quant_conv(h)
         | 
| 309 | 
             
                    posterior = DiagonalGaussianDistribution(moments)
         | 
| 310 |  | 
|  | |
| 311 | 
             
                    if not return_dict:
         | 
| 312 | 
             
                        return (posterior,)
         | 
| 313 |  | 
| @@ -355,6 +290,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): | |
| 355 | 
             
                    else:
         | 
| 356 | 
             
                        decoded = self._decode(z).sample
         | 
| 357 |  | 
|  | |
| 358 | 
             
                    if not return_dict:
         | 
| 359 | 
             
                        return (decoded,)
         | 
| 360 |  | 
| @@ -519,44 +455,6 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): | |
| 519 |  | 
| 520 | 
             
                    return DecoderOutput(sample=dec)
         | 
| 521 |  | 
| 522 | 
            -
                # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
         | 
| 523 | 
            -
                def fuse_qkv_projections(self):
         | 
| 524 | 
            -
                    """
         | 
| 525 | 
            -
                    Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
         | 
| 526 | 
            -
                    key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
         | 
| 527 | 
            -
             | 
| 528 | 
            -
                    <Tip warning={true}>
         | 
| 529 | 
            -
             | 
| 530 | 
            -
                    This API is 🧪 experimental.
         | 
| 531 | 
            -
             | 
| 532 | 
            -
                    </Tip>
         | 
| 533 | 
            -
                    """
         | 
| 534 | 
            -
                    self.original_attn_processors = None
         | 
| 535 | 
            -
             | 
| 536 | 
            -
                    for _, attn_processor in self.attn_processors.items():
         | 
| 537 | 
            -
                        if "Added" in str(attn_processor.__class__.__name__):
         | 
| 538 | 
            -
                            raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
         | 
| 539 | 
            -
             | 
| 540 | 
            -
                    self.original_attn_processors = self.attn_processors
         | 
| 541 | 
            -
             | 
| 542 | 
            -
                    for module in self.modules():
         | 
| 543 | 
            -
                        if isinstance(module, Attention):
         | 
| 544 | 
            -
                            module.fuse_projections(fuse=True)
         | 
| 545 | 
            -
             | 
| 546 | 
            -
                # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
         | 
| 547 | 
            -
                def unfuse_qkv_projections(self):
         | 
| 548 | 
            -
                    """Disables the fused QKV projection if enabled.
         | 
| 549 | 
            -
             | 
| 550 | 
            -
                    <Tip warning={true}>
         | 
| 551 | 
            -
             | 
| 552 | 
            -
                    This API is 🧪 experimental.
         | 
| 553 | 
            -
             | 
| 554 | 
            -
                    </Tip>
         | 
| 555 | 
            -
             | 
| 556 | 
            -
                    """
         | 
| 557 | 
            -
                    if self.original_attn_processors is not None:
         | 
| 558 | 
            -
                        self.set_attn_processor(self.original_attn_processors)
         | 
| 559 | 
            -
             | 
| 560 | 
             
                @classmethod
         | 
| 561 | 
             
                def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
         | 
| 562 | 
             
                    import json
         | 
|  | |
| 44 | 
             
                                                            CogVideoXDecoder3D,
         | 
| 45 | 
             
                                                            CogVideoXEncoder3D,
         | 
| 46 | 
             
                                                            CogVideoXSafeConv3d)
         | 
| 47 | 
            +
            from ..vae.ldm.models.omnigen_enc_dec import CausalConv3d
         | 
| 48 | 
             
            from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
         | 
| 49 | 
             
            from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
         | 
| 50 |  | 
|  | |
| 97 | 
             
                    out_channels: int = 3,
         | 
| 98 | 
             
                    ch =  128,
         | 
| 99 | 
             
                    ch_mult = [ 1,2,4,4 ],
         | 
| 100 | 
            +
                    block_out_channels = [128, 256, 512, 512],
         | 
| 101 | 
             
                    use_gc_blocks = None,
         | 
| 102 | 
             
                    down_block_types: tuple = None,
         | 
| 103 | 
             
                    up_block_types: tuple = None,
         | 
|  | |
| 111 | 
             
                    latent_channels: int = 4,
         | 
| 112 | 
             
                    norm_num_groups: int = 32,
         | 
| 113 | 
             
                    scaling_factor: float = 0.1825,
         | 
| 114 | 
            +
                    force_upcast: float = True,
         | 
| 115 | 
             
                    slice_mag_vae=True,
         | 
| 116 | 
             
                    slice_compression_vae=False,
         | 
| 117 | 
             
                    cache_compression_vae=False,
         | 
|  | |
| 133 | 
             
                        in_channels=in_channels,
         | 
| 134 | 
             
                        out_channels=latent_channels,
         | 
| 135 | 
             
                        down_block_types=down_block_types,
         | 
| 136 | 
            +
                        ch=ch,
         | 
| 137 | 
            +
                        ch_mult=ch_mult,
         | 
| 138 | 
            +
                        block_out_channels=block_out_channels,
         | 
| 139 | 
             
                        use_gc_blocks=use_gc_blocks,
         | 
| 140 | 
             
                        mid_block_type=mid_block_type,
         | 
| 141 | 
             
                        mid_block_use_attention=mid_block_use_attention,
         | 
|  | |
| 158 | 
             
                        in_channels=latent_channels,
         | 
| 159 | 
             
                        out_channels=out_channels,
         | 
| 160 | 
             
                        up_block_types=up_block_types,
         | 
| 161 | 
            +
                        ch=ch,
         | 
| 162 | 
            +
                        ch_mult=ch_mult,
         | 
| 163 | 
            +
                        block_out_channels=block_out_channels,
         | 
| 164 | 
             
                        use_gc_blocks=use_gc_blocks,
         | 
| 165 | 
             
                        mid_block_type=mid_block_type,
         | 
| 166 | 
             
                        mid_block_use_attention=mid_block_use_attention,
         | 
|  | |
| 201 | 
             
                    if isinstance(module, (omnigen_Mag_Encoder, omnigen_Mag_Decoder)):
         | 
| 202 | 
             
                        module.gradient_checkpointing = value
         | 
| 203 |  | 
| 204 | 
            +
                def _clear_conv_cache(self):
         | 
| 205 | 
            +
                    for name, module in self.named_modules():
         | 
| 206 | 
            +
                        if isinstance(module, CausalConv3d):
         | 
| 207 | 
            +
                            module._clear_conv_cache()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 208 |  | 
| 209 | 
             
                @apply_forward_hook
         | 
| 210 | 
             
                def encode(
         | 
|  | |
| 242 | 
             
                    moments = self.quant_conv(h)
         | 
| 243 | 
             
                    posterior = DiagonalGaussianDistribution(moments)
         | 
| 244 |  | 
| 245 | 
            +
                    self._clear_conv_cache()
         | 
| 246 | 
             
                    if not return_dict:
         | 
| 247 | 
             
                        return (posterior,)
         | 
| 248 |  | 
|  | |
| 290 | 
             
                    else:
         | 
| 291 | 
             
                        decoded = self._decode(z).sample
         | 
| 292 |  | 
| 293 | 
            +
                    self._clear_conv_cache()
         | 
| 294 | 
             
                    if not return_dict:
         | 
| 295 | 
             
                        return (decoded,)
         | 
| 296 |  | 
|  | |
| 455 |  | 
| 456 | 
             
                    return DecoderOutput(sample=dec)
         | 
| 457 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 458 | 
             
                @classmethod
         | 
| 459 | 
             
                def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
         | 
| 460 | 
             
                    import json
         | 
    	
        easyanimate/models/embeddings.py
    CHANGED
    
    | @@ -4,8 +4,9 @@ from typing import Optional | |
| 4 | 
             
            import numpy as np
         | 
| 5 | 
             
            import torch
         | 
| 6 | 
             
            import torch.nn.functional as F
         | 
| 7 | 
            -
            from diffusers.models.embeddings import (PixArtAlphaTextProjection, | 
| 8 | 
            -
                                                     TimestepEmbedding, Timesteps | 
|  | |
| 9 | 
             
            from einops import rearrange
         | 
| 10 | 
             
            from torch import nn
         | 
| 11 |  | 
|  | |
| 4 | 
             
            import numpy as np
         | 
| 5 | 
             
            import torch
         | 
| 6 | 
             
            import torch.nn.functional as F
         | 
| 7 | 
            +
            from diffusers.models.embeddings import (PixArtAlphaTextProjection,
         | 
| 8 | 
            +
                                                     TimestepEmbedding, Timesteps,
         | 
| 9 | 
            +
                                                     get_timestep_embedding)
         | 
| 10 | 
             
            from einops import rearrange
         | 
| 11 | 
             
            from torch import nn
         | 
| 12 |  | 
    	
        easyanimate/models/norm.py
    CHANGED
    
    | @@ -25,6 +25,22 @@ class FP32LayerNorm(nn.LayerNorm): | |
| 25 | 
             
                            inputs.float(), self.normalized_shape, None, None, self.eps
         | 
| 26 | 
             
                        ).to(origin_dtype)
         | 
| 27 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 28 | 
             
            class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
         | 
| 29 | 
             
                """
         | 
| 30 | 
             
                For PixArt-Alpha.
         | 
|  | |
| 25 | 
             
                            inputs.float(), self.normalized_shape, None, None, self.eps
         | 
| 26 | 
             
                        ).to(origin_dtype)
         | 
| 27 |  | 
| 28 | 
            +
            class EasyAnimateRMSNorm(nn.Module):
         | 
| 29 | 
            +
                def __init__(self, hidden_size, eps=1e-6):
         | 
| 30 | 
            +
                    super().__init__()
         | 
| 31 | 
            +
                    self.weight = nn.Parameter(torch.ones(hidden_size))
         | 
| 32 | 
            +
                    self.variance_epsilon = eps
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def forward(self, hidden_states):
         | 
| 35 | 
            +
                    input_dtype = hidden_states.dtype
         | 
| 36 | 
            +
                    hidden_states = hidden_states.to(torch.float32)
         | 
| 37 | 
            +
                    variance = hidden_states.pow(2).mean(-1, keepdim=True)
         | 
| 38 | 
            +
                    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
         | 
| 39 | 
            +
                    return self.weight * hidden_states.to(input_dtype)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def extra_repr(self):
         | 
| 42 | 
            +
                    return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
         | 
| 43 | 
            +
             | 
| 44 | 
             
            class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
         | 
| 45 | 
             
                """
         | 
| 46 | 
             
                For PixArt-Alpha.
         | 
    	
        easyanimate/models/processor.py
    CHANGED
    
    | @@ -310,3 +310,149 @@ class EasyAnimateAttnProcessor2_0: | |
| 310 | 
             
                        hidden_states = attn.to_out[1](hidden_states)
         | 
| 311 | 
             
                        encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
         | 
| 312 | 
             
                    return hidden_states, encoder_hidden_states
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 310 | 
             
                        hidden_states = attn.to_out[1](hidden_states)
         | 
| 311 | 
             
                        encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
         | 
| 312 | 
             
                    return hidden_states, encoder_hidden_states
         | 
| 313 | 
            +
             | 
| 314 | 
            +
            try:
         | 
| 315 | 
            +
                from flash_attn import flash_attn_func, flash_attn_varlen_func
         | 
| 316 | 
            +
                from flash_attn.bert_padding import pad_input, unpad_input
         | 
| 317 | 
            +
            except:
         | 
| 318 | 
            +
                print("Flash Attention is not installed. Please install with `pip install flash-attn`, if you want to use SWA.")
         | 
| 319 | 
            +
             | 
| 320 | 
            +
            class EasyAnimateSWAttnProcessor2_0:
         | 
| 321 | 
            +
                def __init__(self, window_size=1024):
         | 
| 322 | 
            +
                    self.window_size = window_size
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                def __call__(
         | 
| 325 | 
            +
                    self,
         | 
| 326 | 
            +
                    attn: Attention,
         | 
| 327 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 328 | 
            +
                    encoder_hidden_states: torch.Tensor,
         | 
| 329 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 330 | 
            +
                    image_rotary_emb: Optional[torch.Tensor] = None,
         | 
| 331 | 
            +
                    num_frames: int = None, 
         | 
| 332 | 
            +
                    height: int = None, 
         | 
| 333 | 
            +
                    width: int = None,
         | 
| 334 | 
            +
                    attn2: Attention = None,
         | 
| 335 | 
            +
                ) -> torch.Tensor:
         | 
| 336 | 
            +
                    text_seq_length = encoder_hidden_states.size(1)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 339 | 
            +
                        hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
         | 
| 340 | 
            +
                    )
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    if attn2 is None:
         | 
| 343 | 
            +
                        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 346 | 
            +
                    key = attn.to_k(hidden_states)
         | 
| 347 | 
            +
                    value = attn.to_v(hidden_states)
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    inner_dim = key.shape[-1]
         | 
| 350 | 
            +
                    head_dim = inner_dim // attn.heads
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 353 | 
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 354 | 
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim)
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    if attn.norm_q is not None:
         | 
| 357 | 
            +
                        query = attn.norm_q(query)
         | 
| 358 | 
            +
                    if attn.norm_k is not None:
         | 
| 359 | 
            +
                        key = attn.norm_k(key)
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    if attn2 is not None:
         | 
| 362 | 
            +
                        query_txt = attn2.to_q(encoder_hidden_states)
         | 
| 363 | 
            +
                        key_txt = attn2.to_k(encoder_hidden_states)
         | 
| 364 | 
            +
                        value_txt = attn2.to_v(encoder_hidden_states)
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                        inner_dim = key_txt.shape[-1]
         | 
| 367 | 
            +
                        head_dim = inner_dim // attn.heads
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                        query_txt = query_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 370 | 
            +
                        key_txt = key_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 371 | 
            +
                        value_txt = value_txt.view(batch_size, -1, attn.heads, head_dim)
         | 
| 372 | 
            +
                        
         | 
| 373 | 
            +
                        if attn2.norm_q is not None:
         | 
| 374 | 
            +
                            query_txt = attn2.norm_q(query_txt)
         | 
| 375 | 
            +
                        if attn2.norm_k is not None:
         | 
| 376 | 
            +
                            key_txt = attn2.norm_k(key_txt)
         | 
| 377 | 
            +
                        
         | 
| 378 | 
            +
                        query = torch.cat([query_txt, query], dim=2)
         | 
| 379 | 
            +
                        key = torch.cat([key_txt, key], dim=2)
         | 
| 380 | 
            +
                        value = torch.cat([value_txt, value], dim=1)
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                    # Apply RoPE if needed
         | 
| 383 | 
            +
                    if image_rotary_emb is not None:
         | 
| 384 | 
            +
                        query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
         | 
| 385 | 
            +
                        if not attn.is_cross_attention:
         | 
| 386 | 
            +
                            key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
         | 
| 387 | 
            +
                        
         | 
| 388 | 
            +
                    query = query.transpose(1, 2).to(value)
         | 
| 389 | 
            +
                    key = key.transpose(1, 2).to(value)
         | 
| 390 | 
            +
                    interval = max((query.size(1) - text_seq_length) // (self.window_size - text_seq_length), 1)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    cross_key = torch.cat([key[:, :text_seq_length], key[:, text_seq_length::interval]], dim=1)
         | 
| 393 | 
            +
                    cross_val = torch.cat([value[:, :text_seq_length], value[:, text_seq_length::interval]], dim=1)
         | 
| 394 | 
            +
                    cross_hidden_states = flash_attn_func(query, cross_key, cross_val, dropout_p=0.0, causal=False)
         | 
| 395 | 
            +
                    
         | 
| 396 | 
            +
                    # Split and rearrange to six directions
         | 
| 397 | 
            +
                    querys = torch.tensor_split(query[:, text_seq_length:], 6, 2)
         | 
| 398 | 
            +
                    keys = torch.tensor_split(key[:, text_seq_length:], 6, 2)
         | 
| 399 | 
            +
                    values = torch.tensor_split(value[:, text_seq_length:], 6, 2)
         | 
| 400 | 
            +
                    
         | 
| 401 | 
            +
                    new_querys = [querys[0]]
         | 
| 402 | 
            +
                    new_keys = [keys[0]]
         | 
| 403 | 
            +
                    new_values = [values[0]]
         | 
| 404 | 
            +
                    for index, mode in enumerate(
         | 
| 405 | 
            +
                        [
         | 
| 406 | 
            +
                            "bs (f h w) hn hd -> bs (f w h) hn hd", 
         | 
| 407 | 
            +
                            "bs (f h w) hn hd -> bs (h f w) hn hd", 
         | 
| 408 | 
            +
                            "bs (f h w) hn hd -> bs (h w f) hn hd", 
         | 
| 409 | 
            +
                            "bs (f h w) hn hd -> bs (w f h) hn hd", 
         | 
| 410 | 
            +
                            "bs (f h w) hn hd -> bs (w h f) hn hd"
         | 
| 411 | 
            +
                        ]
         | 
| 412 | 
            +
                    ):
         | 
| 413 | 
            +
                        new_querys.append(rearrange(querys[index + 1], mode, f=num_frames, h=height, w=width))
         | 
| 414 | 
            +
                        new_keys.append(rearrange(keys[index + 1], mode, f=num_frames, h=height, w=width))
         | 
| 415 | 
            +
                        new_values.append(rearrange(values[index + 1], mode, f=num_frames, h=height, w=width))
         | 
| 416 | 
            +
                    query = torch.cat(new_querys, dim=2)
         | 
| 417 | 
            +
                    key = torch.cat(new_keys, dim=2)
         | 
| 418 | 
            +
                    value = torch.cat(new_values, dim=2)
         | 
| 419 | 
            +
                    
         | 
| 420 | 
            +
                    # apply attention
         | 
| 421 | 
            +
                    hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False, window_size=(self.window_size, self.window_size))
         | 
| 422 | 
            +
                    
         | 
| 423 | 
            +
                    hidden_states = torch.tensor_split(hidden_states, 6, 2)
         | 
| 424 | 
            +
                    new_hidden_states = [hidden_states[0]]
         | 
| 425 | 
            +
                    for index, mode in enumerate(
         | 
| 426 | 
            +
                        [
         | 
| 427 | 
            +
                            "bs (f w h) hn hd -> bs (f h w) hn hd", 
         | 
| 428 | 
            +
                            "bs (h f w) hn hd -> bs (f h w) hn hd", 
         | 
| 429 | 
            +
                            "bs (h w f) hn hd -> bs (f h w) hn hd", 
         | 
| 430 | 
            +
                            "bs (w f h) hn hd -> bs (f h w) hn hd", 
         | 
| 431 | 
            +
                            "bs (w h f) hn hd -> bs (f h w) hn hd"
         | 
| 432 | 
            +
                        ]
         | 
| 433 | 
            +
                    ):
         | 
| 434 | 
            +
                        new_hidden_states.append(rearrange(hidden_states[index + 1], mode, f=num_frames, h=height, w=width))
         | 
| 435 | 
            +
                    hidden_states = torch.cat([cross_hidden_states[:, :text_seq_length], torch.cat(new_hidden_states, dim=2)], dim=1) + cross_hidden_states 
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    if attn2 is None:
         | 
| 440 | 
            +
                        # linear proj
         | 
| 441 | 
            +
                        hidden_states = attn.to_out[0](hidden_states)
         | 
| 442 | 
            +
                        # dropout
         | 
| 443 | 
            +
                        hidden_states = attn.to_out[1](hidden_states)
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                        encoder_hidden_states, hidden_states = hidden_states.split(
         | 
| 446 | 
            +
                            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
         | 
| 447 | 
            +
                        )
         | 
| 448 | 
            +
                    else:
         | 
| 449 | 
            +
                        encoder_hidden_states, hidden_states = hidden_states.split(
         | 
| 450 | 
            +
                            [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
         | 
| 451 | 
            +
                        )
         | 
| 452 | 
            +
                        # linear proj
         | 
| 453 | 
            +
                        hidden_states = attn.to_out[0](hidden_states)
         | 
| 454 | 
            +
                        encoder_hidden_states = attn2.to_out[0](encoder_hidden_states)
         | 
| 455 | 
            +
                        # dropout
         | 
| 456 | 
            +
                        hidden_states = attn.to_out[1](hidden_states)
         | 
| 457 | 
            +
                        encoder_hidden_states = attn2.to_out[1](encoder_hidden_states)
         | 
| 458 | 
            +
                    return hidden_states, encoder_hidden_states
         | 
    	
        easyanimate/models/transformer3d.py
    CHANGED
    
    | @@ -39,8 +39,9 @@ from torch import nn | |
| 39 | 
             
            from .attention import (EasyAnimateDiTBlock, HunyuanDiTBlock,
         | 
| 40 | 
             
                                    SelfAttentionTemporalTransformerBlock,
         | 
| 41 | 
             
                                    TemporalTransformerBlock, zero_module)
         | 
| 42 | 
            -
            from .embeddings import HunyuanCombinedTimestepTextSizeStyleEmbedding, | 
| 43 | 
            -
             | 
|  | |
| 44 | 
             
            from .patch import (CasualPatchEmbed3D, PatchEmbed3D, PatchEmbedF3D,
         | 
| 45 | 
             
                                TemporalUpsampler3D, UnPatch1D)
         | 
| 46 | 
             
            from .resampler import Resampler
         | 
| @@ -142,6 +143,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 142 | 
             
                    norm_eps: float = 1e-5,
         | 
| 143 | 
             
                    attention_type: str = "default",
         | 
| 144 | 
             
                    caption_channels: int = None,
         | 
|  | |
| 145 | 
             
                    # block type
         | 
| 146 | 
             
                    basic_block_type: str = "motionmodule",
         | 
| 147 | 
             
                    # enable_uvit
         | 
| @@ -168,6 +170,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 168 | 
             
                    after_norm = False,
         | 
| 169 | 
             
                    resize_inpaint_mask_directly: bool = False,
         | 
| 170 | 
             
                    enable_clip_in_inpaint: bool = True,
         | 
|  | |
|  | |
| 171 | 
             
                    enable_text_attention_mask: bool = True,
         | 
| 172 | 
             
                    add_noise_in_inpaint_model: bool = False,
         | 
| 173 | 
             
                ):
         | 
| @@ -192,6 +196,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 192 | 
             
                    self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size
         | 
| 193 | 
             
                    interpolation_scale = self.config.sample_size // 64  # => 64 (= 512 pixart) has interpolation scale 1
         | 
| 194 | 
             
                    interpolation_scale = max(interpolation_scale, 1)
         | 
|  | |
| 195 |  | 
| 196 | 
             
                    if self.casual_3d:
         | 
| 197 | 
             
                        self.pos_embed = CasualPatchEmbed3D(
         | 
| @@ -397,16 +402,22 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 397 | 
             
                def forward(
         | 
| 398 | 
             
                    self,
         | 
| 399 | 
             
                    hidden_states: torch.Tensor,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 400 | 
             
                    inpaint_latents: torch.Tensor = None,
         | 
| 401 | 
             
                    control_latents: torch.Tensor = None,
         | 
| 402 | 
            -
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 403 | 
            -
                    clip_encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 404 | 
            -
                    timestep: Optional[torch.LongTensor] = None,
         | 
| 405 | 
             
                    added_cond_kwargs: Dict[str, torch.Tensor] = None,
         | 
| 406 | 
             
                    class_labels: Optional[torch.LongTensor] = None,
         | 
| 407 | 
             
                    cross_attention_kwargs: Dict[str, Any] = None,
         | 
| 408 | 
             
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 409 | 
            -
                     | 
| 410 | 
             
                    clip_attention_mask: Optional[torch.Tensor] = None,
         | 
| 411 | 
             
                    return_dict: bool = True,
         | 
| 412 | 
             
                ):
         | 
| @@ -432,7 +443,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 432 | 
             
                            An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
         | 
| 433 | 
             
                            is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
         | 
| 434 | 
             
                            negative values to the attention scores corresponding to "discard" tokens.
         | 
| 435 | 
            -
                         | 
| 436 | 
             
                            Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
         | 
| 437 |  | 
| 438 | 
             
                                * Mask `(batch, sequence_length)` True = keep, False = discard.
         | 
| @@ -466,11 +477,12 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 466 | 
             
                        attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
         | 
| 467 | 
             
                        attention_mask = attention_mask.unsqueeze(1)
         | 
| 468 |  | 
|  | |
| 469 | 
             
                    if clip_attention_mask is not None:
         | 
| 470 | 
            -
                         | 
| 471 | 
             
                    # convert encoder_attention_mask to a bias the same way we do for attention_mask
         | 
| 472 | 
            -
                    if  | 
| 473 | 
            -
                        encoder_attention_mask = (1 -  | 
| 474 | 
             
                        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
         | 
| 475 |  | 
| 476 | 
             
                    if inpaint_latents is not None:
         | 
| @@ -637,7 +649,10 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 637 | 
             
                    return Transformer3DModelOutput(sample=output)
         | 
| 638 |  | 
| 639 | 
             
                @classmethod
         | 
| 640 | 
            -
                def from_pretrained_2d( | 
|  | |
|  | |
|  | |
| 641 | 
             
                    if subfolder is not None:
         | 
| 642 | 
             
                        pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
         | 
| 643 | 
             
                    print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
         | 
| @@ -649,16 +664,73 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 649 | 
             
                        config = json.load(f)
         | 
| 650 |  | 
| 651 | 
             
                    from diffusers.utils import WEIGHTS_NAME
         | 
| 652 | 
            -
                    model = cls.from_config(config, **transformer_additional_kwargs)
         | 
| 653 | 
             
                    model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
         | 
| 654 | 
             
                    model_file_safetensors = model_file.replace(".bin", ".safetensors")
         | 
| 655 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 656 | 
             
                        from safetensors.torch import load_file, safe_open
         | 
| 657 | 
             
                        state_dict = load_file(model_file_safetensors)
         | 
| 658 | 
             
                    else:
         | 
| 659 | 
            -
                         | 
| 660 | 
            -
             | 
| 661 | 
            -
                        state_dict =  | 
|  | |
|  | |
|  | |
|  | |
| 662 |  | 
| 663 | 
             
                    if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
         | 
| 664 | 
             
                        new_shape   = model.state_dict()['pos_embed.proj.weight'].size()
         | 
| @@ -692,6 +764,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin): | |
| 692 | 
             
                    params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
         | 
| 693 | 
             
                    print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
         | 
| 694 |  | 
|  | |
| 695 | 
             
                    return model
         | 
| 696 |  | 
| 697 | 
             
            class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
         | 
| @@ -769,6 +842,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): | |
| 769 | 
             
                    after_norm = False,
         | 
| 770 | 
             
                    resize_inpaint_mask_directly: bool = False,
         | 
| 771 | 
             
                    enable_clip_in_inpaint: bool = True,
         | 
|  | |
| 772 | 
             
                    enable_text_attention_mask: bool = True,
         | 
| 773 | 
             
                    add_noise_in_inpaint_model: bool = False,
         | 
| 774 | 
             
                ):
         | 
| @@ -909,6 +983,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): | |
| 909 | 
             
                    control_latents: torch.Tensor = None,
         | 
| 910 | 
             
                    clip_encoder_hidden_states: Optional[torch.Tensor]=None,
         | 
| 911 | 
             
                    clip_attention_mask: Optional[torch.Tensor]=None,
         | 
|  | |
| 912 | 
             
                    return_dict=True,
         | 
| 913 | 
             
                ):
         | 
| 914 | 
             
                    """
         | 
| @@ -1085,7 +1160,10 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1085 | 
             
                    return Transformer2DModelOutput(sample=output)
         | 
| 1086 |  | 
| 1087 | 
             
                @classmethod
         | 
| 1088 | 
            -
                def from_pretrained_2d( | 
|  | |
|  | |
|  | |
| 1089 | 
             
                    if subfolder is not None:
         | 
| 1090 | 
             
                        pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
         | 
| 1091 | 
             
                    print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
         | 
| @@ -1097,16 +1175,73 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1097 | 
             
                        config = json.load(f)
         | 
| 1098 |  | 
| 1099 | 
             
                    from diffusers.utils import WEIGHTS_NAME
         | 
| 1100 | 
            -
                    model = cls.from_config(config, **transformer_additional_kwargs)
         | 
| 1101 | 
             
                    model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
         | 
| 1102 | 
             
                    model_file_safetensors = model_file.replace(".bin", ".safetensors")
         | 
| 1103 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1104 | 
             
                        from safetensors.torch import load_file, safe_open
         | 
| 1105 | 
             
                        state_dict = load_file(model_file_safetensors)
         | 
| 1106 | 
             
                    else:
         | 
| 1107 | 
            -
                         | 
| 1108 | 
            -
             | 
| 1109 | 
            -
                        state_dict =  | 
|  | |
|  | |
|  | |
|  | |
| 1110 |  | 
| 1111 | 
             
                    if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
         | 
| 1112 | 
             
                        new_shape   = model.state_dict()['pos_embed.proj.weight'].size()
         | 
| @@ -1156,6 +1291,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1156 | 
             
                    params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
         | 
| 1157 | 
             
                    print(f"### attn1 Parameters: {sum(params) / 1e6} M")
         | 
| 1158 |  | 
|  | |
| 1159 | 
             
                    return model
         | 
| 1160 |  | 
| 1161 | 
             
            class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
         | 
| @@ -1178,8 +1314,11 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1178 | 
             
                    timestep_activation_fn: str = "silu",
         | 
| 1179 | 
             
                    freq_shift: int = 0,
         | 
| 1180 | 
             
                    num_layers: int = 30,
         | 
|  | |
|  | |
| 1181 | 
             
                    dropout: float = 0.0,
         | 
| 1182 | 
             
                    time_embed_dim: int = 512,
         | 
|  | |
| 1183 | 
             
                    text_embed_dim: int = 4096,
         | 
| 1184 | 
             
                    text_embed_dim_t5: int = 4096,
         | 
| 1185 | 
             
                    norm_eps: float = 1e-5,
         | 
| @@ -1191,8 +1330,10 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1191 | 
             
                    after_norm = False,
         | 
| 1192 | 
             
                    resize_inpaint_mask_directly: bool = False,
         | 
| 1193 | 
             
                    enable_clip_in_inpaint: bool = True,
         | 
|  | |
| 1194 | 
             
                    enable_text_attention_mask: bool = True,
         | 
| 1195 | 
             
                    add_noise_in_inpaint_model: bool = False,
         | 
|  | |
| 1196 | 
             
                ):
         | 
| 1197 | 
             
                    super().__init__()
         | 
| 1198 | 
             
                    self.num_heads = num_attention_heads
         | 
| @@ -1211,8 +1352,20 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1211 | 
             
                    self.proj = nn.Conv2d(
         | 
| 1212 | 
             
                        in_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
         | 
| 1213 | 
             
                    )
         | 
| 1214 | 
            -
                     | 
| 1215 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1216 |  | 
| 1217 | 
             
                    if ref_channels is not None:
         | 
| 1218 | 
             
                        self.ref_proj = nn.Conv2d(
         | 
| @@ -1224,23 +1377,45 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1224 |  | 
| 1225 | 
             
                    if clip_channels is not None:
         | 
| 1226 | 
             
                        self.clip_proj = nn.Linear(clip_channels, self.inner_dim)
         | 
| 1227 | 
            -
             | 
| 1228 | 
            -
                    self. | 
| 1229 | 
            -
             | 
| 1230 | 
            -
             | 
| 1231 | 
            -
             | 
| 1232 | 
            -
                                 | 
| 1233 | 
            -
             | 
| 1234 | 
            -
             | 
| 1235 | 
            -
             | 
| 1236 | 
            -
             | 
| 1237 | 
            -
             | 
| 1238 | 
            -
             | 
| 1239 | 
            -
             | 
| 1240 | 
            -
             | 
| 1241 | 
            -
             | 
| 1242 | 
            -
             | 
| 1243 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1244 | 
             
                    self.norm_final = nn.LayerNorm(self.inner_dim, norm_eps, norm_elementwise_affine)
         | 
| 1245 |  | 
| 1246 | 
             
                    # 5. Output blocks
         | 
| @@ -1275,6 +1450,7 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1275 | 
             
                    ref_latents: Optional[torch.Tensor] = None,
         | 
| 1276 | 
             
                    clip_encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 1277 | 
             
                    clip_attention_mask: Optional[torch.Tensor] = None,
         | 
|  | |
| 1278 | 
             
                    return_dict=True,
         | 
| 1279 | 
             
                ):
         | 
| 1280 | 
             
                    batch_size, channels, video_length, height, width = hidden_states.size()
         | 
| @@ -1343,6 +1519,9 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1343 | 
             
                                encoder_hidden_states,
         | 
| 1344 | 
             
                                temb,
         | 
| 1345 | 
             
                                image_rotary_emb,
         | 
|  | |
|  | |
|  | |
| 1346 | 
             
                                **ckpt_kwargs,
         | 
| 1347 | 
             
                            )
         | 
| 1348 | 
             
                        else:
         | 
| @@ -1351,6 +1530,9 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1351 | 
             
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 1352 | 
             
                                temb=temb,
         | 
| 1353 | 
             
                                image_rotary_emb=image_rotary_emb,
         | 
|  | |
|  | |
|  | |
| 1354 | 
             
                            )
         | 
| 1355 |  | 
| 1356 | 
             
                    hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
         | 
| @@ -1371,7 +1553,10 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1371 | 
             
                    return Transformer2DModelOutput(sample=output)
         | 
| 1372 |  | 
| 1373 | 
             
                @classmethod
         | 
| 1374 | 
            -
                def from_pretrained_2d( | 
|  | |
|  | |
|  | |
| 1375 | 
             
                    if subfolder is not None:
         | 
| 1376 | 
             
                        pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
         | 
| 1377 | 
             
                    print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
         | 
| @@ -1383,9 +1568,60 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1383 | 
             
                        config = json.load(f)
         | 
| 1384 |  | 
| 1385 | 
             
                    from diffusers.utils import WEIGHTS_NAME
         | 
| 1386 | 
            -
                    model = cls.from_config(config, **transformer_additional_kwargs)
         | 
| 1387 | 
             
                    model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
         | 
| 1388 | 
             
                    model_file_safetensors = model_file.replace(".bin", ".safetensors")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1389 | 
             
                    if os.path.exists(model_file):
         | 
| 1390 | 
             
                        state_dict = torch.load(model_file, map_location="cpu")
         | 
| 1391 | 
             
                    elif os.path.exists(model_file_safetensors):
         | 
| @@ -1433,4 +1669,5 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): | |
| 1433 | 
             
                    params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
         | 
| 1434 | 
             
                    print(f"### attn1 Parameters: {sum(params) / 1e6} M")
         | 
| 1435 |  | 
|  | |
| 1436 | 
             
                    return model
         | 
|  | |
| 39 | 
             
            from .attention import (EasyAnimateDiTBlock, HunyuanDiTBlock,
         | 
| 40 | 
             
                                    SelfAttentionTemporalTransformerBlock,
         | 
| 41 | 
             
                                    TemporalTransformerBlock, zero_module)
         | 
| 42 | 
            +
            from .embeddings import (HunyuanCombinedTimestepTextSizeStyleEmbedding,
         | 
| 43 | 
            +
                                     TimePositionalEncoding)
         | 
| 44 | 
            +
            from .norm import AdaLayerNormSingle, EasyAnimateRMSNorm
         | 
| 45 | 
             
            from .patch import (CasualPatchEmbed3D, PatchEmbed3D, PatchEmbedF3D,
         | 
| 46 | 
             
                                TemporalUpsampler3D, UnPatch1D)
         | 
| 47 | 
             
            from .resampler import Resampler
         | 
|  | |
| 143 | 
             
                    norm_eps: float = 1e-5,
         | 
| 144 | 
             
                    attention_type: str = "default",
         | 
| 145 | 
             
                    caption_channels: int = None,
         | 
| 146 | 
            +
                    n_query=8,
         | 
| 147 | 
             
                    # block type
         | 
| 148 | 
             
                    basic_block_type: str = "motionmodule",
         | 
| 149 | 
             
                    # enable_uvit
         | 
|  | |
| 170 | 
             
                    after_norm = False,
         | 
| 171 | 
             
                    resize_inpaint_mask_directly: bool = False,
         | 
| 172 | 
             
                    enable_clip_in_inpaint: bool = True,
         | 
| 173 | 
            +
                    position_of_clip_embedding: str = "head",
         | 
| 174 | 
            +
                    enable_zero_in_inpaint: bool = False,
         | 
| 175 | 
             
                    enable_text_attention_mask: bool = True,
         | 
| 176 | 
             
                    add_noise_in_inpaint_model: bool = False,
         | 
| 177 | 
             
                ):
         | 
|  | |
| 196 | 
             
                    self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size
         | 
| 197 | 
             
                    interpolation_scale = self.config.sample_size // 64  # => 64 (= 512 pixart) has interpolation scale 1
         | 
| 198 | 
             
                    interpolation_scale = max(interpolation_scale, 1)
         | 
| 199 | 
            +
                    self.n_query = n_query
         | 
| 200 |  | 
| 201 | 
             
                    if self.casual_3d:
         | 
| 202 | 
             
                        self.pos_embed = CasualPatchEmbed3D(
         | 
|  | |
| 402 | 
             
                def forward(
         | 
| 403 | 
             
                    self,
         | 
| 404 | 
             
                    hidden_states: torch.Tensor,
         | 
| 405 | 
            +
                    timestep: Optional[torch.LongTensor] = None,
         | 
| 406 | 
            +
                    timestep_cond = None,
         | 
| 407 | 
            +
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 408 | 
            +
                    text_embedding_mask: Optional[torch.Tensor] = None,
         | 
| 409 | 
            +
                    encoder_hidden_states_t5: Optional[torch.Tensor] = None,
         | 
| 410 | 
            +
                    text_embedding_mask_t5: Optional[torch.Tensor] = None,
         | 
| 411 | 
            +
                    image_meta_size = None,
         | 
| 412 | 
            +
                    style = None,
         | 
| 413 | 
            +
                    image_rotary_emb: Optional[torch.Tensor] = None,
         | 
| 414 | 
             
                    inpaint_latents: torch.Tensor = None,
         | 
| 415 | 
             
                    control_latents: torch.Tensor = None,
         | 
|  | |
|  | |
|  | |
| 416 | 
             
                    added_cond_kwargs: Dict[str, torch.Tensor] = None,
         | 
| 417 | 
             
                    class_labels: Optional[torch.LongTensor] = None,
         | 
| 418 | 
             
                    cross_attention_kwargs: Dict[str, Any] = None,
         | 
| 419 | 
             
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 420 | 
            +
                    clip_encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 421 | 
             
                    clip_attention_mask: Optional[torch.Tensor] = None,
         | 
| 422 | 
             
                    return_dict: bool = True,
         | 
| 423 | 
             
                ):
         | 
|  | |
| 443 | 
             
                            An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
         | 
| 444 | 
             
                            is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
         | 
| 445 | 
             
                            negative values to the attention scores corresponding to "discard" tokens.
         | 
| 446 | 
            +
                        text_embedding_mask ( `torch.Tensor`, *optional*):
         | 
| 447 | 
             
                            Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
         | 
| 448 |  | 
| 449 | 
             
                                * Mask `(batch, sequence_length)` True = keep, False = discard.
         | 
|  | |
| 477 | 
             
                        attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
         | 
| 478 | 
             
                        attention_mask = attention_mask.unsqueeze(1)
         | 
| 479 |  | 
| 480 | 
            +
                    text_embedding_mask = text_embedding_mask.squeeze(1)
         | 
| 481 | 
             
                    if clip_attention_mask is not None:
         | 
| 482 | 
            +
                        text_embedding_mask = torch.cat([text_embedding_mask, clip_attention_mask], dim=1)
         | 
| 483 | 
             
                    # convert encoder_attention_mask to a bias the same way we do for attention_mask
         | 
| 484 | 
            +
                    if text_embedding_mask is not None and text_embedding_mask.ndim == 2:
         | 
| 485 | 
            +
                        encoder_attention_mask = (1 - text_embedding_mask.to(encoder_hidden_states.dtype)) * -10000.0
         | 
| 486 | 
             
                        encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
         | 
| 487 |  | 
| 488 | 
             
                    if inpaint_latents is not None:
         | 
|  | |
| 649 | 
             
                    return Transformer3DModelOutput(sample=output)
         | 
| 650 |  | 
| 651 | 
             
                @classmethod
         | 
| 652 | 
            +
                def from_pretrained_2d(
         | 
| 653 | 
            +
                    cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={},
         | 
| 654 | 
            +
                    low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
         | 
| 655 | 
            +
                ):
         | 
| 656 | 
             
                    if subfolder is not None:
         | 
| 657 | 
             
                        pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
         | 
| 658 | 
             
                    print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
         | 
|  | |
| 664 | 
             
                        config = json.load(f)
         | 
| 665 |  | 
| 666 | 
             
                    from diffusers.utils import WEIGHTS_NAME
         | 
|  | |
| 667 | 
             
                    model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
         | 
| 668 | 
             
                    model_file_safetensors = model_file.replace(".bin", ".safetensors")
         | 
| 669 | 
            +
             | 
| 670 | 
            +
                    if low_cpu_mem_usage:
         | 
| 671 | 
            +
                        try:
         | 
| 672 | 
            +
                            import re
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                            from diffusers.models.modeling_utils import \
         | 
| 675 | 
            +
                                load_model_dict_into_meta
         | 
| 676 | 
            +
                            from diffusers.utils import is_accelerate_available
         | 
| 677 | 
            +
                            if is_accelerate_available():
         | 
| 678 | 
            +
                                import accelerate
         | 
| 679 | 
            +
                            
         | 
| 680 | 
            +
                            # Instantiate model with empty weights
         | 
| 681 | 
            +
                            with accelerate.init_empty_weights():
         | 
| 682 | 
            +
                                model = cls.from_config(config, **transformer_additional_kwargs)
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                            param_device = "cpu"
         | 
| 685 | 
            +
                            from safetensors.torch import load_file, safe_open
         | 
| 686 | 
            +
                            state_dict = load_file(model_file_safetensors)
         | 
| 687 | 
            +
                            model._convert_deprecated_attention_blocks(state_dict)
         | 
| 688 | 
            +
                            # move the params from meta device to cpu
         | 
| 689 | 
            +
                            missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
         | 
| 690 | 
            +
                            if len(missing_keys) > 0:
         | 
| 691 | 
            +
                                raise ValueError(
         | 
| 692 | 
            +
                                    f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
         | 
| 693 | 
            +
                                    f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
         | 
| 694 | 
            +
                                    " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
         | 
| 695 | 
            +
                                    " those weights or else make sure your checkpoint file is correct."
         | 
| 696 | 
            +
                                )
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                            unexpected_keys = load_model_dict_into_meta(
         | 
| 699 | 
            +
                                model,
         | 
| 700 | 
            +
                                state_dict,
         | 
| 701 | 
            +
                                device=param_device,
         | 
| 702 | 
            +
                                dtype=torch_dtype,
         | 
| 703 | 
            +
                                model_name_or_path=pretrained_model_path,
         | 
| 704 | 
            +
                            )
         | 
| 705 | 
            +
             | 
| 706 | 
            +
                            if cls._keys_to_ignore_on_load_unexpected is not None:
         | 
| 707 | 
            +
                                for pat in cls._keys_to_ignore_on_load_unexpected:
         | 
| 708 | 
            +
                                    unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                            if len(unexpected_keys) > 0:
         | 
| 711 | 
            +
                                print(
         | 
| 712 | 
            +
                                    f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
         | 
| 713 | 
            +
                                )
         | 
| 714 | 
            +
                            return model
         | 
| 715 | 
            +
                        except Exception as e:
         | 
| 716 | 
            +
                            print(
         | 
| 717 | 
            +
                                f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
         | 
| 718 | 
            +
                            )
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                    model = cls.from_config(config, **transformer_additional_kwargs)
         | 
| 721 | 
            +
                    if os.path.exists(model_file):
         | 
| 722 | 
            +
                        state_dict = torch.load(model_file, map_location="cpu")
         | 
| 723 | 
            +
                    elif os.path.exists(model_file_safetensors):
         | 
| 724 | 
             
                        from safetensors.torch import load_file, safe_open
         | 
| 725 | 
             
                        state_dict = load_file(model_file_safetensors)
         | 
| 726 | 
             
                    else:
         | 
| 727 | 
            +
                        from safetensors.torch import load_file, safe_open
         | 
| 728 | 
            +
                        model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
         | 
| 729 | 
            +
                        state_dict = {}
         | 
| 730 | 
            +
                        for model_file_safetensors in model_files_safetensors:
         | 
| 731 | 
            +
                            _state_dict = load_file(model_file_safetensors)
         | 
| 732 | 
            +
                            for key in _state_dict:
         | 
| 733 | 
            +
                                state_dict[key] = _state_dict[key]
         | 
| 734 |  | 
| 735 | 
             
                    if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
         | 
| 736 | 
             
                        new_shape   = model.state_dict()['pos_embed.proj.weight'].size()
         | 
|  | |
| 764 | 
             
                    params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
         | 
| 765 | 
             
                    print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
         | 
| 766 |  | 
| 767 | 
            +
                    model = model.to(torch_dtype)
         | 
| 768 | 
             
                    return model
         | 
| 769 |  | 
| 770 | 
             
            class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
         | 
|  | |
| 842 | 
             
                    after_norm = False,
         | 
| 843 | 
             
                    resize_inpaint_mask_directly: bool = False,
         | 
| 844 | 
             
                    enable_clip_in_inpaint: bool = True,
         | 
| 845 | 
            +
                    position_of_clip_embedding: str = "full",
         | 
| 846 | 
             
                    enable_text_attention_mask: bool = True,
         | 
| 847 | 
             
                    add_noise_in_inpaint_model: bool = False,
         | 
| 848 | 
             
                ):
         | 
|  | |
| 983 | 
             
                    control_latents: torch.Tensor = None,
         | 
| 984 | 
             
                    clip_encoder_hidden_states: Optional[torch.Tensor]=None,
         | 
| 985 | 
             
                    clip_attention_mask: Optional[torch.Tensor]=None,
         | 
| 986 | 
            +
                    added_cond_kwargs: Dict[str, torch.Tensor] = None,
         | 
| 987 | 
             
                    return_dict=True,
         | 
| 988 | 
             
                ):
         | 
| 989 | 
             
                    """
         | 
|  | |
| 1160 | 
             
                    return Transformer2DModelOutput(sample=output)
         | 
| 1161 |  | 
| 1162 | 
             
                @classmethod
         | 
| 1163 | 
            +
                def from_pretrained_2d(
         | 
| 1164 | 
            +
                    cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
         | 
| 1165 | 
            +
                    low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
         | 
| 1166 | 
            +
                ):
         | 
| 1167 | 
             
                    if subfolder is not None:
         | 
| 1168 | 
             
                        pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
         | 
| 1169 | 
             
                    print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
         | 
|  | |
| 1175 | 
             
                        config = json.load(f)
         | 
| 1176 |  | 
| 1177 | 
             
                    from diffusers.utils import WEIGHTS_NAME
         | 
|  | |
| 1178 | 
             
                    model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
         | 
| 1179 | 
             
                    model_file_safetensors = model_file.replace(".bin", ".safetensors")
         | 
| 1180 | 
            +
             | 
| 1181 | 
            +
                    if low_cpu_mem_usage:
         | 
| 1182 | 
            +
                        try:
         | 
| 1183 | 
            +
                            import re
         | 
| 1184 | 
            +
             | 
| 1185 | 
            +
                            from diffusers.models.modeling_utils import \
         | 
| 1186 | 
            +
                                load_model_dict_into_meta
         | 
| 1187 | 
            +
                            from diffusers.utils import is_accelerate_available
         | 
| 1188 | 
            +
                            if is_accelerate_available():
         | 
| 1189 | 
            +
                                import accelerate
         | 
| 1190 | 
            +
                            
         | 
| 1191 | 
            +
                            # Instantiate model with empty weights
         | 
| 1192 | 
            +
                            with accelerate.init_empty_weights():
         | 
| 1193 | 
            +
                                model = cls.from_config(config, **transformer_additional_kwargs)
         | 
| 1194 | 
            +
             | 
| 1195 | 
            +
                            param_device = "cpu"
         | 
| 1196 | 
            +
                            from safetensors.torch import load_file, safe_open
         | 
| 1197 | 
            +
                            state_dict = load_file(model_file_safetensors)
         | 
| 1198 | 
            +
                            model._convert_deprecated_attention_blocks(state_dict)
         | 
| 1199 | 
            +
                            # move the params from meta device to cpu
         | 
| 1200 | 
            +
                            missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
         | 
| 1201 | 
            +
                            if len(missing_keys) > 0:
         | 
| 1202 | 
            +
                                raise ValueError(
         | 
| 1203 | 
            +
                                    f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
         | 
| 1204 | 
            +
                                    f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
         | 
| 1205 | 
            +
                                    " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
         | 
| 1206 | 
            +
                                    " those weights or else make sure your checkpoint file is correct."
         | 
| 1207 | 
            +
                                )
         | 
| 1208 | 
            +
             | 
| 1209 | 
            +
                            unexpected_keys = load_model_dict_into_meta(
         | 
| 1210 | 
            +
                                model,
         | 
| 1211 | 
            +
                                state_dict,
         | 
| 1212 | 
            +
                                device=param_device,
         | 
| 1213 | 
            +
                                dtype=torch_dtype,
         | 
| 1214 | 
            +
                                model_name_or_path=pretrained_model_path,
         | 
| 1215 | 
            +
                            )
         | 
| 1216 | 
            +
             | 
| 1217 | 
            +
                            if cls._keys_to_ignore_on_load_unexpected is not None:
         | 
| 1218 | 
            +
                                for pat in cls._keys_to_ignore_on_load_unexpected:
         | 
| 1219 | 
            +
                                    unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
         | 
| 1220 | 
            +
             | 
| 1221 | 
            +
                            if len(unexpected_keys) > 0:
         | 
| 1222 | 
            +
                                print(
         | 
| 1223 | 
            +
                                    f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
         | 
| 1224 | 
            +
                                )
         | 
| 1225 | 
            +
                            return model
         | 
| 1226 | 
            +
                        except Exception as e:
         | 
| 1227 | 
            +
                            print(
         | 
| 1228 | 
            +
                                f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
         | 
| 1229 | 
            +
                            )
         | 
| 1230 | 
            +
             | 
| 1231 | 
            +
                    model = cls.from_config(config, **transformer_additional_kwargs)
         | 
| 1232 | 
            +
                    if os.path.exists(model_file):
         | 
| 1233 | 
            +
                        state_dict = torch.load(model_file, map_location="cpu")
         | 
| 1234 | 
            +
                    elif os.path.exists(model_file_safetensors):
         | 
| 1235 | 
             
                        from safetensors.torch import load_file, safe_open
         | 
| 1236 | 
             
                        state_dict = load_file(model_file_safetensors)
         | 
| 1237 | 
             
                    else:
         | 
| 1238 | 
            +
                        from safetensors.torch import load_file, safe_open
         | 
| 1239 | 
            +
                        model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
         | 
| 1240 | 
            +
                        state_dict = {}
         | 
| 1241 | 
            +
                        for model_file_safetensors in model_files_safetensors:
         | 
| 1242 | 
            +
                            _state_dict = load_file(model_file_safetensors)
         | 
| 1243 | 
            +
                            for key in _state_dict:
         | 
| 1244 | 
            +
                                state_dict[key] = _state_dict[key]
         | 
| 1245 |  | 
| 1246 | 
             
                    if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
         | 
| 1247 | 
             
                        new_shape   = model.state_dict()['pos_embed.proj.weight'].size()
         | 
|  | |
| 1291 | 
             
                    params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
         | 
| 1292 | 
             
                    print(f"### attn1 Parameters: {sum(params) / 1e6} M")
         | 
| 1293 |  | 
| 1294 | 
            +
                    model = model.to(torch_dtype)
         | 
| 1295 | 
             
                    return model
         | 
| 1296 |  | 
| 1297 | 
             
            class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
         | 
|  | |
| 1314 | 
             
                    timestep_activation_fn: str = "silu",
         | 
| 1315 | 
             
                    freq_shift: int = 0,
         | 
| 1316 | 
             
                    num_layers: int = 30,
         | 
| 1317 | 
            +
                    mmdit_layers: int = 10000,
         | 
| 1318 | 
            +
                    swa_layers: list = None,
         | 
| 1319 | 
             
                    dropout: float = 0.0,
         | 
| 1320 | 
             
                    time_embed_dim: int = 512,
         | 
| 1321 | 
            +
                    add_norm_text_encoder: bool = False,
         | 
| 1322 | 
             
                    text_embed_dim: int = 4096,
         | 
| 1323 | 
             
                    text_embed_dim_t5: int = 4096,
         | 
| 1324 | 
             
                    norm_eps: float = 1e-5,
         | 
|  | |
| 1330 | 
             
                    after_norm = False,
         | 
| 1331 | 
             
                    resize_inpaint_mask_directly: bool = False,
         | 
| 1332 | 
             
                    enable_clip_in_inpaint: bool = True,
         | 
| 1333 | 
            +
                    position_of_clip_embedding: str = "full",
         | 
| 1334 | 
             
                    enable_text_attention_mask: bool = True,
         | 
| 1335 | 
             
                    add_noise_in_inpaint_model: bool = False,
         | 
| 1336 | 
            +
                    add_ref_latent_in_control_model: bool = False,
         | 
| 1337 | 
             
                ):
         | 
| 1338 | 
             
                    super().__init__()
         | 
| 1339 | 
             
                    self.num_heads = num_attention_heads
         | 
|  | |
| 1352 | 
             
                    self.proj = nn.Conv2d(
         | 
| 1353 | 
             
                        in_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
         | 
| 1354 | 
             
                    )
         | 
| 1355 | 
            +
                    if not add_norm_text_encoder:
         | 
| 1356 | 
            +
                        self.text_proj = nn.Linear(text_embed_dim, self.inner_dim)
         | 
| 1357 | 
            +
                        if text_embed_dim_t5 is not None:
         | 
| 1358 | 
            +
                            self.text_proj_t5 = nn.Linear(text_embed_dim_t5, self.inner_dim)
         | 
| 1359 | 
            +
                    else:
         | 
| 1360 | 
            +
                        self.text_proj = nn.Sequential(
         | 
| 1361 | 
            +
                            EasyAnimateRMSNorm(text_embed_dim),
         | 
| 1362 | 
            +
                            nn.Linear(text_embed_dim, self.inner_dim)
         | 
| 1363 | 
            +
                        )
         | 
| 1364 | 
            +
                        if text_embed_dim_t5 is not None:
         | 
| 1365 | 
            +
                            self.text_proj_t5 = nn.Sequential(
         | 
| 1366 | 
            +
                                EasyAnimateRMSNorm(text_embed_dim),
         | 
| 1367 | 
            +
                                nn.Linear(text_embed_dim_t5, self.inner_dim)
         | 
| 1368 | 
            +
                            )
         | 
| 1369 |  | 
| 1370 | 
             
                    if ref_channels is not None:
         | 
| 1371 | 
             
                        self.ref_proj = nn.Conv2d(
         | 
|  | |
| 1377 |  | 
| 1378 | 
             
                    if clip_channels is not None:
         | 
| 1379 | 
             
                        self.clip_proj = nn.Linear(clip_channels, self.inner_dim)
         | 
| 1380 | 
            +
                    
         | 
| 1381 | 
            +
                    self.swa_layers = swa_layers
         | 
| 1382 | 
            +
                    if swa_layers is not None:
         | 
| 1383 | 
            +
                        self.transformer_blocks = nn.ModuleList(
         | 
| 1384 | 
            +
                            [
         | 
| 1385 | 
            +
                                EasyAnimateDiTBlock(
         | 
| 1386 | 
            +
                                    dim=self.inner_dim,
         | 
| 1387 | 
            +
                                    num_attention_heads=num_attention_heads,
         | 
| 1388 | 
            +
                                    attention_head_dim=attention_head_dim,
         | 
| 1389 | 
            +
                                    time_embed_dim=time_embed_dim,
         | 
| 1390 | 
            +
                                    dropout=dropout,
         | 
| 1391 | 
            +
                                    activation_fn=activation_fn,
         | 
| 1392 | 
            +
                                    norm_elementwise_affine=norm_elementwise_affine,
         | 
| 1393 | 
            +
                                    norm_eps=norm_eps,
         | 
| 1394 | 
            +
                                    after_norm=after_norm,
         | 
| 1395 | 
            +
                                    is_mmdit_block=True if index < mmdit_layers else False,
         | 
| 1396 | 
            +
                                    is_swa=True if index in swa_layers else False,
         | 
| 1397 | 
            +
                                )
         | 
| 1398 | 
            +
                                for index in range(num_layers)
         | 
| 1399 | 
            +
                            ]
         | 
| 1400 | 
            +
                        )
         | 
| 1401 | 
            +
                    else:
         | 
| 1402 | 
            +
                        self.transformer_blocks = nn.ModuleList(
         | 
| 1403 | 
            +
                            [
         | 
| 1404 | 
            +
                                EasyAnimateDiTBlock(
         | 
| 1405 | 
            +
                                    dim=self.inner_dim,
         | 
| 1406 | 
            +
                                    num_attention_heads=num_attention_heads,
         | 
| 1407 | 
            +
                                    attention_head_dim=attention_head_dim,
         | 
| 1408 | 
            +
                                    time_embed_dim=time_embed_dim,
         | 
| 1409 | 
            +
                                    dropout=dropout,
         | 
| 1410 | 
            +
                                    activation_fn=activation_fn,
         | 
| 1411 | 
            +
                                    norm_elementwise_affine=norm_elementwise_affine,
         | 
| 1412 | 
            +
                                    norm_eps=norm_eps,
         | 
| 1413 | 
            +
                                    after_norm=after_norm,
         | 
| 1414 | 
            +
                                    is_mmdit_block=True if _ < mmdit_layers else False,
         | 
| 1415 | 
            +
                                )
         | 
| 1416 | 
            +
                                for _ in range(num_layers)
         | 
| 1417 | 
            +
                            ]
         | 
| 1418 | 
            +
                        )
         | 
| 1419 | 
             
                    self.norm_final = nn.LayerNorm(self.inner_dim, norm_eps, norm_elementwise_affine)
         | 
| 1420 |  | 
| 1421 | 
             
                    # 5. Output blocks
         | 
|  | |
| 1450 | 
             
                    ref_latents: Optional[torch.Tensor] = None,
         | 
| 1451 | 
             
                    clip_encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 1452 | 
             
                    clip_attention_mask: Optional[torch.Tensor] = None,
         | 
| 1453 | 
            +
                    added_cond_kwargs: Dict[str, torch.Tensor] = None,
         | 
| 1454 | 
             
                    return_dict=True,
         | 
| 1455 | 
             
                ):
         | 
| 1456 | 
             
                    batch_size, channels, video_length, height, width = hidden_states.size()
         | 
|  | |
| 1519 | 
             
                                encoder_hidden_states,
         | 
| 1520 | 
             
                                temb,
         | 
| 1521 | 
             
                                image_rotary_emb,
         | 
| 1522 | 
            +
                                video_length,
         | 
| 1523 | 
            +
                                height // self.patch_size,
         | 
| 1524 | 
            +
                                width // self.patch_size,
         | 
| 1525 | 
             
                                **ckpt_kwargs,
         | 
| 1526 | 
             
                            )
         | 
| 1527 | 
             
                        else:
         | 
|  | |
| 1530 | 
             
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 1531 | 
             
                                temb=temb,
         | 
| 1532 | 
             
                                image_rotary_emb=image_rotary_emb,
         | 
| 1533 | 
            +
                                num_frames=video_length,
         | 
| 1534 | 
            +
                                height=height // self.patch_size,
         | 
| 1535 | 
            +
                                width=width // self.patch_size
         | 
| 1536 | 
             
                            )
         | 
| 1537 |  | 
| 1538 | 
             
                    hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
         | 
|  | |
| 1553 | 
             
                    return Transformer2DModelOutput(sample=output)
         | 
| 1554 |  | 
| 1555 | 
             
                @classmethod
         | 
| 1556 | 
            +
                def from_pretrained_2d(
         | 
| 1557 | 
            +
                    cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
         | 
| 1558 | 
            +
                    low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
         | 
| 1559 | 
            +
                ):
         | 
| 1560 | 
             
                    if subfolder is not None:
         | 
| 1561 | 
             
                        pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
         | 
| 1562 | 
             
                    print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
         | 
|  | |
| 1568 | 
             
                        config = json.load(f)
         | 
| 1569 |  | 
| 1570 | 
             
                    from diffusers.utils import WEIGHTS_NAME
         | 
|  | |
| 1571 | 
             
                    model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
         | 
| 1572 | 
             
                    model_file_safetensors = model_file.replace(".bin", ".safetensors")
         | 
| 1573 | 
            +
             | 
| 1574 | 
            +
                    if low_cpu_mem_usage:
         | 
| 1575 | 
            +
                        try:
         | 
| 1576 | 
            +
                            import re
         | 
| 1577 | 
            +
             | 
| 1578 | 
            +
                            from diffusers.models.modeling_utils import \
         | 
| 1579 | 
            +
                                load_model_dict_into_meta
         | 
| 1580 | 
            +
                            from diffusers.utils import is_accelerate_available
         | 
| 1581 | 
            +
                            if is_accelerate_available():
         | 
| 1582 | 
            +
                                import accelerate
         | 
| 1583 | 
            +
                            
         | 
| 1584 | 
            +
                            # Instantiate model with empty weights
         | 
| 1585 | 
            +
                            with accelerate.init_empty_weights():
         | 
| 1586 | 
            +
                                model = cls.from_config(config, **transformer_additional_kwargs)
         | 
| 1587 | 
            +
             | 
| 1588 | 
            +
                            param_device = "cpu"
         | 
| 1589 | 
            +
                            from safetensors.torch import load_file, safe_open
         | 
| 1590 | 
            +
                            state_dict = load_file(model_file_safetensors)
         | 
| 1591 | 
            +
                            model._convert_deprecated_attention_blocks(state_dict)
         | 
| 1592 | 
            +
                            # move the params from meta device to cpu
         | 
| 1593 | 
            +
                            missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
         | 
| 1594 | 
            +
                            if len(missing_keys) > 0:
         | 
| 1595 | 
            +
                                raise ValueError(
         | 
| 1596 | 
            +
                                    f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
         | 
| 1597 | 
            +
                                    f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
         | 
| 1598 | 
            +
                                    " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
         | 
| 1599 | 
            +
                                    " those weights or else make sure your checkpoint file is correct."
         | 
| 1600 | 
            +
                                )
         | 
| 1601 | 
            +
             | 
| 1602 | 
            +
                            unexpected_keys = load_model_dict_into_meta(
         | 
| 1603 | 
            +
                                model,
         | 
| 1604 | 
            +
                                state_dict,
         | 
| 1605 | 
            +
                                device=param_device,
         | 
| 1606 | 
            +
                                dtype=torch_dtype,
         | 
| 1607 | 
            +
                                model_name_or_path=pretrained_model_path,
         | 
| 1608 | 
            +
                            )
         | 
| 1609 | 
            +
             | 
| 1610 | 
            +
                            if cls._keys_to_ignore_on_load_unexpected is not None:
         | 
| 1611 | 
            +
                                for pat in cls._keys_to_ignore_on_load_unexpected:
         | 
| 1612 | 
            +
                                    unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
         | 
| 1613 | 
            +
             | 
| 1614 | 
            +
                            if len(unexpected_keys) > 0:
         | 
| 1615 | 
            +
                                print(
         | 
| 1616 | 
            +
                                    f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
         | 
| 1617 | 
            +
                                )
         | 
| 1618 | 
            +
                            return model
         | 
| 1619 | 
            +
                        except Exception as e:
         | 
| 1620 | 
            +
                            print(
         | 
| 1621 | 
            +
                                f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
         | 
| 1622 | 
            +
                            )
         | 
| 1623 | 
            +
             | 
| 1624 | 
            +
                    model = cls.from_config(config, **transformer_additional_kwargs)
         | 
| 1625 | 
             
                    if os.path.exists(model_file):
         | 
| 1626 | 
             
                        state_dict = torch.load(model_file, map_location="cpu")
         | 
| 1627 | 
             
                    elif os.path.exists(model_file_safetensors):
         | 
|  | |
| 1669 | 
             
                    params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
         | 
| 1670 | 
             
                    print(f"### attn1 Parameters: {sum(params) / 1e6} M")
         | 
| 1671 |  | 
| 1672 | 
            +
                    model = model.to(torch_dtype)
         | 
| 1673 | 
             
                    return model
         | 
    	
        easyanimate/pipeline/pipeline_easyanimate.py
    CHANGED
    
    | @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            # Copyright  | 
| 2 | 
             
            #
         | 
| 3 | 
             
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
             
            # you may not use this file except in compliance with the License.
         | 
| @@ -12,61 +12,113 @@ | |
| 12 | 
             
            # See the License for the specific language governing permissions and
         | 
| 13 | 
             
            # limitations under the License.
         | 
| 14 |  | 
| 15 | 
            -
            import copy
         | 
| 16 | 
            -
            import html
         | 
| 17 | 
             
            import inspect
         | 
| 18 | 
            -
            import re
         | 
| 19 | 
            -
            import urllib.parse as ul
         | 
| 20 | 
             
            from dataclasses import dataclass
         | 
| 21 | 
            -
            from typing import Callable, List, Optional, Tuple, Union
         | 
| 22 |  | 
| 23 | 
             
            import numpy as np
         | 
| 24 | 
             
            import torch
         | 
| 25 | 
            -
             | 
|  | |
|  | |
| 26 | 
             
            from diffusers.image_processor import VaeImageProcessor
         | 
| 27 | 
            -
            from diffusers.models import AutoencoderKL
         | 
| 28 | 
            -
            from diffusers. | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 29 | 
             
            from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
         | 
| 30 | 
            -
                                         is_bs4_available, is_ftfy_available, | 
|  | |
| 31 | 
             
                                         replace_example_docstring)
         | 
| 32 | 
             
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 33 | 
             
            from einops import rearrange
         | 
|  | |
| 34 | 
             
            from tqdm import tqdm
         | 
| 35 | 
            -
            from transformers import  | 
|  | |
|  | |
| 36 |  | 
| 37 | 
            -
            from ..models | 
|  | |
| 38 |  | 
| 39 | 
            -
             | 
|  | |
| 40 |  | 
| 41 | 
            -
             | 
| 42 | 
            -
             | 
|  | |
| 43 |  | 
| 44 | 
            -
            if is_ftfy_available():
         | 
| 45 | 
            -
                import ftfy
         | 
| 46 |  | 
|  | |
| 47 |  | 
| 48 | 
             
            EXAMPLE_DOC_STRING = """
         | 
| 49 | 
             
                Examples:
         | 
| 50 | 
            -
                    ``` | 
| 51 | 
             
                    >>> import torch
         | 
| 52 | 
             
                    >>> from diffusers import EasyAnimatePipeline
         | 
| 53 | 
            -
             | 
| 54 | 
            -
             | 
| 55 | 
            -
                    >>>  | 
| 56 | 
            -
                    >>>  | 
| 57 | 
            -
                    >>>  | 
| 58 | 
            -
             | 
| 59 | 
            -
                     | 
| 60 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 61 | 
             
                    ```
         | 
| 62 | 
             
            """
         | 
| 63 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 64 | 
             
            # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
         | 
| 65 | 
             
            def retrieve_timesteps(
         | 
| 66 | 
             
                scheduler,
         | 
| 67 | 
             
                num_inference_steps: Optional[int] = None,
         | 
| 68 | 
             
                device: Optional[Union[str, torch.device]] = None,
         | 
| 69 | 
             
                timesteps: Optional[List[int]] = None,
         | 
|  | |
| 70 | 
             
                **kwargs,
         | 
| 71 | 
             
            ):
         | 
| 72 | 
             
                """
         | 
| @@ -77,19 +129,23 @@ def retrieve_timesteps( | |
| 77 | 
             
                    scheduler (`SchedulerMixin`):
         | 
| 78 | 
             
                        The scheduler to get timesteps from.
         | 
| 79 | 
             
                    num_inference_steps (`int`):
         | 
| 80 | 
            -
                        The number of diffusion steps used when generating samples with a pre-trained model. If used,
         | 
| 81 | 
            -
                         | 
| 82 | 
             
                    device (`str` or `torch.device`, *optional*):
         | 
| 83 | 
             
                        The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
         | 
| 84 | 
             
                    timesteps (`List[int]`, *optional*):
         | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
|  | |
|  | |
| 88 |  | 
| 89 | 
             
                Returns:
         | 
| 90 | 
             
                    `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
         | 
| 91 | 
             
                    second element is the number of inference steps.
         | 
| 92 | 
             
                """
         | 
|  | |
|  | |
| 93 | 
             
                if timesteps is not None:
         | 
| 94 | 
             
                    accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
         | 
| 95 | 
             
                    if not accepts_timesteps:
         | 
| @@ -100,86 +156,113 @@ def retrieve_timesteps( | |
| 100 | 
             
                    scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
         | 
| 101 | 
             
                    timesteps = scheduler.timesteps
         | 
| 102 | 
             
                    num_inference_steps = len(timesteps)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 103 | 
             
                else:
         | 
| 104 | 
             
                    scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
         | 
| 105 | 
             
                    timesteps = scheduler.timesteps
         | 
| 106 | 
             
                return timesteps, num_inference_steps
         | 
| 107 |  | 
| 108 | 
            -
            @dataclass
         | 
| 109 | 
            -
            class EasyAnimatePipelineOutput(BaseOutput):
         | 
| 110 | 
            -
                videos: Union[torch.Tensor, np.ndarray]
         | 
| 111 |  | 
| 112 | 
             
            class EasyAnimatePipeline(DiffusionPipeline):
         | 
| 113 | 
             
                r"""
         | 
| 114 | 
            -
                Pipeline for text-to- | 
| 115 |  | 
| 116 | 
             
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
         | 
| 117 | 
             
                library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
         | 
| 118 |  | 
|  | |
|  | |
|  | |
|  | |
| 119 | 
             
                Args:
         | 
| 120 | 
            -
                    vae ([` | 
| 121 | 
            -
                        Variational Auto-Encoder (VAE) Model to encode and decode  | 
| 122 | 
            -
                    text_encoder ([` | 
| 123 | 
            -
                         | 
| 124 | 
            -
                        [ | 
| 125 | 
            -
             | 
| 126 | 
            -
             | 
| 127 | 
            -
             | 
| 128 | 
            -
                         | 
| 129 | 
            -
                     | 
| 130 | 
            -
                         | 
| 131 | 
            -
             | 
| 132 | 
            -
             | 
|  | |
|  | |
|  | |
| 133 | 
             
                """
         | 
| 134 | 
            -
                bad_punct_regex = re.compile(
         | 
| 135 | 
            -
                    r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
         | 
| 136 | 
            -
                )  # noqa
         | 
| 137 |  | 
| 138 | 
            -
                 | 
| 139 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 140 |  | 
| 141 | 
             
                def __init__(
         | 
| 142 | 
             
                    self,
         | 
| 143 | 
            -
                     | 
| 144 | 
            -
                    text_encoder:  | 
| 145 | 
            -
                     | 
| 146 | 
            -
                     | 
| 147 | 
            -
                     | 
|  | |
|  | |
| 148 | 
             
                ):
         | 
| 149 | 
             
                    super().__init__()
         | 
| 150 |  | 
| 151 | 
             
                    self.register_modules(
         | 
| 152 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 153 | 
             
                    )
         | 
| 154 |  | 
| 155 | 
             
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         | 
| 156 | 
            -
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         | 
| 157 | 
            -
                    self.enable_autocast_float8_transformer_flag = False
         | 
| 158 | 
            -
                    
         | 
| 159 | 
            -
                # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
         | 
| 160 | 
            -
                def mask_text_embeddings(self, emb, mask):
         | 
| 161 | 
            -
                    if emb.shape[0] == 1:
         | 
| 162 | 
            -
                        keep_index = mask.sum().item()
         | 
| 163 | 
            -
                        return emb[:, :, :keep_index, :], keep_index
         | 
| 164 | 
            -
                    else:
         | 
| 165 | 
            -
                        masked_feature = emb * mask[:, None, :, None]
         | 
| 166 | 
            -
                        return masked_feature, emb.shape[2]
         | 
| 167 |  | 
| 168 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 169 | 
             
                def encode_prompt(
         | 
| 170 | 
             
                    self,
         | 
| 171 | 
            -
                    prompt:  | 
| 172 | 
            -
                     | 
| 173 | 
            -
                     | 
| 174 | 
             
                    num_images_per_prompt: int = 1,
         | 
| 175 | 
            -
                     | 
| 176 | 
            -
                     | 
| 177 | 
            -
                     | 
| 178 | 
            -
                     | 
| 179 | 
            -
                     | 
| 180 | 
            -
                     | 
| 181 | 
            -
                    max_sequence_length: int =  | 
| 182 | 
            -
                     | 
|  | |
| 183 | 
             
                ):
         | 
| 184 | 
             
                    r"""
         | 
| 185 | 
             
                    Encodes the prompt into text encoder hidden states.
         | 
| @@ -187,33 +270,46 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 187 | 
             
                    Args:
         | 
| 188 | 
             
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 189 | 
             
                            prompt to be encoded
         | 
| 190 | 
            -
                         | 
| 191 | 
            -
                             | 
| 192 | 
            -
             | 
| 193 | 
            -
                             | 
| 194 | 
            -
                         | 
| 195 | 
            -
                            whether to use classifier free guidance or not
         | 
| 196 | 
            -
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 197 | 
             
                            number of images that should be generated per prompt
         | 
| 198 | 
            -
                         | 
| 199 | 
            -
                             | 
| 200 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
| 201 | 
             
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         | 
| 202 | 
             
                            provided, text embeddings will be generated from `prompt` input argument.
         | 
| 203 | 
            -
                        negative_prompt_embeds (`torch. | 
| 204 | 
            -
                            Pre-generated negative text embeddings.  | 
| 205 | 
            -
                             | 
| 206 | 
            -
             | 
| 207 | 
            -
             | 
| 208 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 209 | 
             
                    """
         | 
|  | |
|  | |
| 210 |  | 
| 211 | 
            -
                     | 
| 212 | 
            -
             | 
| 213 | 
            -
                        deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 214 |  | 
| 215 | 
            -
                    if  | 
| 216 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 217 |  | 
| 218 | 
             
                    if prompt is not None and isinstance(prompt, str):
         | 
| 219 | 
             
                        batch_size = 1
         | 
| @@ -222,74 +318,199 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 222 | 
             
                    else:
         | 
| 223 | 
             
                        batch_size = prompt_embeds.shape[0]
         | 
| 224 |  | 
| 225 | 
            -
                    # See Section 3.1. of the paper.
         | 
| 226 | 
            -
                    max_length = max_sequence_length
         | 
| 227 | 
            -
             | 
| 228 | 
             
                    if prompt_embeds is None:
         | 
| 229 | 
            -
                         | 
| 230 | 
            -
             | 
| 231 | 
            -
             | 
| 232 | 
            -
             | 
| 233 | 
            -
             | 
| 234 | 
            -
             | 
| 235 | 
            -
             | 
| 236 | 
            -
             | 
| 237 | 
            -
             | 
| 238 | 
            -
             | 
| 239 | 
            -
             | 
| 240 | 
            -
             | 
| 241 | 
            -
             | 
| 242 | 
            -
             | 
| 243 | 
            -
             | 
| 244 | 
            -
             | 
| 245 | 
            -
             | 
| 246 | 
            -
             | 
| 247 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 248 | 
             
                            )
         | 
| 249 |  | 
| 250 | 
            -
             | 
| 251 | 
            -
             | 
| 252 | 
            -
             | 
| 253 | 
            -
             | 
| 254 | 
            -
             | 
| 255 | 
            -
             | 
| 256 | 
            -
             | 
| 257 | 
            -
             | 
| 258 | 
            -
             | 
| 259 | 
            -
             | 
| 260 | 
            -
             | 
| 261 | 
            -
             | 
| 262 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 263 | 
             
                    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
         | 
| 264 |  | 
| 265 | 
             
                    bs_embed, seq_len, _ = prompt_embeds.shape
         | 
| 266 | 
            -
                    # duplicate text embeddings  | 
| 267 | 
             
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 268 | 
             
                    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 269 | 
            -
                    prompt_attention_mask = prompt_attention_mask. | 
| 270 | 
            -
                    prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 271 |  | 
| 272 | 
             
                    # get unconditional embeddings for classifier free guidance
         | 
| 273 | 
             
                    if do_classifier_free_guidance and negative_prompt_embeds is None:
         | 
| 274 | 
            -
                         | 
| 275 | 
            -
             | 
| 276 | 
            -
             | 
| 277 | 
            -
             | 
| 278 | 
            -
                             | 
| 279 | 
            -
             | 
| 280 | 
            -
             | 
| 281 | 
            -
             | 
| 282 | 
            -
             | 
| 283 | 
            -
                             | 
| 284 | 
            -
             | 
| 285 | 
            -
             | 
| 286 | 
            -
             | 
| 287 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 288 |  | 
| 289 | 
            -
             | 
| 290 | 
            -
             | 
| 291 | 
            -
             | 
| 292 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 293 |  | 
| 294 | 
             
                    if do_classifier_free_guidance:
         | 
| 295 | 
             
                        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
         | 
| @@ -299,14 +520,9 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 299 |  | 
| 300 | 
             
                        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 301 | 
             
                        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
         | 
| 302 | 
            -
             | 
| 303 | 
            -
                         | 
| 304 | 
            -
             | 
| 305 | 
            -
                    else:
         | 
| 306 | 
            -
                        negative_prompt_embeds = None
         | 
| 307 | 
            -
                        negative_prompt_attention_mask = None
         | 
| 308 | 
            -
             | 
| 309 | 
            -
                    return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
         | 
| 310 |  | 
| 311 | 
             
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
         | 
| 312 | 
             
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| @@ -331,20 +547,25 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 331 | 
             
                    prompt,
         | 
| 332 | 
             
                    height,
         | 
| 333 | 
             
                    width,
         | 
| 334 | 
            -
                    negative_prompt,
         | 
| 335 | 
            -
                    callback_steps,
         | 
| 336 | 
             
                    prompt_embeds=None,
         | 
| 337 | 
             
                    negative_prompt_embeds=None,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 338 | 
             
                ):
         | 
| 339 | 
            -
                    if height %  | 
| 340 | 
             
                        raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
         | 
| 341 |  | 
| 342 | 
            -
                    if  | 
| 343 | 
            -
                         | 
| 344 | 
             
                    ):
         | 
| 345 | 
             
                        raise ValueError(
         | 
| 346 | 
            -
                            f"` | 
| 347 | 
            -
                            f" {type(callback_steps)}."
         | 
| 348 | 
             
                        )
         | 
| 349 |  | 
| 350 | 
             
                    if prompt is not None and prompt_embeds is not None:
         | 
| @@ -356,14 +577,18 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 356 | 
             
                        raise ValueError(
         | 
| 357 | 
             
                            "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
         | 
| 358 | 
             
                        )
         | 
|  | |
|  | |
|  | |
|  | |
| 359 | 
             
                    elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
         | 
| 360 | 
             
                        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
         | 
| 361 |  | 
| 362 | 
            -
                    if  | 
| 363 | 
            -
                        raise ValueError(
         | 
| 364 | 
            -
             | 
| 365 | 
            -
             | 
| 366 | 
            -
                        )
         | 
| 367 |  | 
| 368 | 
             
                    if negative_prompt is not None and negative_prompt_embeds is not None:
         | 
| 369 | 
             
                        raise ValueError(
         | 
| @@ -371,6 +596,13 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 371 | 
             
                            f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
         | 
| 372 | 
             
                        )
         | 
| 373 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 374 | 
             
                    if prompt_embeds is not None and negative_prompt_embeds is not None:
         | 
| 375 | 
             
                        if prompt_embeds.shape != negative_prompt_embeds.shape:
         | 
| 376 | 
             
                            raise ValueError(
         | 
| @@ -378,153 +610,25 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 378 | 
             
                                f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
         | 
| 379 | 
             
                                f" {negative_prompt_embeds.shape}."
         | 
| 380 | 
             
                            )
         | 
| 381 | 
            -
             | 
| 382 | 
            -
             | 
| 383 | 
            -
             | 
| 384 | 
            -
             | 
| 385 | 
            -
             | 
| 386 | 
            -
             | 
| 387 | 
            -
             | 
| 388 | 
            -
             | 
| 389 | 
            -
                    if clean_caption and not is_ftfy_available():
         | 
| 390 | 
            -
                        logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
         | 
| 391 | 
            -
                        logger.warn("Setting `clean_caption` to False...")
         | 
| 392 | 
            -
                        clean_caption = False
         | 
| 393 | 
            -
             | 
| 394 | 
            -
                    if not isinstance(text, (tuple, list)):
         | 
| 395 | 
            -
                        text = [text]
         | 
| 396 | 
            -
             | 
| 397 | 
            -
                    def process(text: str):
         | 
| 398 | 
            -
                        if clean_caption:
         | 
| 399 | 
            -
                            text = self._clean_caption(text)
         | 
| 400 | 
            -
                            text = self._clean_caption(text)
         | 
| 401 | 
            -
                        else:
         | 
| 402 | 
            -
                            text = text.lower().strip()
         | 
| 403 | 
            -
                        return text
         | 
| 404 | 
            -
             | 
| 405 | 
            -
                    return [process(t) for t in text]
         | 
| 406 | 
            -
             | 
| 407 | 
            -
                # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
         | 
| 408 | 
            -
                def _clean_caption(self, caption):
         | 
| 409 | 
            -
                    caption = str(caption)
         | 
| 410 | 
            -
                    caption = ul.unquote_plus(caption)
         | 
| 411 | 
            -
                    caption = caption.strip().lower()
         | 
| 412 | 
            -
                    caption = re.sub("<person>", "person", caption)
         | 
| 413 | 
            -
                    # urls:
         | 
| 414 | 
            -
                    caption = re.sub(
         | 
| 415 | 
            -
                        r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",  # noqa
         | 
| 416 | 
            -
                        "",
         | 
| 417 | 
            -
                        caption,
         | 
| 418 | 
            -
                    )  # regex for urls
         | 
| 419 | 
            -
                    caption = re.sub(
         | 
| 420 | 
            -
                        r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",  # noqa
         | 
| 421 | 
            -
                        "",
         | 
| 422 | 
            -
                        caption,
         | 
| 423 | 
            -
                    )  # regex for urls
         | 
| 424 | 
            -
                    # html:
         | 
| 425 | 
            -
                    caption = BeautifulSoup(caption, features="html.parser").text
         | 
| 426 | 
            -
             | 
| 427 | 
            -
                    # @<nickname>
         | 
| 428 | 
            -
                    caption = re.sub(r"@[\w\d]+\b", "", caption)
         | 
| 429 | 
            -
             | 
| 430 | 
            -
                    # 31C0—31EF CJK Strokes
         | 
| 431 | 
            -
                    # 31F0—31FF Katakana Phonetic Extensions
         | 
| 432 | 
            -
                    # 3200—32FF Enclosed CJK Letters and Months
         | 
| 433 | 
            -
                    # 3300—33FF CJK Compatibility
         | 
| 434 | 
            -
                    # 3400—4DBF CJK Unified Ideographs Extension A
         | 
| 435 | 
            -
                    # 4DC0—4DFF Yijing Hexagram Symbols
         | 
| 436 | 
            -
                    # 4E00—9FFF CJK Unified Ideographs
         | 
| 437 | 
            -
                    caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
         | 
| 438 | 
            -
                    caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
         | 
| 439 | 
            -
                    caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
         | 
| 440 | 
            -
                    caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
         | 
| 441 | 
            -
                    caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
         | 
| 442 | 
            -
                    caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
         | 
| 443 | 
            -
                    caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
         | 
| 444 | 
            -
                    #######################################################
         | 
| 445 | 
            -
             | 
| 446 | 
            -
                    # все виды тире / all types of dash --> "-"
         | 
| 447 | 
            -
                    caption = re.sub(
         | 
| 448 | 
            -
                        r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",  # noqa
         | 
| 449 | 
            -
                        "-",
         | 
| 450 | 
            -
                        caption,
         | 
| 451 | 
            -
                    )
         | 
| 452 | 
            -
             | 
| 453 | 
            -
                    # кавычки к одному стандарту
         | 
| 454 | 
            -
                    caption = re.sub(r"[`´«»“”¨]", '"', caption)
         | 
| 455 | 
            -
                    caption = re.sub(r"[‘’]", "'", caption)
         | 
| 456 | 
            -
             | 
| 457 | 
            -
                    # "
         | 
| 458 | 
            -
                    caption = re.sub(r""?", "", caption)
         | 
| 459 | 
            -
                    # &
         | 
| 460 | 
            -
                    caption = re.sub(r"&", "", caption)
         | 
| 461 | 
            -
             | 
| 462 | 
            -
                    # ip adresses:
         | 
| 463 | 
            -
                    caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
         | 
| 464 | 
            -
             | 
| 465 | 
            -
                    # article ids:
         | 
| 466 | 
            -
                    caption = re.sub(r"\d:\d\d\s+$", "", caption)
         | 
| 467 | 
            -
             | 
| 468 | 
            -
                    # \n
         | 
| 469 | 
            -
                    caption = re.sub(r"\\n", " ", caption)
         | 
| 470 | 
            -
             | 
| 471 | 
            -
                    # "#123"
         | 
| 472 | 
            -
                    caption = re.sub(r"#\d{1,3}\b", "", caption)
         | 
| 473 | 
            -
                    # "#12345.."
         | 
| 474 | 
            -
                    caption = re.sub(r"#\d{5,}\b", "", caption)
         | 
| 475 | 
            -
                    # "123456.."
         | 
| 476 | 
            -
                    caption = re.sub(r"\b\d{6,}\b", "", caption)
         | 
| 477 | 
            -
                    # filenames:
         | 
| 478 | 
            -
                    caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
         | 
| 479 | 
            -
             | 
| 480 | 
            -
                    #
         | 
| 481 | 
            -
                    caption = re.sub(r"[\"\']{2,}", r'"', caption)  # """AUSVERKAUFT"""
         | 
| 482 | 
            -
                    caption = re.sub(r"[\.]{2,}", r" ", caption)  # """AUSVERKAUFT"""
         | 
| 483 | 
            -
             | 
| 484 | 
            -
                    caption = re.sub(self.bad_punct_regex, r" ", caption)  # ***AUSVERKAUFT***, #AUSVERKAUFT
         | 
| 485 | 
            -
                    caption = re.sub(r"\s+\.\s+", r" ", caption)  # " . "
         | 
| 486 | 
            -
             | 
| 487 | 
            -
                    # this-is-my-cute-cat / this_is_my_cute_cat
         | 
| 488 | 
            -
                    regex2 = re.compile(r"(?:\-|\_)")
         | 
| 489 | 
            -
                    if len(re.findall(regex2, caption)) > 3:
         | 
| 490 | 
            -
                        caption = re.sub(regex2, " ", caption)
         | 
| 491 | 
            -
             | 
| 492 | 
            -
                    caption = ftfy.fix_text(caption)
         | 
| 493 | 
            -
                    caption = html.unescape(html.unescape(caption))
         | 
| 494 | 
            -
             | 
| 495 | 
            -
                    caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption)  # jc6640
         | 
| 496 | 
            -
                    caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption)  # jc6640vc
         | 
| 497 | 
            -
                    caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption)  # 6640vc231
         | 
| 498 | 
            -
             | 
| 499 | 
            -
                    caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
         | 
| 500 | 
            -
                    caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
         | 
| 501 | 
            -
                    caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
         | 
| 502 | 
            -
                    caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
         | 
| 503 | 
            -
                    caption = re.sub(r"\bpage\s+\d+\b", "", caption)
         | 
| 504 | 
            -
             | 
| 505 | 
            -
                    caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption)  # j2d1a2a...
         | 
| 506 | 
            -
             | 
| 507 | 
            -
                    caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
         | 
| 508 | 
            -
             | 
| 509 | 
            -
                    caption = re.sub(r"\b\s+\:\s+", r": ", caption)
         | 
| 510 | 
            -
                    caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
         | 
| 511 | 
            -
                    caption = re.sub(r"\s+", " ", caption)
         | 
| 512 | 
            -
             | 
| 513 | 
            -
                    caption.strip()
         | 
| 514 | 
            -
             | 
| 515 | 
            -
                    caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
         | 
| 516 | 
            -
                    caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
         | 
| 517 | 
            -
                    caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
         | 
| 518 | 
            -
                    caption = re.sub(r"^\.\S+$", "", caption)
         | 
| 519 | 
            -
             | 
| 520 | 
            -
                    return caption.strip()
         | 
| 521 |  | 
| 522 | 
             
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
         | 
| 523 | 
             
                def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
         | 
| 524 | 
            -
                    if self.vae.quant_conv.weight.ndim==5:
         | 
| 525 | 
            -
                         | 
| 526 | 
            -
             | 
| 527 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 528 | 
             
                    else:
         | 
| 529 | 
             
                        shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 530 |  | 
| @@ -538,11 +642,12 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 538 | 
             
                        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         | 
| 539 | 
             
                    else:
         | 
| 540 | 
             
                        latents = latents.to(device)
         | 
| 541 | 
            -
             | 
| 542 | 
             
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 543 | 
            -
                     | 
|  | |
| 544 | 
             
                    return latents
         | 
| 545 | 
            -
             | 
| 546 | 
             
                def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
         | 
| 547 | 
             
                    if video.size()[2] <= mini_batch_encoder:
         | 
| 548 | 
             
                        return video
         | 
| @@ -558,16 +663,17 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 558 |  | 
| 559 | 
             
                    video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
         | 
| 560 | 
             
                    return video
         | 
| 561 | 
            -
             | 
| 562 | 
             
                def decode_latents(self, latents):
         | 
| 563 | 
             
                    video_length = latents.shape[2]
         | 
| 564 | 
             
                    latents = 1 / self.vae.config.scaling_factor * latents
         | 
| 565 | 
            -
                    if self.vae.quant_conv.weight.ndim==5:
         | 
| 566 | 
             
                        mini_batch_encoder = self.vae.mini_batch_encoder
         | 
| 567 | 
             
                        mini_batch_decoder = self.vae.mini_batch_decoder
         | 
| 568 | 
             
                        video = self.vae.decode(latents)[0]
         | 
| 569 | 
             
                        video = video.clamp(-1, 1)
         | 
| 570 | 
            -
                         | 
|  | |
| 571 | 
             
                    else:
         | 
| 572 | 
             
                        latents = rearrange(latents, "b c f h w -> (b f) c h w")
         | 
| 573 | 
             
                        video = []
         | 
| @@ -580,8 +686,28 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 580 | 
             
                    video = video.cpu().float().numpy()
         | 
| 581 | 
             
                    return video
         | 
| 582 |  | 
| 583 | 
            -
                 | 
| 584 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 585 |  | 
| 586 | 
             
                @torch.no_grad()
         | 
| 587 | 
             
                @replace_example_docstring(EXAMPLE_DOC_STRING)
         | 
| @@ -589,103 +715,131 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 589 | 
             
                    self,
         | 
| 590 | 
             
                    prompt: Union[str, List[str]] = None,
         | 
| 591 | 
             
                    video_length: Optional[int] = None,
         | 
| 592 | 
            -
                    negative_prompt: str = "",
         | 
| 593 | 
            -
                    num_inference_steps: int = 20,
         | 
| 594 | 
            -
                    timesteps: List[int] = None,
         | 
| 595 | 
            -
                    guidance_scale: float = 4.5,
         | 
| 596 | 
            -
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 597 | 
             
                    height: Optional[int] = None,
         | 
| 598 | 
             
                    width: Optional[int] = None,
         | 
| 599 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 600 | 
             
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 601 | 
            -
                    latents: Optional[torch. | 
| 602 | 
            -
                    prompt_embeds: Optional[torch. | 
| 603 | 
            -
                     | 
| 604 | 
            -
                    negative_prompt_embeds: Optional[torch. | 
| 605 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 606 | 
             
                    output_type: Optional[str] = "latent",
         | 
| 607 | 
             
                    return_dict: bool = True,
         | 
| 608 | 
            -
                     | 
| 609 | 
            -
             | 
| 610 | 
            -
                     | 
| 611 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 612 | 
             
                    comfyui_progressbar: bool = False,
         | 
| 613 | 
            -
                     | 
| 614 | 
            -
                ) | 
| 615 | 
            -
                    """
         | 
| 616 | 
            -
                     | 
| 617 | 
            -
             | 
| 618 | 
            -
                    Args:
         | 
| 619 | 
            -
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 620 | 
            -
                            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
         | 
| 621 | 
            -
                            instead.
         | 
| 622 | 
            -
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 623 | 
            -
                            The prompt or prompts not to guide the image generation. If not defined, one has to pass
         | 
| 624 | 
            -
                            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
         | 
| 625 | 
            -
                            less than `1`).
         | 
| 626 | 
            -
                        num_inference_steps (`int`, *optional*, defaults to 100):
         | 
| 627 | 
            -
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 628 | 
            -
                            expense of slower inference.
         | 
| 629 | 
            -
                        timesteps (`List[int]`, *optional*):
         | 
| 630 | 
            -
                            Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
         | 
| 631 | 
            -
                            timesteps are used. Must be in descending order.
         | 
| 632 | 
            -
                        guidance_scale (`float`, *optional*, defaults to 7.0):
         | 
| 633 | 
            -
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 634 | 
            -
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 635 | 
            -
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 636 | 
            -
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 637 | 
            -
                            usually at the expense of lower image quality.
         | 
| 638 | 
            -
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 639 | 
            -
                            The number of images to generate per prompt.
         | 
| 640 | 
            -
                        height (`int`, *optional*, defaults to self.unet.config.sample_size):
         | 
| 641 | 
            -
                            The height in pixels of the generated image.
         | 
| 642 | 
            -
                        width (`int`, *optional*, defaults to self.unet.config.sample_size):
         | 
| 643 | 
            -
                            The width in pixels of the generated image.
         | 
| 644 | 
            -
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 645 | 
            -
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 646 | 
            -
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 647 | 
            -
                        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
         | 
| 648 | 
            -
                            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
         | 
| 649 | 
            -
                            to make generation deterministic.
         | 
| 650 | 
            -
                        latents (`torch.FloatTensor`, *optional*):
         | 
| 651 | 
            -
                            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
         | 
| 652 | 
            -
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         | 
| 653 | 
            -
                            tensor will ge generated by sampling using the supplied random `generator`.
         | 
| 654 | 
            -
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 655 | 
            -
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         | 
| 656 | 
            -
                            provided, text embeddings will be generated from `prompt` input argument.
         | 
| 657 | 
            -
                        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 658 | 
            -
                            Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
         | 
| 659 | 
            -
                            provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
         | 
| 660 | 
            -
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 661 | 
            -
                            The output format of the generate image. Choose between
         | 
| 662 | 
            -
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 663 | 
            -
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 664 | 
            -
                            Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
         | 
| 665 | 
            -
                        callback (`Callable`, *optional*):
         | 
| 666 | 
            -
                            A function that will be called every `callback_steps` steps during inference. The function will be
         | 
| 667 | 
            -
                            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
         | 
| 668 | 
            -
                        callback_steps (`int`, *optional*, defaults to 1):
         | 
| 669 | 
            -
                            The frequency at which the `callback` function will be called. If not specified, the callback will be
         | 
| 670 | 
            -
                            called at every step.
         | 
| 671 | 
            -
                        clean_caption (`bool`, *optional*, defaults to `True`):
         | 
| 672 | 
            -
                            Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
         | 
| 673 | 
            -
                            be installed. If the dependencies are not installed, the embeddings will be created from the raw
         | 
| 674 | 
            -
                            prompt.
         | 
| 675 | 
            -
                        mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked.
         | 
| 676 |  | 
| 677 | 
             
                    Examples:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 678 |  | 
| 679 | 
             
                    Returns:
         | 
| 680 | 
            -
                        [`~pipelines. | 
| 681 | 
            -
                            If `return_dict` is `True`, [`~pipelines. | 
| 682 | 
            -
                            returned where the first element is a list with the generated images
         | 
|  | |
|  | |
| 683 | 
             
                    """
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 684 | 
             
                    # 1. Check inputs. Raise error if not correct
         | 
| 685 | 
            -
                     | 
| 686 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 687 |  | 
| 688 | 
            -
                    # 2.  | 
| 689 | 
             
                    if prompt is not None and isinstance(prompt, str):
         | 
| 690 | 
             
                        batch_size = 1
         | 
| 691 | 
             
                    elif prompt is not None and isinstance(prompt, list):
         | 
| @@ -694,136 +848,223 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 694 | 
             
                        batch_size = prompt_embeds.shape[0]
         | 
| 695 |  | 
| 696 | 
             
                    device = self._execution_device
         | 
| 697 | 
            -
             | 
| 698 | 
            -
             | 
| 699 | 
            -
                     | 
| 700 | 
            -
             | 
| 701 | 
            -
                     | 
|  | |
| 702 |  | 
| 703 | 
             
                    # 3. Encode input prompt
         | 
| 704 | 
             
                    (
         | 
| 705 | 
             
                        prompt_embeds,
         | 
| 706 | 
            -
                        prompt_attention_mask,
         | 
| 707 | 
             
                        negative_prompt_embeds,
         | 
|  | |
| 708 | 
             
                        negative_prompt_attention_mask,
         | 
| 709 | 
             
                    ) = self.encode_prompt(
         | 
| 710 | 
            -
                        prompt,
         | 
| 711 | 
            -
                        do_classifier_free_guidance,
         | 
| 712 | 
            -
                        negative_prompt=negative_prompt,
         | 
| 713 | 
            -
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 714 | 
             
                        device=device,
         | 
|  | |
|  | |
|  | |
|  | |
| 715 | 
             
                        prompt_embeds=prompt_embeds,
         | 
| 716 | 
             
                        negative_prompt_embeds=negative_prompt_embeds,
         | 
| 717 | 
             
                        prompt_attention_mask=prompt_attention_mask,
         | 
| 718 | 
             
                        negative_prompt_attention_mask=negative_prompt_attention_mask,
         | 
| 719 | 
            -
                         | 
| 720 | 
            -
                        max_sequence_length=max_sequence_length,
         | 
| 721 | 
             
                    )
         | 
| 722 | 
            -
                    if  | 
| 723 | 
            -
                         | 
| 724 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 725 |  | 
| 726 | 
             
                    # 4. Prepare timesteps
         | 
| 727 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 728 |  | 
| 729 | 
            -
                    # 5. Prepare  | 
| 730 | 
            -
                     | 
| 731 | 
             
                    latents = self.prepare_latents(
         | 
| 732 | 
             
                        batch_size * num_images_per_prompt,
         | 
| 733 | 
            -
                         | 
| 734 | 
             
                        video_length,
         | 
| 735 | 
             
                        height,
         | 
| 736 | 
             
                        width,
         | 
| 737 | 
            -
                         | 
| 738 | 
             
                        device,
         | 
| 739 | 
             
                        generator,
         | 
| 740 | 
             
                        latents,
         | 
| 741 | 
             
                    )
         | 
|  | |
|  | |
| 742 |  | 
| 743 | 
             
                    # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
         | 
| 744 | 
             
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 745 |  | 
| 746 | 
            -
                    #  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 747 | 
             
                    added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
         | 
| 748 | 
            -
                    if self.transformer.config.sample_size == 128:
         | 
| 749 | 
             
                        resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
         | 
| 750 | 
             
                        aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
         | 
| 751 | 
            -
                        resolution = resolution.to(dtype= | 
| 752 | 
            -
                        aspect_ratio = aspect_ratio.to(dtype= | 
| 753 |  | 
| 754 | 
            -
                        if do_classifier_free_guidance:
         | 
| 755 | 
             
                            resolution = torch.cat([resolution, resolution], dim=0)
         | 
| 756 | 
             
                            aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
         | 
| 757 |  | 
| 758 | 
             
                        added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
         | 
| 759 |  | 
| 760 | 
            -
                     | 
| 761 | 
            -
             | 
| 762 | 
            -
                         | 
| 763 | 
            -
                         | 
| 764 | 
            -
             | 
| 765 | 
            -
             | 
| 766 | 
            -
             | 
| 767 | 
            -
                     | 
| 768 | 
            -
             | 
| 769 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 770 | 
             
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 771 | 
             
                        for i, t in enumerate(timesteps):
         | 
| 772 | 
            -
                             | 
| 773 | 
            -
             | 
| 774 | 
            -
             | 
| 775 | 
            -
                             | 
| 776 | 
            -
                             | 
| 777 | 
            -
             | 
| 778 | 
            -
                                 | 
| 779 | 
            -
             | 
| 780 | 
            -
             | 
| 781 | 
            -
             | 
| 782 | 
            -
                                 | 
| 783 | 
            -
             | 
| 784 | 
            -
             | 
| 785 | 
            -
                             | 
| 786 | 
            -
                                current_timestep = current_timestep[None].to(latent_model_input.device)
         | 
| 787 | 
            -
                            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 788 | 
            -
                            current_timestep = current_timestep.expand(latent_model_input.shape[0])
         | 
| 789 | 
            -
             | 
| 790 | 
            -
                            # predict noise model_output
         | 
| 791 | 
             
                            noise_pred = self.transformer(
         | 
| 792 | 
             
                                latent_model_input,
         | 
|  | |
| 793 | 
             
                                encoder_hidden_states=prompt_embeds,
         | 
| 794 | 
            -
                                 | 
| 795 | 
            -
                                 | 
|  | |
|  | |
|  | |
|  | |
| 796 | 
             
                                added_cond_kwargs=added_cond_kwargs,
         | 
| 797 | 
             
                                return_dict=False,
         | 
| 798 | 
             
                            )[0]
         | 
|  | |
|  | |
|  | |
| 799 |  | 
| 800 | 
             
                            # perform guidance
         | 
| 801 | 
            -
                            if do_classifier_free_guidance:
         | 
| 802 | 
             
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 803 | 
             
                                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 804 |  | 
| 805 | 
            -
                             | 
| 806 | 
            -
             | 
| 807 | 
            -
                                noise_pred = noise_pred | 
| 808 | 
            -
                            else:
         | 
| 809 | 
            -
                                noise_pred = noise_pred
         | 
| 810 |  | 
| 811 | 
            -
                            # compute previous  | 
| 812 | 
             
                            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
         | 
| 813 |  | 
| 814 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 815 | 
             
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
| 816 | 
             
                                progress_bar.update()
         | 
| 817 | 
            -
             | 
| 818 | 
            -
             | 
| 819 | 
            -
             | 
| 820 |  | 
| 821 | 
             
                            if comfyui_progressbar:
         | 
| 822 | 
             
                                pbar.update(1)
         | 
| 823 |  | 
| 824 | 
            -
                    if self.enable_autocast_float8_transformer_flag:
         | 
| 825 | 
            -
                        self.transformer = self.transformer.to("cpu", origin_weight_dtype)
         | 
| 826 | 
            -
             | 
| 827 | 
             
                    # Post-processing
         | 
| 828 | 
             
                    video = self.decode_latents(latents)
         | 
| 829 |  | 
| @@ -831,7 +1072,10 @@ class EasyAnimatePipeline(DiffusionPipeline): | |
| 831 | 
             
                    if output_type == "latent":
         | 
| 832 | 
             
                        video = torch.from_numpy(video)
         | 
| 833 |  | 
|  | |
|  | |
|  | |
| 834 | 
             
                    if not return_dict:
         | 
| 835 | 
             
                        return video
         | 
| 836 |  | 
| 837 | 
            -
                    return EasyAnimatePipelineOutput( | 
|  | |
| 1 | 
            +
            # Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
         | 
| 2 | 
             
            #
         | 
| 3 | 
             
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
             
            # you may not use this file except in compliance with the License.
         | 
|  | |
| 12 | 
             
            # See the License for the specific language governing permissions and
         | 
| 13 | 
             
            # limitations under the License.
         | 
| 14 |  | 
|  | |
|  | |
| 15 | 
             
            import inspect
         | 
|  | |
|  | |
| 16 | 
             
            from dataclasses import dataclass
         | 
| 17 | 
            +
            from typing import Callable, Dict, List, Optional, Tuple, Union
         | 
| 18 |  | 
| 19 | 
             
            import numpy as np
         | 
| 20 | 
             
            import torch
         | 
| 21 | 
            +
            import torch.nn.functional as F
         | 
| 22 | 
            +
            from diffusers import DiffusionPipeline
         | 
| 23 | 
            +
            from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
         | 
| 24 | 
             
            from diffusers.image_processor import VaeImageProcessor
         | 
| 25 | 
            +
            from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
         | 
| 26 | 
            +
            from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
         | 
| 27 | 
            +
                                                     get_3d_rotary_pos_embed)
         | 
| 28 | 
            +
            from diffusers.pipelines.pipeline_utils import DiffusionPipeline
         | 
| 29 | 
            +
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
         | 
| 30 | 
            +
            from diffusers.pipelines.stable_diffusion.safety_checker import \
         | 
| 31 | 
            +
                StableDiffusionSafetyChecker
         | 
| 32 | 
            +
            from diffusers.schedulers import DDIMScheduler, FlowMatchEulerDiscreteScheduler
         | 
| 33 | 
             
            from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
         | 
| 34 | 
            +
                                         is_bs4_available, is_ftfy_available,
         | 
| 35 | 
            +
                                         is_torch_xla_available, logging,
         | 
| 36 | 
             
                                         replace_example_docstring)
         | 
| 37 | 
             
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 38 | 
             
            from einops import rearrange
         | 
| 39 | 
            +
            from PIL import Image
         | 
| 40 | 
             
            from tqdm import tqdm
         | 
| 41 | 
            +
            from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
         | 
| 42 | 
            +
                                      Qwen2Tokenizer, Qwen2VLForConditionalGeneration, 
         | 
| 43 | 
            +
                                      T5EncoderModel, T5Tokenizer)
         | 
| 44 |  | 
| 45 | 
            +
            from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
         | 
| 46 | 
            +
            from .pipeline_easyanimate_inpaint import EasyAnimatePipelineOutput
         | 
| 47 |  | 
| 48 | 
            +
            if is_torch_xla_available():
         | 
| 49 | 
            +
                import torch_xla.core.xla_model as xm
         | 
| 50 |  | 
| 51 | 
            +
                XLA_AVAILABLE = True
         | 
| 52 | 
            +
            else:
         | 
| 53 | 
            +
                XLA_AVAILABLE = False
         | 
| 54 |  | 
|  | |
|  | |
| 55 |  | 
| 56 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 57 |  | 
| 58 | 
             
            EXAMPLE_DOC_STRING = """
         | 
| 59 | 
             
                Examples:
         | 
| 60 | 
            +
                    ```python
         | 
| 61 | 
             
                    >>> import torch
         | 
| 62 | 
             
                    >>> from diffusers import EasyAnimatePipeline
         | 
| 63 | 
            +
                    >>> from diffusers.utils import export_to_video
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    >>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh" or "alibaba-pai/EasyAnimateV5.1-7b-zh"
         | 
| 66 | 
            +
                    >>> pipe = EasyAnimatePipeline.from_pretrained("alibaba-pai/EasyAnimateV5.1-7b-zh", torch_dtype=torch.float16).to("cuda")
         | 
| 67 | 
            +
                    >>> prompt = (
         | 
| 68 | 
            +
                    ...     "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
         | 
| 69 | 
            +
                    ...     "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
         | 
| 70 | 
            +
                    ...     "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
         | 
| 71 | 
            +
                    ...     "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
         | 
| 72 | 
            +
                    ...     "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
         | 
| 73 | 
            +
                    ...     "atmosphere of this unique musical performance."
         | 
| 74 | 
            +
                    ... )
         | 
| 75 | 
            +
                    >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).sample[0]
         | 
| 76 | 
            +
                    >>> export_to_video(video, "output.mp4", fps=8)
         | 
| 77 | 
             
                    ```
         | 
| 78 | 
             
            """
         | 
| 79 |  | 
| 80 | 
            +
             | 
| 81 | 
            +
            # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
         | 
| 82 | 
            +
            def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
         | 
| 83 | 
            +
                tw = tgt_width
         | 
| 84 | 
            +
                th = tgt_height
         | 
| 85 | 
            +
                h, w = src
         | 
| 86 | 
            +
                r = h / w
         | 
| 87 | 
            +
                if r > (th / tw):
         | 
| 88 | 
            +
                    resize_height = th
         | 
| 89 | 
            +
                    resize_width = int(round(th / h * w))
         | 
| 90 | 
            +
                else:
         | 
| 91 | 
            +
                    resize_width = tw
         | 
| 92 | 
            +
                    resize_height = int(round(tw / w * h))
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                crop_top = int(round((th - resize_height) / 2.0))
         | 
| 95 | 
            +
                crop_left = int(round((tw - resize_width) / 2.0))
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
         | 
| 101 | 
            +
            def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
         | 
| 102 | 
            +
                """
         | 
| 103 | 
            +
                Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
         | 
| 104 | 
            +
                Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
         | 
| 105 | 
            +
                """
         | 
| 106 | 
            +
                std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
         | 
| 107 | 
            +
                std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
         | 
| 108 | 
            +
                # rescale the results from guidance (fixes overexposure)
         | 
| 109 | 
            +
                noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
         | 
| 110 | 
            +
                # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
         | 
| 111 | 
            +
                noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
         | 
| 112 | 
            +
                return noise_cfg
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
             
            # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
         | 
| 116 | 
             
            def retrieve_timesteps(
         | 
| 117 | 
             
                scheduler,
         | 
| 118 | 
             
                num_inference_steps: Optional[int] = None,
         | 
| 119 | 
             
                device: Optional[Union[str, torch.device]] = None,
         | 
| 120 | 
             
                timesteps: Optional[List[int]] = None,
         | 
| 121 | 
            +
                sigmas: Optional[List[float]] = None,
         | 
| 122 | 
             
                **kwargs,
         | 
| 123 | 
             
            ):
         | 
| 124 | 
             
                """
         | 
|  | |
| 129 | 
             
                    scheduler (`SchedulerMixin`):
         | 
| 130 | 
             
                        The scheduler to get timesteps from.
         | 
| 131 | 
             
                    num_inference_steps (`int`):
         | 
| 132 | 
            +
                        The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
         | 
| 133 | 
            +
                        must be `None`.
         | 
| 134 | 
             
                    device (`str` or `torch.device`, *optional*):
         | 
| 135 | 
             
                        The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
         | 
| 136 | 
             
                    timesteps (`List[int]`, *optional*):
         | 
| 137 | 
            +
                        Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
         | 
| 138 | 
            +
                        `num_inference_steps` and `sigmas` must be `None`.
         | 
| 139 | 
            +
                    sigmas (`List[float]`, *optional*):
         | 
| 140 | 
            +
                        Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
         | 
| 141 | 
            +
                        `num_inference_steps` and `timesteps` must be `None`.
         | 
| 142 |  | 
| 143 | 
             
                Returns:
         | 
| 144 | 
             
                    `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
         | 
| 145 | 
             
                    second element is the number of inference steps.
         | 
| 146 | 
             
                """
         | 
| 147 | 
            +
                if timesteps is not None and sigmas is not None:
         | 
| 148 | 
            +
                    raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
         | 
| 149 | 
             
                if timesteps is not None:
         | 
| 150 | 
             
                    accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
         | 
| 151 | 
             
                    if not accepts_timesteps:
         | 
|  | |
| 156 | 
             
                    scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
         | 
| 157 | 
             
                    timesteps = scheduler.timesteps
         | 
| 158 | 
             
                    num_inference_steps = len(timesteps)
         | 
| 159 | 
            +
                elif sigmas is not None:
         | 
| 160 | 
            +
                    accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
         | 
| 161 | 
            +
                    if not accept_sigmas:
         | 
| 162 | 
            +
                        raise ValueError(
         | 
| 163 | 
            +
                            f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
         | 
| 164 | 
            +
                            f" sigmas schedules. Please check whether you are using the correct scheduler."
         | 
| 165 | 
            +
                        )
         | 
| 166 | 
            +
                    scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
         | 
| 167 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 168 | 
            +
                    num_inference_steps = len(timesteps)
         | 
| 169 | 
             
                else:
         | 
| 170 | 
             
                    scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
         | 
| 171 | 
             
                    timesteps = scheduler.timesteps
         | 
| 172 | 
             
                return timesteps, num_inference_steps
         | 
| 173 |  | 
|  | |
|  | |
|  | |
| 174 |  | 
| 175 | 
             
            class EasyAnimatePipeline(DiffusionPipeline):
         | 
| 176 | 
             
                r"""
         | 
| 177 | 
            +
                Pipeline for text-to-video generation using EasyAnimate.
         | 
| 178 |  | 
| 179 | 
             
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
         | 
| 180 | 
             
                library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
         | 
| 181 |  | 
| 182 | 
            +
                EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
         | 
| 183 | 
            +
                EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
         | 
| 184 | 
            +
                HunyuanDiT team) in V5.
         | 
| 185 | 
            +
             | 
| 186 | 
             
                Args:
         | 
| 187 | 
            +
                    vae ([`AutoencoderKLMagvit`]):
         | 
| 188 | 
            +
                        Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. 
         | 
| 189 | 
            +
                    text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
         | 
| 190 | 
            +
                        EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
         | 
| 191 | 
            +
                        EasyAnimate uses [bilingual CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) in V5.
         | 
| 192 | 
            +
                    tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
         | 
| 193 | 
            +
                        A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
         | 
| 194 | 
            +
                    transformer ([`EasyAnimateTransformer3DModel`]):
         | 
| 195 | 
            +
                        The EasyAnimate model designed by EasyAnimate Team.
         | 
| 196 | 
            +
                    text_encoder_2 (`T5EncoderModel`):
         | 
| 197 | 
            +
                        EasyAnimate does not use text_encoder_2 in V5.1.
         | 
| 198 | 
            +
                        EasyAnimate uses [mT5](https://huggingface.co/google/mt5-base) embedder in V5.
         | 
| 199 | 
            +
                    tokenizer_2 (`T5Tokenizer`):
         | 
| 200 | 
            +
                        The tokenizer for the mT5 embedder.
         | 
| 201 | 
            +
                    scheduler ([`FlowMatchEulerDiscreteScheduler`]):
         | 
| 202 | 
            +
                        A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
         | 
| 203 | 
             
                """
         | 
|  | |
|  | |
|  | |
| 204 |  | 
| 205 | 
            +
                model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
         | 
| 206 | 
            +
                _optional_components = [
         | 
| 207 | 
            +
                    "text_encoder_2",
         | 
| 208 | 
            +
                    "tokenizer_2",
         | 
| 209 | 
            +
                    "text_encoder",
         | 
| 210 | 
            +
                    "tokenizer",
         | 
| 211 | 
            +
                ]
         | 
| 212 | 
            +
                _callback_tensor_inputs = [
         | 
| 213 | 
            +
                    "latents",
         | 
| 214 | 
            +
                    "prompt_embeds",
         | 
| 215 | 
            +
                    "negative_prompt_embeds",
         | 
| 216 | 
            +
                    "prompt_embeds_2",
         | 
| 217 | 
            +
                    "negative_prompt_embeds_2",
         | 
| 218 | 
            +
                ]
         | 
| 219 |  | 
| 220 | 
             
                def __init__(
         | 
| 221 | 
             
                    self,
         | 
| 222 | 
            +
                    vae: AutoencoderKLMagvit,
         | 
| 223 | 
            +
                    text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
         | 
| 224 | 
            +
                    tokenizer: Union[Qwen2Tokenizer, BertTokenizer], 
         | 
| 225 | 
            +
                    text_encoder_2: Optional[Union[T5EncoderModel, Qwen2VLForConditionalGeneration]],
         | 
| 226 | 
            +
                    tokenizer_2: Optional[Union[T5Tokenizer, Qwen2Tokenizer]],
         | 
| 227 | 
            +
                    transformer: EasyAnimateTransformer3DModel,
         | 
| 228 | 
            +
                    scheduler: FlowMatchEulerDiscreteScheduler,
         | 
| 229 | 
             
                ):
         | 
| 230 | 
             
                    super().__init__()
         | 
| 231 |  | 
| 232 | 
             
                    self.register_modules(
         | 
| 233 | 
            +
                        vae=vae,
         | 
| 234 | 
            +
                        text_encoder=text_encoder,
         | 
| 235 | 
            +
                        text_encoder_2=text_encoder_2,
         | 
| 236 | 
            +
                        tokenizer=tokenizer,
         | 
| 237 | 
            +
                        tokenizer_2=tokenizer_2,
         | 
| 238 | 
            +
                        transformer=transformer,
         | 
| 239 | 
            +
                        scheduler=scheduler,
         | 
| 240 | 
             
                    )
         | 
| 241 |  | 
| 242 | 
             
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 243 |  | 
| 244 | 
            +
                def enable_sequential_cpu_offload(self, *args, **kwargs):
         | 
| 245 | 
            +
                    super().enable_sequential_cpu_offload(*args, **kwargs)
         | 
| 246 | 
            +
                    if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
         | 
| 247 | 
            +
                        import accelerate
         | 
| 248 | 
            +
                        accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
         | 
| 249 | 
            +
                        self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
         | 
| 250 | 
            +
             | 
| 251 | 
             
                def encode_prompt(
         | 
| 252 | 
             
                    self,
         | 
| 253 | 
            +
                    prompt: str,
         | 
| 254 | 
            +
                    device: torch.device,
         | 
| 255 | 
            +
                    dtype: torch.dtype,
         | 
| 256 | 
             
                    num_images_per_prompt: int = 1,
         | 
| 257 | 
            +
                    do_classifier_free_guidance: bool = True,
         | 
| 258 | 
            +
                    negative_prompt: Optional[str] = None,
         | 
| 259 | 
            +
                    prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 260 | 
            +
                    negative_prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 261 | 
            +
                    prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 262 | 
            +
                    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 263 | 
            +
                    max_sequence_length: Optional[int] = None,
         | 
| 264 | 
            +
                    text_encoder_index: int = 0,
         | 
| 265 | 
            +
                    actual_max_sequence_length: int = 256
         | 
| 266 | 
             
                ):
         | 
| 267 | 
             
                    r"""
         | 
| 268 | 
             
                    Encodes the prompt into text encoder hidden states.
         | 
|  | |
| 270 | 
             
                    Args:
         | 
| 271 | 
             
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 272 | 
             
                            prompt to be encoded
         | 
| 273 | 
            +
                        device: (`torch.device`):
         | 
| 274 | 
            +
                            torch device
         | 
| 275 | 
            +
                        dtype (`torch.dtype`):
         | 
| 276 | 
            +
                            torch dtype
         | 
| 277 | 
            +
                        num_images_per_prompt (`int`):
         | 
|  | |
|  | |
| 278 | 
             
                            number of images that should be generated per prompt
         | 
| 279 | 
            +
                        do_classifier_free_guidance (`bool`):
         | 
| 280 | 
            +
                            whether to use classifier free guidance or not
         | 
| 281 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 282 | 
            +
                            The prompt or prompts not to guide the image generation. If not defined, one has to pass
         | 
| 283 | 
            +
                            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
         | 
| 284 | 
            +
                            less than `1`).
         | 
| 285 | 
            +
                        prompt_embeds (`torch.Tensor`, *optional*):
         | 
| 286 | 
             
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         | 
| 287 | 
             
                            provided, text embeddings will be generated from `prompt` input argument.
         | 
| 288 | 
            +
                        negative_prompt_embeds (`torch.Tensor`, *optional*):
         | 
| 289 | 
            +
                            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
         | 
| 290 | 
            +
                            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
         | 
| 291 | 
            +
                            argument.
         | 
| 292 | 
            +
                        prompt_attention_mask (`torch.Tensor`, *optional*):
         | 
| 293 | 
            +
                            Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
         | 
| 294 | 
            +
                        negative_prompt_attention_mask (`torch.Tensor`, *optional*):
         | 
| 295 | 
            +
                            Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
         | 
| 296 | 
            +
                        max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
         | 
| 297 | 
            +
                        text_encoder_index (`int`, *optional*):
         | 
| 298 | 
            +
                            Index of the text encoder to use. `0` for clip and `1` for T5.
         | 
| 299 | 
             
                    """
         | 
| 300 | 
            +
                    tokenizers = [self.tokenizer, self.tokenizer_2]
         | 
| 301 | 
            +
                    text_encoders = [self.text_encoder, self.text_encoder_2]
         | 
| 302 |  | 
| 303 | 
            +
                    tokenizer = tokenizers[text_encoder_index]
         | 
| 304 | 
            +
                    text_encoder = text_encoders[text_encoder_index]
         | 
|  | |
| 305 |  | 
| 306 | 
            +
                    if max_sequence_length is None:
         | 
| 307 | 
            +
                        if text_encoder_index == 0:
         | 
| 308 | 
            +
                            max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
         | 
| 309 | 
            +
                        if text_encoder_index == 1:
         | 
| 310 | 
            +
                            max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
         | 
| 311 | 
            +
                    else:
         | 
| 312 | 
            +
                        max_length = max_sequence_length
         | 
| 313 |  | 
| 314 | 
             
                    if prompt is not None and isinstance(prompt, str):
         | 
| 315 | 
             
                        batch_size = 1
         | 
|  | |
| 318 | 
             
                    else:
         | 
| 319 | 
             
                        batch_size = prompt_embeds.shape[0]
         | 
| 320 |  | 
|  | |
|  | |
|  | |
| 321 | 
             
                    if prompt_embeds is None:
         | 
| 322 | 
            +
                        if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
         | 
| 323 | 
            +
                            text_inputs = tokenizer(
         | 
| 324 | 
            +
                                prompt,
         | 
| 325 | 
            +
                                padding="max_length",
         | 
| 326 | 
            +
                                max_length=max_length,
         | 
| 327 | 
            +
                                truncation=True,
         | 
| 328 | 
            +
                                return_attention_mask=True,
         | 
| 329 | 
            +
                                return_tensors="pt",
         | 
| 330 | 
            +
                            )
         | 
| 331 | 
            +
                            text_input_ids = text_inputs.input_ids
         | 
| 332 | 
            +
                            if text_input_ids.shape[-1] > actual_max_sequence_length:
         | 
| 333 | 
            +
                                reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
         | 
| 334 | 
            +
                                text_inputs = tokenizer(
         | 
| 335 | 
            +
                                    reprompt,
         | 
| 336 | 
            +
                                    padding="max_length",
         | 
| 337 | 
            +
                                    max_length=max_length,
         | 
| 338 | 
            +
                                    truncation=True,
         | 
| 339 | 
            +
                                    return_attention_mask=True,
         | 
| 340 | 
            +
                                    return_tensors="pt",
         | 
| 341 | 
            +
                                )
         | 
| 342 | 
            +
                                text_input_ids = text_inputs.input_ids
         | 
| 343 | 
            +
                            untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
         | 
| 346 | 
            +
                                text_input_ids, untruncated_ids
         | 
| 347 | 
            +
                            ):
         | 
| 348 | 
            +
                                _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
         | 
| 349 | 
            +
                                removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
         | 
| 350 | 
            +
                                logger.warning(
         | 
| 351 | 
            +
                                    "The following part of your input was truncated because CLIP can only handle sequences up to"
         | 
| 352 | 
            +
                                    f" {_actual_max_sequence_length} tokens: {removed_text}"
         | 
| 353 | 
            +
                                )
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                            prompt_attention_mask = text_inputs.attention_mask.to(device)
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                            if self.transformer.config.enable_text_attention_mask:
         | 
| 358 | 
            +
                                prompt_embeds = text_encoder(
         | 
| 359 | 
            +
                                    text_input_ids.to(device),
         | 
| 360 | 
            +
                                    attention_mask=prompt_attention_mask,
         | 
| 361 | 
            +
                                )
         | 
| 362 | 
            +
                            else:
         | 
| 363 | 
            +
                                prompt_embeds = text_encoder(
         | 
| 364 | 
            +
                                    text_input_ids.to(device)
         | 
| 365 | 
            +
                                )
         | 
| 366 | 
            +
                            prompt_embeds = prompt_embeds[0]
         | 
| 367 | 
            +
                            prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 368 | 
            +
                        else:
         | 
| 369 | 
            +
                            if prompt is not None and isinstance(prompt, str):
         | 
| 370 | 
            +
                                messages = [
         | 
| 371 | 
            +
                                    {
         | 
| 372 | 
            +
                                        "role": "user",
         | 
| 373 | 
            +
                                        "content": [{"type": "text", "text": prompt}],
         | 
| 374 | 
            +
                                    }
         | 
| 375 | 
            +
                                ]
         | 
| 376 | 
            +
                            else:
         | 
| 377 | 
            +
                                messages = [
         | 
| 378 | 
            +
                                    {
         | 
| 379 | 
            +
                                        "role": "user",
         | 
| 380 | 
            +
                                        "content": [{"type": "text", "text": _prompt}],
         | 
| 381 | 
            +
                                    } for _prompt in prompt
         | 
| 382 | 
            +
                                ]
         | 
| 383 | 
            +
                            text = tokenizer.apply_chat_template(
         | 
| 384 | 
            +
                                messages, tokenize=False, add_generation_prompt=True
         | 
| 385 | 
             
                            )
         | 
| 386 |  | 
| 387 | 
            +
                            text_inputs = tokenizer(
         | 
| 388 | 
            +
                                text=[text],
         | 
| 389 | 
            +
                                padding="max_length",
         | 
| 390 | 
            +
                                max_length=max_length,
         | 
| 391 | 
            +
                                truncation=True,
         | 
| 392 | 
            +
                                return_attention_mask=True,
         | 
| 393 | 
            +
                                padding_side="right",
         | 
| 394 | 
            +
                                return_tensors="pt",
         | 
| 395 | 
            +
                            )
         | 
| 396 | 
            +
                            text_inputs = text_inputs.to(text_encoder.device)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                            text_input_ids = text_inputs.input_ids
         | 
| 399 | 
            +
                            prompt_attention_mask = text_inputs.attention_mask
         | 
| 400 | 
            +
                            if self.transformer.config.enable_text_attention_mask:
         | 
| 401 | 
            +
                                # Inference: Generation of the output
         | 
| 402 | 
            +
                                prompt_embeds = text_encoder(
         | 
| 403 | 
            +
                                    input_ids=text_input_ids,
         | 
| 404 | 
            +
                                    attention_mask=prompt_attention_mask,
         | 
| 405 | 
            +
                                    output_hidden_states=True).hidden_states[-2]
         | 
| 406 | 
            +
                            else:
         | 
| 407 | 
            +
                                raise ValueError("LLM needs attention_mask")
         | 
| 408 | 
            +
                            prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 409 | 
            +
                    
         | 
| 410 | 
             
                    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
         | 
| 411 |  | 
| 412 | 
             
                    bs_embed, seq_len, _ = prompt_embeds.shape
         | 
| 413 | 
            +
                    # duplicate text embeddings for each generation per prompt, using mps friendly method
         | 
| 414 | 
             
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 415 | 
             
                    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 416 | 
            +
                    prompt_attention_mask = prompt_attention_mask.to(device=device)
         | 
|  | |
| 417 |  | 
| 418 | 
             
                    # get unconditional embeddings for classifier free guidance
         | 
| 419 | 
             
                    if do_classifier_free_guidance and negative_prompt_embeds is None:
         | 
| 420 | 
            +
                        if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
         | 
| 421 | 
            +
                            uncond_tokens: List[str]
         | 
| 422 | 
            +
                            if negative_prompt is None:
         | 
| 423 | 
            +
                                uncond_tokens = [""] * batch_size
         | 
| 424 | 
            +
                            elif prompt is not None and type(prompt) is not type(negative_prompt):
         | 
| 425 | 
            +
                                raise TypeError(
         | 
| 426 | 
            +
                                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
         | 
| 427 | 
            +
                                    f" {type(prompt)}."
         | 
| 428 | 
            +
                                )
         | 
| 429 | 
            +
                            elif isinstance(negative_prompt, str):
         | 
| 430 | 
            +
                                uncond_tokens = [negative_prompt]
         | 
| 431 | 
            +
                            elif batch_size != len(negative_prompt):
         | 
| 432 | 
            +
                                raise ValueError(
         | 
| 433 | 
            +
                                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
         | 
| 434 | 
            +
                                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
         | 
| 435 | 
            +
                                    " the batch size of `prompt`."
         | 
| 436 | 
            +
                                )
         | 
| 437 | 
            +
                            else:
         | 
| 438 | 
            +
                                uncond_tokens = negative_prompt
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                            max_length = prompt_embeds.shape[1]
         | 
| 441 | 
            +
                            uncond_input = tokenizer(
         | 
| 442 | 
            +
                                uncond_tokens,
         | 
| 443 | 
            +
                                padding="max_length",
         | 
| 444 | 
            +
                                max_length=max_length,
         | 
| 445 | 
            +
                                truncation=True,
         | 
| 446 | 
            +
                                return_tensors="pt",
         | 
| 447 | 
            +
                            )
         | 
| 448 | 
            +
                            uncond_input_ids = uncond_input.input_ids
         | 
| 449 | 
            +
                            if uncond_input_ids.shape[-1] > actual_max_sequence_length:
         | 
| 450 | 
            +
                                reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
         | 
| 451 | 
            +
                                uncond_input = tokenizer(
         | 
| 452 | 
            +
                                    reuncond_tokens,
         | 
| 453 | 
            +
                                    padding="max_length",
         | 
| 454 | 
            +
                                    max_length=max_length,
         | 
| 455 | 
            +
                                    truncation=True,
         | 
| 456 | 
            +
                                    return_attention_mask=True,
         | 
| 457 | 
            +
                                    return_tensors="pt",
         | 
| 458 | 
            +
                                )
         | 
| 459 | 
            +
                                uncond_input_ids = uncond_input.input_ids
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                            negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
         | 
| 462 | 
            +
                            if self.transformer.config.enable_text_attention_mask:
         | 
| 463 | 
            +
                                negative_prompt_embeds = text_encoder(
         | 
| 464 | 
            +
                                    uncond_input.input_ids.to(device),
         | 
| 465 | 
            +
                                    attention_mask=negative_prompt_attention_mask,
         | 
| 466 | 
            +
                                )
         | 
| 467 | 
            +
                            else:
         | 
| 468 | 
            +
                                negative_prompt_embeds = text_encoder(
         | 
| 469 | 
            +
                                    uncond_input.input_ids.to(device)
         | 
| 470 | 
            +
                                )
         | 
| 471 | 
            +
                            negative_prompt_embeds = negative_prompt_embeds[0]
         | 
| 472 | 
            +
                            negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 473 | 
            +
                        else:
         | 
| 474 | 
            +
                            if negative_prompt is not None and isinstance(negative_prompt, str):
         | 
| 475 | 
            +
                                messages = [
         | 
| 476 | 
            +
                                    {
         | 
| 477 | 
            +
                                        "role": "user",
         | 
| 478 | 
            +
                                        "content": [{"type": "text", "text": negative_prompt}],
         | 
| 479 | 
            +
                                    }
         | 
| 480 | 
            +
                                ]
         | 
| 481 | 
            +
                            else:
         | 
| 482 | 
            +
                                messages = [
         | 
| 483 | 
            +
                                    {
         | 
| 484 | 
            +
                                        "role": "user",
         | 
| 485 | 
            +
                                        "content": [{"type": "text", "text": _negative_prompt}],
         | 
| 486 | 
            +
                                    } for _negative_prompt in negative_prompt
         | 
| 487 | 
            +
                                ]
         | 
| 488 | 
            +
                            text = tokenizer.apply_chat_template(
         | 
| 489 | 
            +
                                messages, tokenize=False, add_generation_prompt=True
         | 
| 490 | 
            +
                            )
         | 
| 491 |  | 
| 492 | 
            +
                            text_inputs = tokenizer(
         | 
| 493 | 
            +
                                text=[text],
         | 
| 494 | 
            +
                                padding="max_length",
         | 
| 495 | 
            +
                                max_length=max_length,
         | 
| 496 | 
            +
                                truncation=True,
         | 
| 497 | 
            +
                                return_attention_mask=True,
         | 
| 498 | 
            +
                                padding_side="right",
         | 
| 499 | 
            +
                                return_tensors="pt",
         | 
| 500 | 
            +
                            )
         | 
| 501 | 
            +
                            text_inputs = text_inputs.to(text_encoder.device)
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                            text_input_ids = text_inputs.input_ids
         | 
| 504 | 
            +
                            negative_prompt_attention_mask = text_inputs.attention_mask
         | 
| 505 | 
            +
                            if self.transformer.config.enable_text_attention_mask:
         | 
| 506 | 
            +
                                # Inference: Generation of the output
         | 
| 507 | 
            +
                                negative_prompt_embeds = text_encoder(
         | 
| 508 | 
            +
                                    input_ids=text_input_ids,
         | 
| 509 | 
            +
                                    attention_mask=negative_prompt_attention_mask,
         | 
| 510 | 
            +
                                    output_hidden_states=True).hidden_states[-2]
         | 
| 511 | 
            +
                            else:
         | 
| 512 | 
            +
                                raise ValueError("LLM needs attention_mask")
         | 
| 513 | 
            +
                            negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 514 |  | 
| 515 | 
             
                    if do_classifier_free_guidance:
         | 
| 516 | 
             
                        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
         | 
|  | |
| 520 |  | 
| 521 | 
             
                        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 522 | 
             
                        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
         | 
| 523 | 
            +
                        negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
         | 
| 524 | 
            +
                        
         | 
| 525 | 
            +
                    return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 526 |  | 
| 527 | 
             
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
         | 
| 528 | 
             
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
|  | |
| 547 | 
             
                    prompt,
         | 
| 548 | 
             
                    height,
         | 
| 549 | 
             
                    width,
         | 
| 550 | 
            +
                    negative_prompt=None,
         | 
|  | |
| 551 | 
             
                    prompt_embeds=None,
         | 
| 552 | 
             
                    negative_prompt_embeds=None,
         | 
| 553 | 
            +
                    prompt_attention_mask=None,
         | 
| 554 | 
            +
                    negative_prompt_attention_mask=None,
         | 
| 555 | 
            +
                    prompt_embeds_2=None,
         | 
| 556 | 
            +
                    negative_prompt_embeds_2=None,
         | 
| 557 | 
            +
                    prompt_attention_mask_2=None,
         | 
| 558 | 
            +
                    negative_prompt_attention_mask_2=None,
         | 
| 559 | 
            +
                    callback_on_step_end_tensor_inputs=None,
         | 
| 560 | 
             
                ):
         | 
| 561 | 
            +
                    if height % 16 != 0 or width % 16 != 0:
         | 
| 562 | 
             
                        raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
         | 
| 563 |  | 
| 564 | 
            +
                    if callback_on_step_end_tensor_inputs is not None and not all(
         | 
| 565 | 
            +
                        k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
         | 
| 566 | 
             
                    ):
         | 
| 567 | 
             
                        raise ValueError(
         | 
| 568 | 
            +
                            f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
         | 
|  | |
| 569 | 
             
                        )
         | 
| 570 |  | 
| 571 | 
             
                    if prompt is not None and prompt_embeds is not None:
         | 
|  | |
| 577 | 
             
                        raise ValueError(
         | 
| 578 | 
             
                            "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
         | 
| 579 | 
             
                        )
         | 
| 580 | 
            +
                    elif prompt is None and prompt_embeds_2 is None:
         | 
| 581 | 
            +
                        raise ValueError(
         | 
| 582 | 
            +
                            "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
         | 
| 583 | 
            +
                        )
         | 
| 584 | 
             
                    elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
         | 
| 585 | 
             
                        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
         | 
| 586 |  | 
| 587 | 
            +
                    if prompt_embeds is not None and prompt_attention_mask is None:
         | 
| 588 | 
            +
                        raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                    if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
         | 
| 591 | 
            +
                        raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
         | 
| 592 |  | 
| 593 | 
             
                    if negative_prompt is not None and negative_prompt_embeds is not None:
         | 
| 594 | 
             
                        raise ValueError(
         | 
|  | |
| 596 | 
             
                            f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
         | 
| 597 | 
             
                        )
         | 
| 598 |  | 
| 599 | 
            +
                    if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
         | 
| 600 | 
            +
                        raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
         | 
| 603 | 
            +
                        raise ValueError(
         | 
| 604 | 
            +
                            "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
         | 
| 605 | 
            +
                        )
         | 
| 606 | 
             
                    if prompt_embeds is not None and negative_prompt_embeds is not None:
         | 
| 607 | 
             
                        if prompt_embeds.shape != negative_prompt_embeds.shape:
         | 
| 608 | 
             
                            raise ValueError(
         | 
|  | |
| 610 | 
             
                                f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
         | 
| 611 | 
             
                                f" {negative_prompt_embeds.shape}."
         | 
| 612 | 
             
                            )
         | 
| 613 | 
            +
                    if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
         | 
| 614 | 
            +
                        if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
         | 
| 615 | 
            +
                            raise ValueError(
         | 
| 616 | 
            +
                                "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
         | 
| 617 | 
            +
                                f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
         | 
| 618 | 
            +
                                f" {negative_prompt_embeds_2.shape}."
         | 
| 619 | 
            +
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 620 |  | 
| 621 | 
             
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
         | 
| 622 | 
             
                def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
         | 
| 623 | 
            +
                    if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
         | 
| 624 | 
            +
                        if self.vae.cache_mag_vae:
         | 
| 625 | 
            +
                            mini_batch_encoder = self.vae.mini_batch_encoder
         | 
| 626 | 
            +
                            mini_batch_decoder = self.vae.mini_batch_decoder
         | 
| 627 | 
            +
                            shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 628 | 
            +
                        else:
         | 
| 629 | 
            +
                            mini_batch_encoder = self.vae.mini_batch_encoder
         | 
| 630 | 
            +
                            mini_batch_decoder = self.vae.mini_batch_decoder
         | 
| 631 | 
            +
                            shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 632 | 
             
                    else:
         | 
| 633 | 
             
                        shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 634 |  | 
|  | |
| 642 | 
             
                        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         | 
| 643 | 
             
                    else:
         | 
| 644 | 
             
                        latents = latents.to(device)
         | 
| 645 | 
            +
                    
         | 
| 646 | 
             
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 647 | 
            +
                    if hasattr(self.scheduler, "init_noise_sigma"):
         | 
| 648 | 
            +
                        latents = latents * self.scheduler.init_noise_sigma
         | 
| 649 | 
             
                    return latents
         | 
| 650 | 
            +
             | 
| 651 | 
             
                def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
         | 
| 652 | 
             
                    if video.size()[2] <= mini_batch_encoder:
         | 
| 653 | 
             
                        return video
         | 
|  | |
| 663 |  | 
| 664 | 
             
                    video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
         | 
| 665 | 
             
                    return video
         | 
| 666 | 
            +
             | 
| 667 | 
             
                def decode_latents(self, latents):
         | 
| 668 | 
             
                    video_length = latents.shape[2]
         | 
| 669 | 
             
                    latents = 1 / self.vae.config.scaling_factor * latents
         | 
| 670 | 
            +
                    if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
         | 
| 671 | 
             
                        mini_batch_encoder = self.vae.mini_batch_encoder
         | 
| 672 | 
             
                        mini_batch_decoder = self.vae.mini_batch_decoder
         | 
| 673 | 
             
                        video = self.vae.decode(latents)[0]
         | 
| 674 | 
             
                        video = video.clamp(-1, 1)
         | 
| 675 | 
            +
                        if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
         | 
| 676 | 
            +
                            video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
         | 
| 677 | 
             
                    else:
         | 
| 678 | 
             
                        latents = rearrange(latents, "b c f h w -> (b f) c h w")
         | 
| 679 | 
             
                        video = []
         | 
|  | |
| 686 | 
             
                    video = video.cpu().float().numpy()
         | 
| 687 | 
             
                    return video
         | 
| 688 |  | 
| 689 | 
            +
                @property
         | 
| 690 | 
            +
                def guidance_scale(self):
         | 
| 691 | 
            +
                    return self._guidance_scale
         | 
| 692 | 
            +
             | 
| 693 | 
            +
                @property
         | 
| 694 | 
            +
                def guidance_rescale(self):
         | 
| 695 | 
            +
                    return self._guidance_rescale
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 698 | 
            +
                # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 699 | 
            +
                # corresponds to doing no classifier free guidance.
         | 
| 700 | 
            +
                @property
         | 
| 701 | 
            +
                def do_classifier_free_guidance(self):
         | 
| 702 | 
            +
                    return self._guidance_scale > 1
         | 
| 703 | 
            +
             | 
| 704 | 
            +
                @property
         | 
| 705 | 
            +
                def num_timesteps(self):
         | 
| 706 | 
            +
                    return self._num_timesteps
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                @property
         | 
| 709 | 
            +
                def interrupt(self):
         | 
| 710 | 
            +
                    return self._interrupt
         | 
| 711 |  | 
| 712 | 
             
                @torch.no_grad()
         | 
| 713 | 
             
                @replace_example_docstring(EXAMPLE_DOC_STRING)
         | 
|  | |
| 715 | 
             
                    self,
         | 
| 716 | 
             
                    prompt: Union[str, List[str]] = None,
         | 
| 717 | 
             
                    video_length: Optional[int] = None,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 718 | 
             
                    height: Optional[int] = None,
         | 
| 719 | 
             
                    width: Optional[int] = None,
         | 
| 720 | 
            +
                    num_inference_steps: Optional[int] = 50,
         | 
| 721 | 
            +
                    guidance_scale: Optional[float] = 5.0,
         | 
| 722 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 723 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 724 | 
            +
                    eta: Optional[float] = 0.0,
         | 
| 725 | 
             
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 726 | 
            +
                    latents: Optional[torch.Tensor] = None,
         | 
| 727 | 
            +
                    prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 728 | 
            +
                    prompt_embeds_2: Optional[torch.Tensor] = None,
         | 
| 729 | 
            +
                    negative_prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 730 | 
            +
                    negative_prompt_embeds_2: Optional[torch.Tensor] = None,
         | 
| 731 | 
            +
                    prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 732 | 
            +
                    prompt_attention_mask_2: Optional[torch.Tensor] = None,
         | 
| 733 | 
            +
                    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 734 | 
            +
                    negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
         | 
| 735 | 
             
                    output_type: Optional[str] = "latent",
         | 
| 736 | 
             
                    return_dict: bool = True,
         | 
| 737 | 
            +
                    callback_on_step_end: Optional[
         | 
| 738 | 
            +
                        Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
         | 
| 739 | 
            +
                    ] = None,
         | 
| 740 | 
            +
                    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
         | 
| 741 | 
            +
                    guidance_rescale: float = 0.0,
         | 
| 742 | 
            +
                    original_size: Optional[Tuple[int, int]] = (1024, 1024),
         | 
| 743 | 
            +
                    target_size: Optional[Tuple[int, int]] = None,
         | 
| 744 | 
            +
                    crops_coords_top_left: Tuple[int, int] = (0, 0),
         | 
| 745 | 
             
                    comfyui_progressbar: bool = False,
         | 
| 746 | 
            +
                    timesteps: Optional[List[int]] = None,
         | 
| 747 | 
            +
                ):
         | 
| 748 | 
            +
                    r"""
         | 
| 749 | 
            +
                    Generates images or video using the EasyAnimate pipeline based on the provided prompts.
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 750 |  | 
| 751 | 
             
                    Examples:
         | 
| 752 | 
            +
                        prompt (`str` or `List[str]`, *optional*): 
         | 
| 753 | 
            +
                            Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
         | 
| 754 | 
            +
                        video_length (`int`, *optional*): 
         | 
| 755 | 
            +
                            Length of the generated video (in frames).
         | 
| 756 | 
            +
                        height (`int`, *optional*): 
         | 
| 757 | 
            +
                            Height of the generated image in pixels.
         | 
| 758 | 
            +
                        width (`int`, *optional*): 
         | 
| 759 | 
            +
                            Width of the generated image in pixels.
         | 
| 760 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 50): 
         | 
| 761 | 
            +
                            Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference.
         | 
| 762 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 5.0): 
         | 
| 763 | 
            +
                            Encourages the model to align outputs with prompts. A higher value may decrease image quality.
         | 
| 764 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*): 
         | 
| 765 | 
            +
                            Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
         | 
| 766 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1): 
         | 
| 767 | 
            +
                            Number of images to generate for each prompt.
         | 
| 768 | 
            +
                        eta (`float`, *optional*, defaults to 0.0): 
         | 
| 769 | 
            +
                            Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
         | 
| 770 | 
            +
                        generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 
         | 
| 771 | 
            +
                            A generator to ensure reproducibility in image generation.
         | 
| 772 | 
            +
                        latents (`torch.Tensor`, *optional*): 
         | 
| 773 | 
            +
                            Predefined latent tensors to condition generation.
         | 
| 774 | 
            +
                        prompt_embeds (`torch.Tensor`, *optional*): 
         | 
| 775 | 
            +
                            Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
         | 
| 776 | 
            +
                        prompt_embeds_2 (`torch.Tensor`, *optional*): 
         | 
| 777 | 
            +
                            Secondary text embeddings to supplement or replace the initial prompt embeddings.
         | 
| 778 | 
            +
                        negative_prompt_embeds (`torch.Tensor`, *optional*): 
         | 
| 779 | 
            +
                            Embeddings for negative prompts. Overrides string inputs if defined.
         | 
| 780 | 
            +
                        negative_prompt_embeds_2 (`torch.Tensor`, *optional*): 
         | 
| 781 | 
            +
                            Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`.
         | 
| 782 | 
            +
                        prompt_attention_mask (`torch.Tensor`, *optional*): 
         | 
| 783 | 
            +
                            Attention mask for the primary prompt embeddings.
         | 
| 784 | 
            +
                        prompt_attention_mask_2 (`torch.Tensor`, *optional*): 
         | 
| 785 | 
            +
                            Attention mask for the secondary prompt embeddings.
         | 
| 786 | 
            +
                        negative_prompt_attention_mask (`torch.Tensor`, *optional*): 
         | 
| 787 | 
            +
                            Attention mask for negative prompt embeddings.
         | 
| 788 | 
            +
                        negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): 
         | 
| 789 | 
            +
                            Attention mask for secondary negative prompt embeddings.
         | 
| 790 | 
            +
                        output_type (`str`, *optional*, defaults to "latent"): 
         | 
| 791 | 
            +
                            Format of the generated output, either as a PIL image or as a NumPy array.
         | 
| 792 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`): 
         | 
| 793 | 
            +
                            If `True`, returns a structured output. Otherwise returns a simple tuple.
         | 
| 794 | 
            +
                        callback_on_step_end (`Callable`, *optional*): 
         | 
| 795 | 
            +
                            Functions called at the end of each denoising step.
         | 
| 796 | 
            +
                        callback_on_step_end_tensor_inputs (`List[str]`, *optional*): 
         | 
| 797 | 
            +
                            Tensor names to be included in callback function calls.
         | 
| 798 | 
            +
                        guidance_rescale (`float`, *optional*, defaults to 0.0): 
         | 
| 799 | 
            +
                            Adjusts noise levels based on guidance scale.
         | 
| 800 | 
            +
                        original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): 
         | 
| 801 | 
            +
                            Original dimensions of the output.
         | 
| 802 | 
            +
                        target_size (`Tuple[int, int]`, *optional*): 
         | 
| 803 | 
            +
                            Desired output dimensions for calculations.
         | 
| 804 | 
            +
                        crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): 
         | 
| 805 | 
            +
                            Coordinates for cropping.
         | 
| 806 |  | 
| 807 | 
             
                    Returns:
         | 
| 808 | 
            +
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 809 | 
            +
                            If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
         | 
| 810 | 
            +
                            otherwise a `tuple` is returned where the first element is a list with the generated images and the
         | 
| 811 | 
            +
                            second element is a list of `bool`s indicating whether the corresponding generated image contains
         | 
| 812 | 
            +
                            "not-safe-for-work" (nsfw) content.
         | 
| 813 | 
             
                    """
         | 
| 814 | 
            +
             | 
| 815 | 
            +
                    if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
         | 
| 816 | 
            +
                        callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
         | 
| 817 | 
            +
             | 
| 818 | 
            +
                    # 0. default height and width
         | 
| 819 | 
            +
                    height = int((height // 16) * 16)
         | 
| 820 | 
            +
                    width = int((width // 16) * 16)
         | 
| 821 | 
            +
             | 
| 822 | 
             
                    # 1. Check inputs. Raise error if not correct
         | 
| 823 | 
            +
                    self.check_inputs(
         | 
| 824 | 
            +
                        prompt,
         | 
| 825 | 
            +
                        height,
         | 
| 826 | 
            +
                        width,
         | 
| 827 | 
            +
                        negative_prompt,
         | 
| 828 | 
            +
                        prompt_embeds,
         | 
| 829 | 
            +
                        negative_prompt_embeds,
         | 
| 830 | 
            +
                        prompt_attention_mask,
         | 
| 831 | 
            +
                        negative_prompt_attention_mask,
         | 
| 832 | 
            +
                        prompt_embeds_2,
         | 
| 833 | 
            +
                        negative_prompt_embeds_2,
         | 
| 834 | 
            +
                        prompt_attention_mask_2,
         | 
| 835 | 
            +
                        negative_prompt_attention_mask_2,
         | 
| 836 | 
            +
                        callback_on_step_end_tensor_inputs,
         | 
| 837 | 
            +
                    )
         | 
| 838 | 
            +
                    self._guidance_scale = guidance_scale
         | 
| 839 | 
            +
                    self._guidance_rescale = guidance_rescale
         | 
| 840 | 
            +
                    self._interrupt = False
         | 
| 841 |  | 
| 842 | 
            +
                    # 2. Define call parameters
         | 
| 843 | 
             
                    if prompt is not None and isinstance(prompt, str):
         | 
| 844 | 
             
                        batch_size = 1
         | 
| 845 | 
             
                    elif prompt is not None and isinstance(prompt, list):
         | 
|  | |
| 848 | 
             
                        batch_size = prompt_embeds.shape[0]
         | 
| 849 |  | 
| 850 | 
             
                    device = self._execution_device
         | 
| 851 | 
            +
                    if self.text_encoder is not None:
         | 
| 852 | 
            +
                        dtype = self.text_encoder.dtype
         | 
| 853 | 
            +
                    elif self.text_encoder_2 is not None:
         | 
| 854 | 
            +
                        dtype = self.text_encoder_2.dtype
         | 
| 855 | 
            +
                    else:
         | 
| 856 | 
            +
                        dtype = self.transformer.dtype
         | 
| 857 |  | 
| 858 | 
             
                    # 3. Encode input prompt
         | 
| 859 | 
             
                    (
         | 
| 860 | 
             
                        prompt_embeds,
         | 
|  | |
| 861 | 
             
                        negative_prompt_embeds,
         | 
| 862 | 
            +
                        prompt_attention_mask,
         | 
| 863 | 
             
                        negative_prompt_attention_mask,
         | 
| 864 | 
             
                    ) = self.encode_prompt(
         | 
| 865 | 
            +
                        prompt=prompt,
         | 
|  | |
|  | |
|  | |
| 866 | 
             
                        device=device,
         | 
| 867 | 
            +
                        dtype=dtype,
         | 
| 868 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 869 | 
            +
                        do_classifier_free_guidance=self.do_classifier_free_guidance,
         | 
| 870 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 871 | 
             
                        prompt_embeds=prompt_embeds,
         | 
| 872 | 
             
                        negative_prompt_embeds=negative_prompt_embeds,
         | 
| 873 | 
             
                        prompt_attention_mask=prompt_attention_mask,
         | 
| 874 | 
             
                        negative_prompt_attention_mask=negative_prompt_attention_mask,
         | 
| 875 | 
            +
                        text_encoder_index=0,
         | 
|  | |
| 876 | 
             
                    )
         | 
| 877 | 
            +
                    if self.tokenizer_2 is not None:
         | 
| 878 | 
            +
                        (
         | 
| 879 | 
            +
                            prompt_embeds_2,
         | 
| 880 | 
            +
                            negative_prompt_embeds_2,
         | 
| 881 | 
            +
                            prompt_attention_mask_2,
         | 
| 882 | 
            +
                            negative_prompt_attention_mask_2,
         | 
| 883 | 
            +
                        ) = self.encode_prompt(
         | 
| 884 | 
            +
                            prompt=prompt,
         | 
| 885 | 
            +
                            device=device,
         | 
| 886 | 
            +
                            dtype=dtype,
         | 
| 887 | 
            +
                            num_images_per_prompt=num_images_per_prompt,
         | 
| 888 | 
            +
                            do_classifier_free_guidance=self.do_classifier_free_guidance,
         | 
| 889 | 
            +
                            negative_prompt=negative_prompt,
         | 
| 890 | 
            +
                            prompt_embeds=prompt_embeds_2,
         | 
| 891 | 
            +
                            negative_prompt_embeds=negative_prompt_embeds_2,
         | 
| 892 | 
            +
                            prompt_attention_mask=prompt_attention_mask_2,
         | 
| 893 | 
            +
                            negative_prompt_attention_mask=negative_prompt_attention_mask_2,
         | 
| 894 | 
            +
                            text_encoder_index=1,
         | 
| 895 | 
            +
                        )
         | 
| 896 | 
            +
                    else:
         | 
| 897 | 
            +
                        prompt_embeds_2 = None
         | 
| 898 | 
            +
                        negative_prompt_embeds_2 = None
         | 
| 899 | 
            +
                        prompt_attention_mask_2 = None
         | 
| 900 | 
            +
                        negative_prompt_attention_mask_2 = None
         | 
| 901 |  | 
| 902 | 
             
                    # 4. Prepare timesteps
         | 
| 903 | 
            +
                    if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
         | 
| 904 | 
            +
                        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
         | 
| 905 | 
            +
                    else:
         | 
| 906 | 
            +
                        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
         | 
| 907 | 
            +
                    if comfyui_progressbar:
         | 
| 908 | 
            +
                        from comfy.utils import ProgressBar
         | 
| 909 | 
            +
                        pbar = ProgressBar(num_inference_steps + 1)
         | 
| 910 |  | 
| 911 | 
            +
                    # 5. Prepare latent variables
         | 
| 912 | 
            +
                    num_channels_latents = self.transformer.config.in_channels
         | 
| 913 | 
             
                    latents = self.prepare_latents(
         | 
| 914 | 
             
                        batch_size * num_images_per_prompt,
         | 
| 915 | 
            +
                        num_channels_latents,
         | 
| 916 | 
             
                        video_length,
         | 
| 917 | 
             
                        height,
         | 
| 918 | 
             
                        width,
         | 
| 919 | 
            +
                        dtype,
         | 
| 920 | 
             
                        device,
         | 
| 921 | 
             
                        generator,
         | 
| 922 | 
             
                        latents,
         | 
| 923 | 
             
                    )
         | 
| 924 | 
            +
                    if comfyui_progressbar:
         | 
| 925 | 
            +
                        pbar.update(1)
         | 
| 926 |  | 
| 927 | 
             
                    # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
         | 
| 928 | 
             
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 929 |  | 
| 930 | 
            +
                    # 7 create image_rotary_emb, style embedding & time ids
         | 
| 931 | 
            +
                    grid_height = height // 8 // self.transformer.config.patch_size
         | 
| 932 | 
            +
                    grid_width = width // 8 // self.transformer.config.patch_size
         | 
| 933 | 
            +
                    if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
         | 
| 934 | 
            +
                        base_size_width = 720 // 8 // self.transformer.config.patch_size
         | 
| 935 | 
            +
                        base_size_height = 480 // 8 // self.transformer.config.patch_size
         | 
| 936 | 
            +
             | 
| 937 | 
            +
                        grid_crops_coords = get_resize_crop_region_for_grid(
         | 
| 938 | 
            +
                            (grid_height, grid_width), base_size_width, base_size_height
         | 
| 939 | 
            +
                        )
         | 
| 940 | 
            +
                        image_rotary_emb = get_3d_rotary_pos_embed(
         | 
| 941 | 
            +
                            self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
         | 
| 942 | 
            +
                            temporal_size=latents.size(2), use_real=True,
         | 
| 943 | 
            +
                        )
         | 
| 944 | 
            +
                    else:
         | 
| 945 | 
            +
                        base_size = 512 // 8 // self.transformer.config.patch_size
         | 
| 946 | 
            +
                        grid_crops_coords = get_resize_crop_region_for_grid(
         | 
| 947 | 
            +
                            (grid_height, grid_width), base_size, base_size
         | 
| 948 | 
            +
                        )
         | 
| 949 | 
            +
                        image_rotary_emb = get_2d_rotary_pos_embed(
         | 
| 950 | 
            +
                            self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
         | 
| 951 | 
            +
                        )
         | 
| 952 | 
            +
             | 
| 953 | 
            +
                    # Get other hunyuan params
         | 
| 954 | 
            +
                    target_size = target_size or (height, width)
         | 
| 955 | 
            +
                    add_time_ids = list(original_size + target_size + crops_coords_top_left)
         | 
| 956 | 
            +
                    add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
         | 
| 957 | 
            +
                    style = torch.tensor([0], device=device)
         | 
| 958 | 
            +
             | 
| 959 | 
            +
                    if self.do_classifier_free_guidance:
         | 
| 960 | 
            +
                        add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
         | 
| 961 | 
            +
                        style = torch.cat([style] * 2, dim=0)
         | 
| 962 | 
            +
             | 
| 963 | 
            +
                    # To latents.device
         | 
| 964 | 
            +
                    add_time_ids = add_time_ids.to(dtype=dtype, device=device).repeat(
         | 
| 965 | 
            +
                        batch_size * num_images_per_prompt, 1
         | 
| 966 | 
            +
                    )
         | 
| 967 | 
            +
                    style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
         | 
| 968 | 
            +
             | 
| 969 | 
            +
                    # Get other pixart params
         | 
| 970 | 
             
                    added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
         | 
| 971 | 
            +
                    if self.transformer.config.get("sample_size", 64) == 128:
         | 
| 972 | 
             
                        resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
         | 
| 973 | 
             
                        aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
         | 
| 974 | 
            +
                        resolution = resolution.to(dtype=dtype, device=device)
         | 
| 975 | 
            +
                        aspect_ratio = aspect_ratio.to(dtype=dtype, device=device)
         | 
| 976 |  | 
| 977 | 
            +
                        if self.do_classifier_free_guidance:
         | 
| 978 | 
             
                            resolution = torch.cat([resolution, resolution], dim=0)
         | 
| 979 | 
             
                            aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
         | 
| 980 |  | 
| 981 | 
             
                        added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
         | 
| 982 |  | 
| 983 | 
            +
                    if self.do_classifier_free_guidance:
         | 
| 984 | 
            +
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
         | 
| 985 | 
            +
                        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
         | 
| 986 | 
            +
                        if prompt_embeds_2 is not None:
         | 
| 987 | 
            +
                            prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
         | 
| 988 | 
            +
                            prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
         | 
| 989 | 
            +
             | 
| 990 | 
            +
                    # To latents.device
         | 
| 991 | 
            +
                    prompt_embeds = prompt_embeds.to(device=device)
         | 
| 992 | 
            +
                    prompt_attention_mask = prompt_attention_mask.to(device=device)
         | 
| 993 | 
            +
                    if prompt_embeds_2 is not None:
         | 
| 994 | 
            +
                        prompt_embeds_2 = prompt_embeds_2.to(device=device)
         | 
| 995 | 
            +
                        prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
         | 
| 996 | 
            +
             | 
| 997 | 
            +
                    # 8. Denoising loop
         | 
| 998 | 
            +
                    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
         | 
| 999 | 
            +
                    self._num_timesteps = len(timesteps)
         | 
| 1000 | 
             
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 1001 | 
             
                        for i, t in enumerate(timesteps):
         | 
| 1002 | 
            +
                            if self.interrupt:
         | 
| 1003 | 
            +
                                continue
         | 
| 1004 | 
            +
             | 
| 1005 | 
            +
                            # expand the latents if we are doing classifier free guidance
         | 
| 1006 | 
            +
                            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
         | 
| 1007 | 
            +
                            if hasattr(self.scheduler, "scale_model_input"):
         | 
| 1008 | 
            +
                                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         | 
| 1009 | 
            +
             | 
| 1010 | 
            +
                            # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
         | 
| 1011 | 
            +
                            t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
         | 
| 1012 | 
            +
                                dtype=latent_model_input.dtype
         | 
| 1013 | 
            +
                            )
         | 
| 1014 | 
            +
             | 
| 1015 | 
            +
                            # predict the noise residual
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 1016 | 
             
                            noise_pred = self.transformer(
         | 
| 1017 | 
             
                                latent_model_input,
         | 
| 1018 | 
            +
                                t_expand,
         | 
| 1019 | 
             
                                encoder_hidden_states=prompt_embeds,
         | 
| 1020 | 
            +
                                text_embedding_mask=prompt_attention_mask,
         | 
| 1021 | 
            +
                                encoder_hidden_states_t5=prompt_embeds_2,
         | 
| 1022 | 
            +
                                text_embedding_mask_t5=prompt_attention_mask_2,
         | 
| 1023 | 
            +
                                image_meta_size=add_time_ids,
         | 
| 1024 | 
            +
                                style=style,
         | 
| 1025 | 
            +
                                image_rotary_emb=image_rotary_emb,
         | 
| 1026 | 
             
                                added_cond_kwargs=added_cond_kwargs,
         | 
| 1027 | 
             
                                return_dict=False,
         | 
| 1028 | 
             
                            )[0]
         | 
| 1029 | 
            +
                            
         | 
| 1030 | 
            +
                            if noise_pred.size()[1] != self.vae.config.latent_channels:
         | 
| 1031 | 
            +
                                noise_pred, _ = noise_pred.chunk(2, dim=1)
         | 
| 1032 |  | 
| 1033 | 
             
                            # perform guidance
         | 
| 1034 | 
            +
                            if self.do_classifier_free_guidance:
         | 
| 1035 | 
             
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 1036 | 
             
                                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 1037 |  | 
| 1038 | 
            +
                            if self.do_classifier_free_guidance and guidance_rescale > 0.0:
         | 
| 1039 | 
            +
                                # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
         | 
| 1040 | 
            +
                                noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
         | 
|  | |
|  | |
| 1041 |  | 
| 1042 | 
            +
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 1043 | 
             
                            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
         | 
| 1044 |  | 
| 1045 | 
            +
                            if callback_on_step_end is not None:
         | 
| 1046 | 
            +
                                callback_kwargs = {}
         | 
| 1047 | 
            +
                                for k in callback_on_step_end_tensor_inputs:
         | 
| 1048 | 
            +
                                    callback_kwargs[k] = locals()[k]
         | 
| 1049 | 
            +
                                callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
         | 
| 1050 | 
            +
             | 
| 1051 | 
            +
                                latents = callback_outputs.pop("latents", latents)
         | 
| 1052 | 
            +
                                prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
         | 
| 1053 | 
            +
                                negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
         | 
| 1054 | 
            +
                                prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
         | 
| 1055 | 
            +
                                negative_prompt_embeds_2 = callback_outputs.pop(
         | 
| 1056 | 
            +
                                    "negative_prompt_embeds_2", negative_prompt_embeds_2
         | 
| 1057 | 
            +
                                )
         | 
| 1058 | 
            +
             | 
| 1059 | 
             
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
| 1060 | 
             
                                progress_bar.update()
         | 
| 1061 | 
            +
             | 
| 1062 | 
            +
                            if XLA_AVAILABLE:
         | 
| 1063 | 
            +
                                xm.mark_step()
         | 
| 1064 |  | 
| 1065 | 
             
                            if comfyui_progressbar:
         | 
| 1066 | 
             
                                pbar.update(1)
         | 
| 1067 |  | 
|  | |
|  | |
|  | |
| 1068 | 
             
                    # Post-processing
         | 
| 1069 | 
             
                    video = self.decode_latents(latents)
         | 
| 1070 |  | 
|  | |
| 1072 | 
             
                    if output_type == "latent":
         | 
| 1073 | 
             
                        video = torch.from_numpy(video)
         | 
| 1074 |  | 
| 1075 | 
            +
                    # Offload all models
         | 
| 1076 | 
            +
                    self.maybe_free_model_hooks()
         | 
| 1077 | 
            +
             | 
| 1078 | 
             
                    if not return_dict:
         | 
| 1079 | 
             
                        return video
         | 
| 1080 |  | 
| 1081 | 
            +
                    return EasyAnimatePipelineOutput(frames=video)
         | 
    	
        easyanimate/pipeline/{pipeline_easyanimate_multi_text_encoder_control.py → pipeline_easyanimate_control.py}
    RENAMED
    
    | @@ -31,7 +31,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
| 31 | 
             
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
         | 
| 32 | 
             
            from diffusers.pipelines.stable_diffusion.safety_checker import \
         | 
| 33 | 
             
                StableDiffusionSafetyChecker
         | 
| 34 | 
            -
            from diffusers.schedulers import DDIMScheduler, DPMSolverMultistepScheduler
         | 
|  | |
| 35 | 
             
            from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
         | 
| 36 | 
             
                                         is_bs4_available, is_ftfy_available,
         | 
| 37 | 
             
                                         is_torch_xla_available, logging,
         | 
| @@ -41,11 +42,12 @@ from einops import rearrange | |
| 41 | 
             
            from PIL import Image
         | 
| 42 | 
             
            from tqdm import tqdm
         | 
| 43 | 
             
            from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
         | 
| 44 | 
            -
                                      CLIPVisionModelWithProjection,
         | 
| 45 | 
            -
                                      T5EncoderModel, | 
|  | |
| 46 |  | 
| 47 | 
             
            from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
         | 
| 48 | 
            -
            from . | 
| 49 |  | 
| 50 | 
             
            if is_torch_xla_available():
         | 
| 51 | 
             
                import torch_xla.core.xla_model as xm
         | 
| @@ -64,6 +66,7 @@ EXAMPLE_DOC_STRING = """ | |
| 64 | 
             
                    ```
         | 
| 65 | 
             
            """
         | 
| 66 |  | 
|  | |
| 67 | 
             
            def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
         | 
| 68 | 
             
                tw = tgt_width
         | 
| 69 | 
             
                th = tgt_height
         | 
| @@ -97,44 +100,140 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | |
| 97 | 
             
                return noise_cfg
         | 
| 98 |  | 
| 99 |  | 
| 100 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 101 | 
             
                r"""
         | 
| 102 | 
             
                Pipeline for text-to-video generation using EasyAnimate.
         | 
| 103 |  | 
| 104 | 
             
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
         | 
| 105 | 
             
                library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
         | 
| 106 |  | 
|  | |
| 107 | 
             
                EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
         | 
| 108 | 
            -
                HunyuanDiT team)
         | 
| 109 |  | 
| 110 | 
             
                Args:
         | 
| 111 | 
             
                    vae ([`AutoencoderKLMagvit`]):
         | 
| 112 | 
             
                        Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. 
         | 
| 113 | 
            -
                    text_encoder (Optional[`~transformers. | 
| 114 | 
            -
                         | 
| 115 | 
            -
                        EasyAnimate uses  | 
| 116 | 
            -
                    tokenizer (Optional[`~transformers. | 
| 117 | 
            -
                        A ` | 
| 118 | 
             
                    transformer ([`EasyAnimateTransformer3DModel`]):
         | 
| 119 | 
            -
                        The EasyAnimate model designed by  | 
| 120 | 
             
                    text_encoder_2 (`T5EncoderModel`):
         | 
| 121 | 
            -
                         | 
|  | |
| 122 | 
             
                    tokenizer_2 (`T5Tokenizer`):
         | 
| 123 | 
             
                        The tokenizer for the mT5 embedder.
         | 
| 124 | 
            -
                    scheduler ([` | 
| 125 | 
             
                        A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
         | 
| 126 | 
             
                """
         | 
| 127 |  | 
| 128 | 
             
                model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
         | 
| 129 | 
             
                _optional_components = [
         | 
| 130 | 
            -
                    "safety_checker",
         | 
| 131 | 
            -
                    "feature_extractor",
         | 
| 132 | 
             
                    "text_encoder_2",
         | 
| 133 | 
             
                    "tokenizer_2",
         | 
| 134 | 
             
                    "text_encoder",
         | 
| 135 | 
             
                    "tokenizer",
         | 
| 136 | 
             
                ]
         | 
| 137 | 
            -
                _exclude_from_cpu_offload = ["safety_checker"]
         | 
| 138 | 
             
                _callback_tensor_inputs = [
         | 
| 139 | 
             
                    "latents",
         | 
| 140 | 
             
                    "prompt_embeds",
         | 
| @@ -146,53 +245,30 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 146 | 
             
                def __init__(
         | 
| 147 | 
             
                    self,
         | 
| 148 | 
             
                    vae: AutoencoderKLMagvit,
         | 
| 149 | 
            -
                    text_encoder: BertModel,
         | 
| 150 | 
            -
                    tokenizer: BertTokenizer,
         | 
| 151 | 
            -
                    text_encoder_2: T5EncoderModel,
         | 
| 152 | 
            -
                    tokenizer_2: T5Tokenizer,
         | 
| 153 | 
             
                    transformer: EasyAnimateTransformer3DModel,
         | 
| 154 | 
            -
                    scheduler:  | 
| 155 | 
            -
                    safety_checker: StableDiffusionSafetyChecker,
         | 
| 156 | 
            -
                    feature_extractor: CLIPImageProcessor,
         | 
| 157 | 
            -
                    requires_safety_checker: bool = True
         | 
| 158 | 
             
                ):
         | 
| 159 | 
             
                    super().__init__()
         | 
| 160 |  | 
| 161 | 
             
                    self.register_modules(
         | 
| 162 | 
             
                        vae=vae,
         | 
| 163 | 
             
                        text_encoder=text_encoder,
         | 
|  | |
| 164 | 
             
                        tokenizer=tokenizer,
         | 
| 165 | 
             
                        tokenizer_2=tokenizer_2,
         | 
| 166 | 
             
                        transformer=transformer,
         | 
| 167 | 
             
                        scheduler=scheduler,
         | 
| 168 | 
            -
                        safety_checker=safety_checker,
         | 
| 169 | 
            -
                        feature_extractor=feature_extractor,
         | 
| 170 | 
            -
                        text_encoder_2=text_encoder_2
         | 
| 171 | 
             
                    )
         | 
| 172 |  | 
| 173 | 
            -
                    if safety_checker is None and requires_safety_checker:
         | 
| 174 | 
            -
                        logger.warning(
         | 
| 175 | 
            -
                            f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
         | 
| 176 | 
            -
                            " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
         | 
| 177 | 
            -
                            " results in services or applications open to the public. Both the diffusers team and Hugging Face"
         | 
| 178 | 
            -
                            " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
         | 
| 179 | 
            -
                            " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
         | 
| 180 | 
            -
                            " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
         | 
| 181 | 
            -
                        )
         | 
| 182 | 
            -
             | 
| 183 | 
            -
                    if safety_checker is not None and feature_extractor is None:
         | 
| 184 | 
            -
                        raise ValueError(
         | 
| 185 | 
            -
                            "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
         | 
| 186 | 
            -
                            " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
         | 
| 187 | 
            -
                        )
         | 
| 188 | 
            -
             | 
| 189 | 
             
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         | 
| 190 | 
             
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         | 
| 191 | 
             
                    self.mask_processor = VaeImageProcessor(
         | 
| 192 | 
             
                        vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
         | 
| 193 | 
             
                    )
         | 
| 194 | 
            -
                    self.enable_autocast_float8_transformer_flag = False
         | 
| 195 | 
            -
                    self.register_to_config(requires_safety_checker=requires_safety_checker)
         | 
| 196 |  | 
| 197 | 
             
                def enable_sequential_cpu_offload(self, *args, **kwargs):
         | 
| 198 | 
             
                    super().enable_sequential_cpu_offload(*args, **kwargs)
         | 
| @@ -272,19 +348,9 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 272 | 
             
                        batch_size = prompt_embeds.shape[0]
         | 
| 273 |  | 
| 274 | 
             
                    if prompt_embeds is None:
         | 
| 275 | 
            -
                         | 
| 276 | 
            -
                            prompt,
         | 
| 277 | 
            -
                            padding="max_length",
         | 
| 278 | 
            -
                            max_length=max_length,
         | 
| 279 | 
            -
                            truncation=True,
         | 
| 280 | 
            -
                            return_attention_mask=True,
         | 
| 281 | 
            -
                            return_tensors="pt",
         | 
| 282 | 
            -
                        )
         | 
| 283 | 
            -
                        text_input_ids = text_inputs.input_ids
         | 
| 284 | 
            -
                        if text_input_ids.shape[-1] > actual_max_sequence_length:
         | 
| 285 | 
            -
                            reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
         | 
| 286 | 
             
                            text_inputs = tokenizer(
         | 
| 287 | 
            -
                                 | 
| 288 | 
             
                                padding="max_length",
         | 
| 289 | 
             
                                max_length=max_length,
         | 
| 290 | 
             
                                truncation=True,
         | 
| @@ -292,91 +358,188 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 292 | 
             
                                return_tensors="pt",
         | 
| 293 | 
             
                            )
         | 
| 294 | 
             
                            text_input_ids = text_inputs.input_ids
         | 
| 295 | 
            -
             | 
| 296 | 
            -
             | 
| 297 | 
            -
             | 
| 298 | 
            -
             | 
| 299 | 
            -
             | 
| 300 | 
            -
             | 
| 301 | 
            -
             | 
| 302 | 
            -
             | 
| 303 | 
            -
             | 
| 304 | 
            -
                                 | 
| 305 | 
            -
             | 
| 306 | 
            -
             | 
| 307 | 
            -
             | 
| 308 | 
            -
                             | 
| 309 | 
            -
                                text_input_ids | 
| 310 | 
            -
             | 
| 311 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 312 | 
             
                        else:
         | 
| 313 | 
            -
                             | 
| 314 | 
            -
                                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 315 | 
             
                            )
         | 
| 316 | 
            -
                        prompt_embeds = prompt_embeds[0]
         | 
| 317 | 
            -
                        prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 318 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 319 | 
             
                    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
         | 
| 320 |  | 
| 321 | 
             
                    bs_embed, seq_len, _ = prompt_embeds.shape
         | 
| 322 | 
             
                    # duplicate text embeddings for each generation per prompt, using mps friendly method
         | 
| 323 | 
             
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 324 | 
             
                    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
|  | |
| 325 |  | 
| 326 | 
             
                    # get unconditional embeddings for classifier free guidance
         | 
| 327 | 
             
                    if do_classifier_free_guidance and negative_prompt_embeds is None:
         | 
| 328 | 
            -
                         | 
| 329 | 
            -
             | 
| 330 | 
            -
                             | 
| 331 | 
            -
             | 
| 332 | 
            -
                             | 
| 333 | 
            -
                                 | 
| 334 | 
            -
             | 
| 335 | 
            -
             | 
| 336 | 
            -
             | 
| 337 | 
            -
                             | 
| 338 | 
            -
             | 
| 339 | 
            -
                             | 
| 340 | 
            -
                                 | 
| 341 | 
            -
             | 
| 342 | 
            -
             | 
| 343 | 
            -
             | 
| 344 | 
            -
             | 
| 345 | 
            -
                             | 
| 346 | 
            -
             | 
| 347 | 
            -
             | 
| 348 | 
            -
             | 
| 349 | 
            -
                            uncond_tokens,
         | 
| 350 | 
            -
                            padding="max_length",
         | 
| 351 | 
            -
                            max_length=max_length,
         | 
| 352 | 
            -
                            truncation=True,
         | 
| 353 | 
            -
                            return_tensors="pt",
         | 
| 354 | 
            -
                        )
         | 
| 355 | 
            -
                        uncond_input_ids = uncond_input.input_ids
         | 
| 356 | 
            -
                        if uncond_input_ids.shape[-1] > actual_max_sequence_length:
         | 
| 357 | 
            -
                            reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
         | 
| 358 | 
             
                            uncond_input = tokenizer(
         | 
| 359 | 
            -
                                 | 
| 360 | 
             
                                padding="max_length",
         | 
| 361 | 
             
                                max_length=max_length,
         | 
| 362 | 
             
                                truncation=True,
         | 
| 363 | 
            -
                                return_attention_mask=True,
         | 
| 364 | 
             
                                return_tensors="pt",
         | 
| 365 | 
             
                            )
         | 
| 366 | 
             
                            uncond_input_ids = uncond_input.input_ids
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 367 |  | 
| 368 | 
            -
             | 
| 369 | 
            -
             | 
| 370 | 
            -
             | 
| 371 | 
            -
             | 
| 372 | 
            -
             | 
| 373 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 374 | 
             
                        else:
         | 
| 375 | 
            -
                             | 
| 376 | 
            -
                                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 377 | 
             
                            )
         | 
| 378 | 
            -
             | 
| 379 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 380 |  | 
| 381 | 
             
                    if do_classifier_free_guidance:
         | 
| 382 | 
             
                        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
         | 
| @@ -386,24 +549,10 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 386 |  | 
| 387 | 
             
                        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 388 | 
             
                        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
         | 
| 389 | 
            -
             | 
|  | |
| 390 | 
             
                    return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
         | 
| 391 |  | 
| 392 | 
            -
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
         | 
| 393 | 
            -
                def run_safety_checker(self, image, device, dtype):
         | 
| 394 | 
            -
                    if self.safety_checker is None:
         | 
| 395 | 
            -
                        has_nsfw_concept = None
         | 
| 396 | 
            -
                    else:
         | 
| 397 | 
            -
                        if torch.is_tensor(image):
         | 
| 398 | 
            -
                            feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
         | 
| 399 | 
            -
                        else:
         | 
| 400 | 
            -
                            feature_extractor_input = self.image_processor.numpy_to_pil(image)
         | 
| 401 | 
            -
                        safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
         | 
| 402 | 
            -
                        image, has_nsfw_concept = self.safety_checker(
         | 
| 403 | 
            -
                            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
         | 
| 404 | 
            -
                        )
         | 
| 405 | 
            -
                    return image, has_nsfw_concept
         | 
| 406 | 
            -
             | 
| 407 | 
             
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
         | 
| 408 | 
             
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 409 | 
             
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
| @@ -438,8 +587,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 438 | 
             
                    negative_prompt_attention_mask_2=None,
         | 
| 439 | 
             
                    callback_on_step_end_tensor_inputs=None,
         | 
| 440 | 
             
                ):
         | 
| 441 | 
            -
                    if height %  | 
| 442 | 
            -
                        raise ValueError(f"`height` and `width` have to be divisible by  | 
| 443 |  | 
| 444 | 
             
                    if callback_on_step_end_tensor_inputs is not None and not all(
         | 
| 445 | 
             
                        k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
         | 
| @@ -524,43 +673,44 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 524 | 
             
                        latents = latents.to(device)
         | 
| 525 |  | 
| 526 | 
             
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 527 | 
            -
                     | 
|  | |
| 528 | 
             
                    return latents
         | 
| 529 |  | 
| 530 | 
             
                def prepare_control_latents(
         | 
| 531 | 
            -
                    self,  | 
| 532 | 
             
                ):
         | 
| 533 | 
            -
                    # resize the  | 
| 534 | 
             
                    # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
         | 
| 535 | 
             
                    # and half precision
         | 
| 536 |  | 
| 537 | 
            -
                    if  | 
| 538 | 
            -
                         | 
| 539 | 
             
                        bs = 1
         | 
| 540 | 
            -
                         | 
| 541 | 
            -
                        for i in range(0,  | 
| 542 | 
            -
                             | 
| 543 | 
            -
                             | 
| 544 | 
            -
                             | 
| 545 | 
            -
                             | 
| 546 | 
            -
                         | 
| 547 | 
            -
                         | 
| 548 | 
            -
             | 
| 549 | 
            -
                    if  | 
| 550 | 
            -
                         | 
| 551 | 
             
                        bs = 1
         | 
| 552 | 
            -
                         | 
| 553 | 
            -
                        for i in range(0,  | 
| 554 | 
            -
                             | 
| 555 | 
            -
                             | 
| 556 | 
            -
                             | 
| 557 | 
            -
                             | 
| 558 | 
            -
                         | 
| 559 | 
            -
                         | 
| 560 | 
             
                    else:
         | 
| 561 | 
            -
                         | 
| 562 |  | 
| 563 | 
            -
                    return  | 
| 564 |  | 
| 565 | 
             
                def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
         | 
| 566 | 
             
                    if video.size()[2] <= mini_batch_encoder:
         | 
| @@ -623,9 +773,6 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 623 | 
             
                def interrupt(self):
         | 
| 624 | 
             
                    return self._interrupt
         | 
| 625 |  | 
| 626 | 
            -
                def enable_autocast_float8_transformer(self):
         | 
| 627 | 
            -
                    self.enable_autocast_float8_transformer_flag = True
         | 
| 628 | 
            -
             | 
| 629 | 
             
                @torch.no_grad()
         | 
| 630 | 
             
                @replace_example_docstring(EXAMPLE_DOC_STRING)
         | 
| 631 | 
             
                def __call__(
         | 
| @@ -635,6 +782,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 635 | 
             
                    height: Optional[int] = None,
         | 
| 636 | 
             
                    width: Optional[int] = None,
         | 
| 637 | 
             
                    control_video: Union[torch.FloatTensor] = None,
         | 
|  | |
|  | |
| 638 | 
             
                    num_inference_steps: Optional[int] = 50,
         | 
| 639 | 
             
                    guidance_scale: Optional[float] = 5.0,
         | 
| 640 | 
             
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| @@ -661,6 +810,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 661 | 
             
                    target_size: Optional[Tuple[int, int]] = None,
         | 
| 662 | 
             
                    crops_coords_top_left: Tuple[int, int] = (0, 0),
         | 
| 663 | 
             
                    comfyui_progressbar: bool = False,
         | 
|  | |
| 664 | 
             
                ):
         | 
| 665 | 
             
                    r"""
         | 
| 666 | 
             
                    Generates images or video using the EasyAnimate pipeline based on the provided prompts.
         | 
| @@ -765,6 +915,12 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 765 | 
             
                        batch_size = prompt_embeds.shape[0]
         | 
| 766 |  | 
| 767 | 
             
                    device = self._execution_device
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 768 |  | 
| 769 | 
             
                    # 3. Encode input prompt
         | 
| 770 | 
             
                    (
         | 
| @@ -775,7 +931,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 775 | 
             
                    ) = self.encode_prompt(
         | 
| 776 | 
             
                        prompt=prompt,
         | 
| 777 | 
             
                        device=device,
         | 
| 778 | 
            -
                        dtype= | 
| 779 | 
             
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 780 | 
             
                        do_classifier_free_guidance=self.do_classifier_free_guidance,
         | 
| 781 | 
             
                        negative_prompt=negative_prompt,
         | 
| @@ -785,28 +941,36 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 785 | 
             
                        negative_prompt_attention_mask=negative_prompt_attention_mask,
         | 
| 786 | 
             
                        text_encoder_index=0,
         | 
| 787 | 
             
                    )
         | 
| 788 | 
            -
                     | 
| 789 | 
            -
                         | 
| 790 | 
            -
             | 
| 791 | 
            -
             | 
| 792 | 
            -
             | 
| 793 | 
            -
             | 
| 794 | 
            -
                         | 
| 795 | 
            -
             | 
| 796 | 
            -
             | 
| 797 | 
            -
             | 
| 798 | 
            -
             | 
| 799 | 
            -
             | 
| 800 | 
            -
             | 
| 801 | 
            -
             | 
| 802 | 
            -
             | 
| 803 | 
            -
             | 
| 804 | 
            -
             | 
| 805 | 
            -
             | 
| 806 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 807 |  | 
| 808 | 
             
                    # 4. Prepare timesteps
         | 
| 809 | 
            -
                    self.scheduler | 
|  | |
|  | |
|  | |
| 810 | 
             
                    timesteps = self.scheduler.timesteps
         | 
| 811 | 
             
                    if comfyui_progressbar:
         | 
| 812 | 
             
                        from comfy.utils import ProgressBar
         | 
| @@ -820,7 +984,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 820 | 
             
                        video_length,
         | 
| 821 | 
             
                        height,
         | 
| 822 | 
             
                        width,
         | 
| 823 | 
            -
                         | 
| 824 | 
             
                        device,
         | 
| 825 | 
             
                        generator,
         | 
| 826 | 
             
                        latents,
         | 
| @@ -828,27 +992,69 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 828 | 
             
                    if comfyui_progressbar:
         | 
| 829 | 
             
                        pbar.update(1)
         | 
| 830 |  | 
| 831 | 
            -
                    if  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 832 | 
             
                        video_length = control_video.shape[2]
         | 
| 833 | 
             
                        control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) 
         | 
| 834 | 
             
                        control_video = control_video.to(dtype=torch.float32)
         | 
| 835 | 
             
                        control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 836 | 
             
                    else:
         | 
| 837 | 
            -
                         | 
| 838 | 
            -
             | 
| 839 | 
            -
             | 
| 840 | 
            -
                         | 
| 841 | 
            -
                         | 
| 842 | 
            -
             | 
| 843 | 
            -
                         | 
| 844 | 
            -
                         | 
| 845 | 
            -
                         | 
| 846 | 
            -
                         | 
| 847 | 
            -
                         | 
| 848 | 
            -
             | 
| 849 | 
            -
             | 
| 850 | 
            -
             | 
| 851 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 852 |  | 
| 853 | 
             
                    if comfyui_progressbar:
         | 
| 854 | 
             
                        pbar.update(1)
         | 
| @@ -880,34 +1086,49 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 880 | 
             
                        )
         | 
| 881 |  | 
| 882 | 
             
                    # Get other hunyuan params
         | 
| 883 | 
            -
                    style = torch.tensor([0], device=device)
         | 
| 884 | 
            -
             | 
| 885 | 
             
                    target_size = target_size or (height, width)
         | 
| 886 | 
             
                    add_time_ids = list(original_size + target_size + crops_coords_top_left)
         | 
| 887 | 
            -
                    add_time_ids = torch.tensor([add_time_ids], dtype= | 
|  | |
| 888 |  | 
| 889 | 
             
                    if self.do_classifier_free_guidance:
         | 
| 890 | 
            -
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
         | 
| 891 | 
            -
                        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
         | 
| 892 | 
            -
                        prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
         | 
| 893 | 
            -
                        prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
         | 
| 894 | 
             
                        add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
         | 
| 895 | 
             
                        style = torch.cat([style] * 2, dim=0)
         | 
| 896 |  | 
| 897 | 
             
                    # To latents.device
         | 
| 898 | 
            -
                     | 
| 899 | 
            -
                    prompt_attention_mask = prompt_attention_mask.to(device=device)
         | 
| 900 | 
            -
                    prompt_embeds_2 = prompt_embeds_2.to(device=device)
         | 
| 901 | 
            -
                    prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
         | 
| 902 | 
            -
                    add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
         | 
| 903 | 
             
                        batch_size * num_images_per_prompt, 1
         | 
| 904 | 
             
                    )
         | 
| 905 | 
             
                    style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
         | 
| 906 |  | 
| 907 | 
            -
                     | 
| 908 | 
            -
                     | 
| 909 | 
            -
             | 
| 910 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 911 | 
             
                    # 8. Denoising loop
         | 
| 912 | 
             
                    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
         | 
| 913 | 
             
                    self._num_timesteps = len(timesteps)
         | 
| @@ -918,7 +1139,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 918 |  | 
| 919 | 
             
                            # expand the latents if we are doing classifier free guidance
         | 
| 920 | 
             
                            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
         | 
| 921 | 
            -
                             | 
|  | |
| 922 |  | 
| 923 | 
             
                            # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
         | 
| 924 | 
             
                            t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
         | 
| @@ -935,8 +1157,9 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 935 | 
             
                                image_meta_size=add_time_ids,
         | 
| 936 | 
             
                                style=style,
         | 
| 937 | 
             
                                image_rotary_emb=image_rotary_emb,
         | 
| 938 | 
            -
                                 | 
| 939 | 
             
                                control_latents=control_latents,
         | 
|  | |
| 940 | 
             
                            )[0]
         | 
| 941 | 
             
                            if noise_pred.size()[1] != self.vae.config.latent_channels:
         | 
| 942 | 
             
                                noise_pred, _ = noise_pred.chunk(2, dim=1)
         | 
| @@ -976,10 +1199,6 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 976 | 
             
                            if comfyui_progressbar:
         | 
| 977 | 
             
                                pbar.update(1)
         | 
| 978 |  | 
| 979 | 
            -
                    if self.enable_autocast_float8_transformer_flag:
         | 
| 980 | 
            -
                        self.transformer = self.transformer.to("cpu", origin_weight_dtype)
         | 
| 981 | 
            -
             | 
| 982 | 
            -
                    torch.cuda.empty_cache()
         | 
| 983 | 
             
                    # Post-processing
         | 
| 984 | 
             
                    video = self.decode_latents(latents)
         | 
| 985 |  | 
| @@ -993,4 +1212,4 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): | |
| 993 | 
             
                    if not return_dict:
         | 
| 994 | 
             
                        return video
         | 
| 995 |  | 
| 996 | 
            -
                    return EasyAnimatePipelineOutput( | 
|  | |
| 31 | 
             
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
         | 
| 32 | 
             
            from diffusers.pipelines.stable_diffusion.safety_checker import \
         | 
| 33 | 
             
                StableDiffusionSafetyChecker
         | 
| 34 | 
            +
            from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
         | 
| 35 | 
            +
                                              FlowMatchEulerDiscreteScheduler)
         | 
| 36 | 
             
            from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
         | 
| 37 | 
             
                                         is_bs4_available, is_ftfy_available,
         | 
| 38 | 
             
                                         is_torch_xla_available, logging,
         | 
|  | |
| 42 | 
             
            from PIL import Image
         | 
| 43 | 
             
            from tqdm import tqdm
         | 
| 44 | 
             
            from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
         | 
| 45 | 
            +
                                      CLIPVisionModelWithProjection, Qwen2Tokenizer,
         | 
| 46 | 
            +
                                      Qwen2VLForConditionalGeneration, T5EncoderModel,
         | 
| 47 | 
            +
                                      T5Tokenizer)
         | 
| 48 |  | 
| 49 | 
             
            from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
         | 
| 50 | 
            +
            from .pipeline_easyanimate_inpaint import EasyAnimatePipelineOutput
         | 
| 51 |  | 
| 52 | 
             
            if is_torch_xla_available():
         | 
| 53 | 
             
                import torch_xla.core.xla_model as xm
         | 
|  | |
| 66 | 
             
                    ```
         | 
| 67 | 
             
            """
         | 
| 68 |  | 
| 69 | 
            +
            # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
         | 
| 70 | 
             
            def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
         | 
| 71 | 
             
                tw = tgt_width
         | 
| 72 | 
             
                th = tgt_height
         | 
|  | |
| 100 | 
             
                return noise_cfg
         | 
| 101 |  | 
| 102 |  | 
| 103 | 
            +
            # Resize mask information in magvit
         | 
| 104 | 
            +
            def resize_mask(mask, latent, process_first_frame_only=True):
         | 
| 105 | 
            +
                latent_size = latent.size()
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                if process_first_frame_only:
         | 
| 108 | 
            +
                    target_size = list(latent_size[2:])
         | 
| 109 | 
            +
                    target_size[0] = 1
         | 
| 110 | 
            +
                    first_frame_resized = F.interpolate(
         | 
| 111 | 
            +
                        mask[:, :, 0:1, :, :],
         | 
| 112 | 
            +
                        size=target_size,
         | 
| 113 | 
            +
                        mode='trilinear',
         | 
| 114 | 
            +
                        align_corners=False
         | 
| 115 | 
            +
                    )
         | 
| 116 | 
            +
                    
         | 
| 117 | 
            +
                    target_size = list(latent_size[2:])
         | 
| 118 | 
            +
                    target_size[0] = target_size[0] - 1
         | 
| 119 | 
            +
                    if target_size[0] != 0:
         | 
| 120 | 
            +
                        remaining_frames_resized = F.interpolate(
         | 
| 121 | 
            +
                            mask[:, :, 1:, :, :],
         | 
| 122 | 
            +
                            size=target_size,
         | 
| 123 | 
            +
                            mode='trilinear',
         | 
| 124 | 
            +
                            align_corners=False
         | 
| 125 | 
            +
                        )
         | 
| 126 | 
            +
                        resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        resized_mask = first_frame_resized
         | 
| 129 | 
            +
                else:
         | 
| 130 | 
            +
                    target_size = list(latent_size[2:])
         | 
| 131 | 
            +
                    resized_mask = F.interpolate(
         | 
| 132 | 
            +
                        mask,
         | 
| 133 | 
            +
                        size=target_size,
         | 
| 134 | 
            +
                        mode='trilinear',
         | 
| 135 | 
            +
                        align_corners=False
         | 
| 136 | 
            +
                    )
         | 
| 137 | 
            +
                return resized_mask
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
         | 
| 141 | 
            +
            def retrieve_timesteps(
         | 
| 142 | 
            +
                scheduler,
         | 
| 143 | 
            +
                num_inference_steps: Optional[int] = None,
         | 
| 144 | 
            +
                device: Optional[Union[str, torch.device]] = None,
         | 
| 145 | 
            +
                timesteps: Optional[List[int]] = None,
         | 
| 146 | 
            +
                sigmas: Optional[List[float]] = None,
         | 
| 147 | 
            +
                **kwargs,
         | 
| 148 | 
            +
            ):
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
                Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
         | 
| 151 | 
            +
                custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                Args:
         | 
| 154 | 
            +
                    scheduler (`SchedulerMixin`):
         | 
| 155 | 
            +
                        The scheduler to get timesteps from.
         | 
| 156 | 
            +
                    num_inference_steps (`int`):
         | 
| 157 | 
            +
                        The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
         | 
| 158 | 
            +
                        must be `None`.
         | 
| 159 | 
            +
                    device (`str` or `torch.device`, *optional*):
         | 
| 160 | 
            +
                        The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
         | 
| 161 | 
            +
                    timesteps (`List[int]`, *optional*):
         | 
| 162 | 
            +
                        Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
         | 
| 163 | 
            +
                        `num_inference_steps` and `sigmas` must be `None`.
         | 
| 164 | 
            +
                    sigmas (`List[float]`, *optional*):
         | 
| 165 | 
            +
                        Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
         | 
| 166 | 
            +
                        `num_inference_steps` and `timesteps` must be `None`.
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                Returns:
         | 
| 169 | 
            +
                    `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
         | 
| 170 | 
            +
                    second element is the number of inference steps.
         | 
| 171 | 
            +
                """
         | 
| 172 | 
            +
                if timesteps is not None and sigmas is not None:
         | 
| 173 | 
            +
                    raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
         | 
| 174 | 
            +
                if timesteps is not None:
         | 
| 175 | 
            +
                    accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
         | 
| 176 | 
            +
                    if not accepts_timesteps:
         | 
| 177 | 
            +
                        raise ValueError(
         | 
| 178 | 
            +
                            f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
         | 
| 179 | 
            +
                            f" timestep schedules. Please check whether you are using the correct scheduler."
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
                    scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
         | 
| 182 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 183 | 
            +
                    num_inference_steps = len(timesteps)
         | 
| 184 | 
            +
                elif sigmas is not None:
         | 
| 185 | 
            +
                    accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
         | 
| 186 | 
            +
                    if not accept_sigmas:
         | 
| 187 | 
            +
                        raise ValueError(
         | 
| 188 | 
            +
                            f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
         | 
| 189 | 
            +
                            f" sigmas schedules. Please check whether you are using the correct scheduler."
         | 
| 190 | 
            +
                        )
         | 
| 191 | 
            +
                    scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
         | 
| 192 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 193 | 
            +
                    num_inference_steps = len(timesteps)
         | 
| 194 | 
            +
                else:
         | 
| 195 | 
            +
                    scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
         | 
| 196 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 197 | 
            +
                return timesteps, num_inference_steps
         | 
| 198 | 
            +
             | 
| 199 | 
            +
             | 
| 200 | 
            +
            class EasyAnimateControlPipeline(DiffusionPipeline):
         | 
| 201 | 
             
                r"""
         | 
| 202 | 
             
                Pipeline for text-to-video generation using EasyAnimate.
         | 
| 203 |  | 
| 204 | 
             
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
         | 
| 205 | 
             
                library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
         | 
| 206 |  | 
| 207 | 
            +
                EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
         | 
| 208 | 
             
                EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
         | 
| 209 | 
            +
                HunyuanDiT team) in V5.
         | 
| 210 |  | 
| 211 | 
             
                Args:
         | 
| 212 | 
             
                    vae ([`AutoencoderKLMagvit`]):
         | 
| 213 | 
             
                        Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. 
         | 
| 214 | 
            +
                    text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
         | 
| 215 | 
            +
                        EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
         | 
| 216 | 
            +
                        EasyAnimate uses [bilingual CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) in V5.
         | 
| 217 | 
            +
                    tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
         | 
| 218 | 
            +
                        A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
         | 
| 219 | 
             
                    transformer ([`EasyAnimateTransformer3DModel`]):
         | 
| 220 | 
            +
                        The EasyAnimate model designed by EasyAnimate Team.
         | 
| 221 | 
             
                    text_encoder_2 (`T5EncoderModel`):
         | 
| 222 | 
            +
                        EasyAnimate does not use text_encoder_2 in V5.1.
         | 
| 223 | 
            +
                        EasyAnimate uses [mT5](https://huggingface.co/google/mt5-base) embedder in V5.
         | 
| 224 | 
             
                    tokenizer_2 (`T5Tokenizer`):
         | 
| 225 | 
             
                        The tokenizer for the mT5 embedder.
         | 
| 226 | 
            +
                    scheduler ([`FlowMatchEulerDiscreteScheduler`]):
         | 
| 227 | 
             
                        A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
         | 
| 228 | 
             
                """
         | 
| 229 |  | 
| 230 | 
             
                model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
         | 
| 231 | 
             
                _optional_components = [
         | 
|  | |
|  | |
| 232 | 
             
                    "text_encoder_2",
         | 
| 233 | 
             
                    "tokenizer_2",
         | 
| 234 | 
             
                    "text_encoder",
         | 
| 235 | 
             
                    "tokenizer",
         | 
| 236 | 
             
                ]
         | 
|  | |
| 237 | 
             
                _callback_tensor_inputs = [
         | 
| 238 | 
             
                    "latents",
         | 
| 239 | 
             
                    "prompt_embeds",
         | 
|  | |
| 245 | 
             
                def __init__(
         | 
| 246 | 
             
                    self,
         | 
| 247 | 
             
                    vae: AutoencoderKLMagvit,
         | 
| 248 | 
            +
                    text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
         | 
| 249 | 
            +
                    tokenizer: Union[Qwen2Tokenizer, BertTokenizer], 
         | 
| 250 | 
            +
                    text_encoder_2: Optional[Union[T5EncoderModel, Qwen2VLForConditionalGeneration]],
         | 
| 251 | 
            +
                    tokenizer_2: Optional[Union[T5Tokenizer, Qwen2Tokenizer]],
         | 
| 252 | 
             
                    transformer: EasyAnimateTransformer3DModel,
         | 
| 253 | 
            +
                    scheduler: FlowMatchEulerDiscreteScheduler,
         | 
|  | |
|  | |
|  | |
| 254 | 
             
                ):
         | 
| 255 | 
             
                    super().__init__()
         | 
| 256 |  | 
| 257 | 
             
                    self.register_modules(
         | 
| 258 | 
             
                        vae=vae,
         | 
| 259 | 
             
                        text_encoder=text_encoder,
         | 
| 260 | 
            +
                        text_encoder_2=text_encoder_2,
         | 
| 261 | 
             
                        tokenizer=tokenizer,
         | 
| 262 | 
             
                        tokenizer_2=tokenizer_2,
         | 
| 263 | 
             
                        transformer=transformer,
         | 
| 264 | 
             
                        scheduler=scheduler,
         | 
|  | |
|  | |
|  | |
| 265 | 
             
                    )
         | 
| 266 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 267 | 
             
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         | 
| 268 | 
             
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         | 
| 269 | 
             
                    self.mask_processor = VaeImageProcessor(
         | 
| 270 | 
             
                        vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
         | 
| 271 | 
             
                    )
         | 
|  | |
|  | |
| 272 |  | 
| 273 | 
             
                def enable_sequential_cpu_offload(self, *args, **kwargs):
         | 
| 274 | 
             
                    super().enable_sequential_cpu_offload(*args, **kwargs)
         | 
|  | |
| 348 | 
             
                        batch_size = prompt_embeds.shape[0]
         | 
| 349 |  | 
| 350 | 
             
                    if prompt_embeds is None:
         | 
| 351 | 
            +
                        if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 352 | 
             
                            text_inputs = tokenizer(
         | 
| 353 | 
            +
                                prompt,
         | 
| 354 | 
             
                                padding="max_length",
         | 
| 355 | 
             
                                max_length=max_length,
         | 
| 356 | 
             
                                truncation=True,
         | 
|  | |
| 358 | 
             
                                return_tensors="pt",
         | 
| 359 | 
             
                            )
         | 
| 360 | 
             
                            text_input_ids = text_inputs.input_ids
         | 
| 361 | 
            +
                            if text_input_ids.shape[-1] > actual_max_sequence_length:
         | 
| 362 | 
            +
                                reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
         | 
| 363 | 
            +
                                text_inputs = tokenizer(
         | 
| 364 | 
            +
                                    reprompt,
         | 
| 365 | 
            +
                                    padding="max_length",
         | 
| 366 | 
            +
                                    max_length=max_length,
         | 
| 367 | 
            +
                                    truncation=True,
         | 
| 368 | 
            +
                                    return_attention_mask=True,
         | 
| 369 | 
            +
                                    return_tensors="pt",
         | 
| 370 | 
            +
                                )
         | 
| 371 | 
            +
                                text_input_ids = text_inputs.input_ids
         | 
| 372 | 
            +
                            untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
         | 
| 375 | 
            +
                                text_input_ids, untruncated_ids
         | 
| 376 | 
            +
                            ):
         | 
| 377 | 
            +
                                _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
         | 
| 378 | 
            +
                                removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
         | 
| 379 | 
            +
                                logger.warning(
         | 
| 380 | 
            +
                                    "The following part of your input was truncated because CLIP can only handle sequences up to"
         | 
| 381 | 
            +
                                    f" {_actual_max_sequence_length} tokens: {removed_text}"
         | 
| 382 | 
            +
                                )
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                            prompt_attention_mask = text_inputs.attention_mask.to(device)
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                            if self.transformer.config.enable_text_attention_mask:
         | 
| 387 | 
            +
                                prompt_embeds = text_encoder(
         | 
| 388 | 
            +
                                    text_input_ids.to(device),
         | 
| 389 | 
            +
                                    attention_mask=prompt_attention_mask,
         | 
| 390 | 
            +
                                )
         | 
| 391 | 
            +
                            else:
         | 
| 392 | 
            +
                                prompt_embeds = text_encoder(
         | 
| 393 | 
            +
                                    text_input_ids.to(device)
         | 
| 394 | 
            +
                                )
         | 
| 395 | 
            +
                            prompt_embeds = prompt_embeds[0]
         | 
| 396 | 
            +
                            prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 397 | 
             
                        else:
         | 
| 398 | 
            +
                            if prompt is not None and isinstance(prompt, str):
         | 
| 399 | 
            +
                                messages = [
         | 
| 400 | 
            +
                                    {
         | 
| 401 | 
            +
                                        "role": "user",
         | 
| 402 | 
            +
                                        "content": [{"type": "text", "text": prompt}],
         | 
| 403 | 
            +
                                    }
         | 
| 404 | 
            +
                                ]
         | 
| 405 | 
            +
                            else:
         | 
| 406 | 
            +
                                messages = [
         | 
| 407 | 
            +
                                    {
         | 
| 408 | 
            +
                                        "role": "user",
         | 
| 409 | 
            +
                                        "content": [{"type": "text", "text": _prompt}],
         | 
| 410 | 
            +
                                    } for _prompt in prompt
         | 
| 411 | 
            +
                                ]
         | 
| 412 | 
            +
                            text = tokenizer.apply_chat_template(
         | 
| 413 | 
            +
                                messages, tokenize=False, add_generation_prompt=True
         | 
| 414 | 
             
                            )
         | 
|  | |
|  | |
| 415 |  | 
| 416 | 
            +
                            text_inputs = tokenizer(
         | 
| 417 | 
            +
                                text=[text],
         | 
| 418 | 
            +
                                padding="max_length",
         | 
| 419 | 
            +
                                max_length=max_length,
         | 
| 420 | 
            +
                                truncation=True,
         | 
| 421 | 
            +
                                return_attention_mask=True,
         | 
| 422 | 
            +
                                padding_side="right",
         | 
| 423 | 
            +
                                return_tensors="pt",
         | 
| 424 | 
            +
                            )
         | 
| 425 | 
            +
                            text_inputs = text_inputs.to(text_encoder.device)
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                            text_input_ids = text_inputs.input_ids
         | 
| 428 | 
            +
                            prompt_attention_mask = text_inputs.attention_mask
         | 
| 429 | 
            +
                            if self.transformer.config.enable_text_attention_mask:
         | 
| 430 | 
            +
                                # Inference: Generation of the output
         | 
| 431 | 
            +
                                prompt_embeds = text_encoder(
         | 
| 432 | 
            +
                                    input_ids=text_input_ids,
         | 
| 433 | 
            +
                                    attention_mask=prompt_attention_mask,
         | 
| 434 | 
            +
                                    output_hidden_states=True).hidden_states[-2]
         | 
| 435 | 
            +
                            else:
         | 
| 436 | 
            +
                                raise ValueError("LLM needs attention_mask")
         | 
| 437 | 
            +
                            prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 438 | 
            +
                    
         | 
| 439 | 
             
                    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
         | 
| 440 |  | 
| 441 | 
             
                    bs_embed, seq_len, _ = prompt_embeds.shape
         | 
| 442 | 
             
                    # duplicate text embeddings for each generation per prompt, using mps friendly method
         | 
| 443 | 
             
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 444 | 
             
                    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 445 | 
            +
                    prompt_attention_mask = prompt_attention_mask.to(device=device)
         | 
| 446 |  | 
| 447 | 
             
                    # get unconditional embeddings for classifier free guidance
         | 
| 448 | 
             
                    if do_classifier_free_guidance and negative_prompt_embeds is None:
         | 
| 449 | 
            +
                        if type(tokenizer) in [BertTokenizer, T5Tokenizer]:
         | 
| 450 | 
            +
                            uncond_tokens: List[str]
         | 
| 451 | 
            +
                            if negative_prompt is None:
         | 
| 452 | 
            +
                                uncond_tokens = [""] * batch_size
         | 
| 453 | 
            +
                            elif prompt is not None and type(prompt) is not type(negative_prompt):
         | 
| 454 | 
            +
                                raise TypeError(
         | 
| 455 | 
            +
                                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
         | 
| 456 | 
            +
                                    f" {type(prompt)}."
         | 
| 457 | 
            +
                                )
         | 
| 458 | 
            +
                            elif isinstance(negative_prompt, str):
         | 
| 459 | 
            +
                                uncond_tokens = [negative_prompt]
         | 
| 460 | 
            +
                            elif batch_size != len(negative_prompt):
         | 
| 461 | 
            +
                                raise ValueError(
         | 
| 462 | 
            +
                                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
         | 
| 463 | 
            +
                                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
         | 
| 464 | 
            +
                                    " the batch size of `prompt`."
         | 
| 465 | 
            +
                                )
         | 
| 466 | 
            +
                            else:
         | 
| 467 | 
            +
                                uncond_tokens = negative_prompt
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                            max_length = prompt_embeds.shape[1]
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 470 | 
             
                            uncond_input = tokenizer(
         | 
| 471 | 
            +
                                uncond_tokens,
         | 
| 472 | 
             
                                padding="max_length",
         | 
| 473 | 
             
                                max_length=max_length,
         | 
| 474 | 
             
                                truncation=True,
         | 
|  | |
| 475 | 
             
                                return_tensors="pt",
         | 
| 476 | 
             
                            )
         | 
| 477 | 
             
                            uncond_input_ids = uncond_input.input_ids
         | 
| 478 | 
            +
                            if uncond_input_ids.shape[-1] > actual_max_sequence_length:
         | 
| 479 | 
            +
                                reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
         | 
| 480 | 
            +
                                uncond_input = tokenizer(
         | 
| 481 | 
            +
                                    reuncond_tokens,
         | 
| 482 | 
            +
                                    padding="max_length",
         | 
| 483 | 
            +
                                    max_length=max_length,
         | 
| 484 | 
            +
                                    truncation=True,
         | 
| 485 | 
            +
                                    return_attention_mask=True,
         | 
| 486 | 
            +
                                    return_tensors="pt",
         | 
| 487 | 
            +
                                )
         | 
| 488 | 
            +
                                uncond_input_ids = uncond_input.input_ids
         | 
| 489 |  | 
| 490 | 
            +
                            negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
         | 
| 491 | 
            +
                            if self.transformer.config.enable_text_attention_mask:
         | 
| 492 | 
            +
                                negative_prompt_embeds = text_encoder(
         | 
| 493 | 
            +
                                    uncond_input.input_ids.to(device),
         | 
| 494 | 
            +
                                    attention_mask=negative_prompt_attention_mask,
         | 
| 495 | 
            +
                                )
         | 
| 496 | 
            +
                            else:
         | 
| 497 | 
            +
                                negative_prompt_embeds = text_encoder(
         | 
| 498 | 
            +
                                    uncond_input.input_ids.to(device)
         | 
| 499 | 
            +
                                )
         | 
| 500 | 
            +
                            negative_prompt_embeds = negative_prompt_embeds[0]
         | 
| 501 | 
            +
                            negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 502 | 
             
                        else:
         | 
| 503 | 
            +
                            if negative_prompt is not None and isinstance(negative_prompt, str):
         | 
| 504 | 
            +
                                messages = [
         | 
| 505 | 
            +
                                    {
         | 
| 506 | 
            +
                                        "role": "user",
         | 
| 507 | 
            +
                                        "content": [{"type": "text", "text": negative_prompt}],
         | 
| 508 | 
            +
                                    }
         | 
| 509 | 
            +
                                ]
         | 
| 510 | 
            +
                            else:
         | 
| 511 | 
            +
                                messages = [
         | 
| 512 | 
            +
                                    {
         | 
| 513 | 
            +
                                        "role": "user",
         | 
| 514 | 
            +
                                        "content": [{"type": "text", "text": _negative_prompt}],
         | 
| 515 | 
            +
                                    } for _negative_prompt in negative_prompt
         | 
| 516 | 
            +
                                ]
         | 
| 517 | 
            +
                            text = tokenizer.apply_chat_template(
         | 
| 518 | 
            +
                                messages, tokenize=False, add_generation_prompt=True
         | 
| 519 | 
             
                            )
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                            text_inputs = tokenizer(
         | 
| 522 | 
            +
                                text=[text],
         | 
| 523 | 
            +
                                padding="max_length",
         | 
| 524 | 
            +
                                max_length=max_length,
         | 
| 525 | 
            +
                                truncation=True,
         | 
| 526 | 
            +
                                return_attention_mask=True,
         | 
| 527 | 
            +
                                padding_side="right",
         | 
| 528 | 
            +
                                return_tensors="pt",
         | 
| 529 | 
            +
                            )
         | 
| 530 | 
            +
                            text_inputs = text_inputs.to(text_encoder.device)
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                            text_input_ids = text_inputs.input_ids
         | 
| 533 | 
            +
                            negative_prompt_attention_mask = text_inputs.attention_mask
         | 
| 534 | 
            +
                            if self.transformer.config.enable_text_attention_mask:
         | 
| 535 | 
            +
                                # Inference: Generation of the output
         | 
| 536 | 
            +
                                negative_prompt_embeds = text_encoder(
         | 
| 537 | 
            +
                                    input_ids=text_input_ids,
         | 
| 538 | 
            +
                                    attention_mask=negative_prompt_attention_mask,
         | 
| 539 | 
            +
                                    output_hidden_states=True).hidden_states[-2]
         | 
| 540 | 
            +
                            else:
         | 
| 541 | 
            +
                                raise ValueError("LLM needs attention_mask")
         | 
| 542 | 
            +
                            negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 543 |  | 
| 544 | 
             
                    if do_classifier_free_guidance:
         | 
| 545 | 
             
                        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
         | 
|  | |
| 549 |  | 
| 550 | 
             
                        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 551 | 
             
                        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
         | 
| 552 | 
            +
                        negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
         | 
| 553 | 
            +
                        
         | 
| 554 | 
             
                    return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
         | 
| 555 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 556 | 
             
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
         | 
| 557 | 
             
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 558 | 
             
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
|  | |
| 587 | 
             
                    negative_prompt_attention_mask_2=None,
         | 
| 588 | 
             
                    callback_on_step_end_tensor_inputs=None,
         | 
| 589 | 
             
                ):
         | 
| 590 | 
            +
                    if height % 16 != 0 or width % 16 != 0:
         | 
| 591 | 
            +
                        raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
         | 
| 592 |  | 
| 593 | 
             
                    if callback_on_step_end_tensor_inputs is not None and not all(
         | 
| 594 | 
             
                        k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
         | 
|  | |
| 673 | 
             
                        latents = latents.to(device)
         | 
| 674 |  | 
| 675 | 
             
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 676 | 
            +
                    if hasattr(self.scheduler, "init_noise_sigma"):
         | 
| 677 | 
            +
                        latents = latents * self.scheduler.init_noise_sigma
         | 
| 678 | 
             
                    return latents
         | 
| 679 |  | 
| 680 | 
             
                def prepare_control_latents(
         | 
| 681 | 
            +
                    self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
         | 
| 682 | 
             
                ):
         | 
| 683 | 
            +
                    # resize the control to latents shape as we concatenate the control to the latents
         | 
| 684 | 
             
                    # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
         | 
| 685 | 
             
                    # and half precision
         | 
| 686 |  | 
| 687 | 
            +
                    if control is not None:
         | 
| 688 | 
            +
                        control = control.to(device=device, dtype=dtype)
         | 
| 689 | 
             
                        bs = 1
         | 
| 690 | 
            +
                        new_control = []
         | 
| 691 | 
            +
                        for i in range(0, control.shape[0], bs):
         | 
| 692 | 
            +
                            control_bs = control[i : i + bs]
         | 
| 693 | 
            +
                            control_bs = self.vae.encode(control_bs)[0]
         | 
| 694 | 
            +
                            control_bs = control_bs.mode()
         | 
| 695 | 
            +
                            new_control.append(control_bs)
         | 
| 696 | 
            +
                        control = torch.cat(new_control, dim = 0)
         | 
| 697 | 
            +
                        control = control * self.vae.config.scaling_factor
         | 
| 698 | 
            +
             | 
| 699 | 
            +
                    if control_image is not None:
         | 
| 700 | 
            +
                        control_image = control_image.to(device=device, dtype=dtype)
         | 
| 701 | 
             
                        bs = 1
         | 
| 702 | 
            +
                        new_control_pixel_values = []
         | 
| 703 | 
            +
                        for i in range(0, control_image.shape[0], bs):
         | 
| 704 | 
            +
                            control_pixel_values_bs = control_image[i : i + bs]
         | 
| 705 | 
            +
                            control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
         | 
| 706 | 
            +
                            control_pixel_values_bs = control_pixel_values_bs.mode()
         | 
| 707 | 
            +
                            new_control_pixel_values.append(control_pixel_values_bs)
         | 
| 708 | 
            +
                        control_image_latents = torch.cat(new_control_pixel_values, dim = 0)
         | 
| 709 | 
            +
                        control_image_latents = control_image_latents * self.vae.config.scaling_factor
         | 
| 710 | 
             
                    else:
         | 
| 711 | 
            +
                        control_image_latents = None
         | 
| 712 |  | 
| 713 | 
            +
                    return control, control_image_latents
         | 
| 714 |  | 
| 715 | 
             
                def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
         | 
| 716 | 
             
                    if video.size()[2] <= mini_batch_encoder:
         | 
|  | |
| 773 | 
             
                def interrupt(self):
         | 
| 774 | 
             
                    return self._interrupt
         | 
| 775 |  | 
|  | |
|  | |
|  | |
| 776 | 
             
                @torch.no_grad()
         | 
| 777 | 
             
                @replace_example_docstring(EXAMPLE_DOC_STRING)
         | 
| 778 | 
             
                def __call__(
         | 
|  | |
| 782 | 
             
                    height: Optional[int] = None,
         | 
| 783 | 
             
                    width: Optional[int] = None,
         | 
| 784 | 
             
                    control_video: Union[torch.FloatTensor] = None,
         | 
| 785 | 
            +
                    control_camera_video: Union[torch.FloatTensor] = None,
         | 
| 786 | 
            +
                    ref_image: Union[torch.FloatTensor] = None,
         | 
| 787 | 
             
                    num_inference_steps: Optional[int] = 50,
         | 
| 788 | 
             
                    guidance_scale: Optional[float] = 5.0,
         | 
| 789 | 
             
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
|  | |
| 810 | 
             
                    target_size: Optional[Tuple[int, int]] = None,
         | 
| 811 | 
             
                    crops_coords_top_left: Tuple[int, int] = (0, 0),
         | 
| 812 | 
             
                    comfyui_progressbar: bool = False,
         | 
| 813 | 
            +
                    timesteps: Optional[List[int]] = None,
         | 
| 814 | 
             
                ):
         | 
| 815 | 
             
                    r"""
         | 
| 816 | 
             
                    Generates images or video using the EasyAnimate pipeline based on the provided prompts.
         | 
|  | |
| 915 | 
             
                        batch_size = prompt_embeds.shape[0]
         | 
| 916 |  | 
| 917 | 
             
                    device = self._execution_device
         | 
| 918 | 
            +
                    if self.text_encoder is not None:
         | 
| 919 | 
            +
                        dtype = self.text_encoder.dtype
         | 
| 920 | 
            +
                    elif self.text_encoder_2 is not None:
         | 
| 921 | 
            +
                        dtype = self.text_encoder_2.dtype
         | 
| 922 | 
            +
                    else:
         | 
| 923 | 
            +
                        dtype = self.transformer.dtype
         | 
| 924 |  | 
| 925 | 
             
                    # 3. Encode input prompt
         | 
| 926 | 
             
                    (
         | 
|  | |
| 931 | 
             
                    ) = self.encode_prompt(
         | 
| 932 | 
             
                        prompt=prompt,
         | 
| 933 | 
             
                        device=device,
         | 
| 934 | 
            +
                        dtype=dtype,
         | 
| 935 | 
             
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 936 | 
             
                        do_classifier_free_guidance=self.do_classifier_free_guidance,
         | 
| 937 | 
             
                        negative_prompt=negative_prompt,
         | 
|  | |
| 941 | 
             
                        negative_prompt_attention_mask=negative_prompt_attention_mask,
         | 
| 942 | 
             
                        text_encoder_index=0,
         | 
| 943 | 
             
                    )
         | 
| 944 | 
            +
                    if self.tokenizer_2 is not None:
         | 
| 945 | 
            +
                        (
         | 
| 946 | 
            +
                            prompt_embeds_2,
         | 
| 947 | 
            +
                            negative_prompt_embeds_2,
         | 
| 948 | 
            +
                            prompt_attention_mask_2,
         | 
| 949 | 
            +
                            negative_prompt_attention_mask_2,
         | 
| 950 | 
            +
                        ) = self.encode_prompt(
         | 
| 951 | 
            +
                            prompt=prompt,
         | 
| 952 | 
            +
                            device=device,
         | 
| 953 | 
            +
                            dtype=dtype,
         | 
| 954 | 
            +
                            num_images_per_prompt=num_images_per_prompt,
         | 
| 955 | 
            +
                            do_classifier_free_guidance=self.do_classifier_free_guidance,
         | 
| 956 | 
            +
                            negative_prompt=negative_prompt,
         | 
| 957 | 
            +
                            prompt_embeds=prompt_embeds_2,
         | 
| 958 | 
            +
                            negative_prompt_embeds=negative_prompt_embeds_2,
         | 
| 959 | 
            +
                            prompt_attention_mask=prompt_attention_mask_2,
         | 
| 960 | 
            +
                            negative_prompt_attention_mask=negative_prompt_attention_mask_2,
         | 
| 961 | 
            +
                            text_encoder_index=1,
         | 
| 962 | 
            +
                        )
         | 
| 963 | 
            +
                    else:
         | 
| 964 | 
            +
                        prompt_embeds_2 = None
         | 
| 965 | 
            +
                        negative_prompt_embeds_2 = None
         | 
| 966 | 
            +
                        prompt_attention_mask_2 = None
         | 
| 967 | 
            +
                        negative_prompt_attention_mask_2 = None
         | 
| 968 |  | 
| 969 | 
             
                    # 4. Prepare timesteps
         | 
| 970 | 
            +
                    if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
         | 
| 971 | 
            +
                        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
         | 
| 972 | 
            +
                    else:
         | 
| 973 | 
            +
                        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
         | 
| 974 | 
             
                    timesteps = self.scheduler.timesteps
         | 
| 975 | 
             
                    if comfyui_progressbar:
         | 
| 976 | 
             
                        from comfy.utils import ProgressBar
         | 
|  | |
| 984 | 
             
                        video_length,
         | 
| 985 | 
             
                        height,
         | 
| 986 | 
             
                        width,
         | 
| 987 | 
            +
                        dtype,
         | 
| 988 | 
             
                        device,
         | 
| 989 | 
             
                        generator,
         | 
| 990 | 
             
                        latents,
         | 
|  | |
| 992 | 
             
                    if comfyui_progressbar:
         | 
| 993 | 
             
                        pbar.update(1)
         | 
| 994 |  | 
| 995 | 
            +
                    if control_camera_video is not None:
         | 
| 996 | 
            +
                        control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True)
         | 
| 997 | 
            +
                        control_video_latents = control_video_latents * 6
         | 
| 998 | 
            +
                        control_latents = (
         | 
| 999 | 
            +
                            torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
         | 
| 1000 | 
            +
                        ).to(device, dtype)
         | 
| 1001 | 
            +
                    elif control_video is not None:
         | 
| 1002 | 
             
                        video_length = control_video.shape[2]
         | 
| 1003 | 
             
                        control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) 
         | 
| 1004 | 
             
                        control_video = control_video.to(dtype=torch.float32)
         | 
| 1005 | 
             
                        control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 1006 | 
            +
                        control_video_latents = self.prepare_control_latents(
         | 
| 1007 | 
            +
                            None,
         | 
| 1008 | 
            +
                            control_video,
         | 
| 1009 | 
            +
                            batch_size,
         | 
| 1010 | 
            +
                            height,
         | 
| 1011 | 
            +
                            width,
         | 
| 1012 | 
            +
                            dtype,
         | 
| 1013 | 
            +
                            device,
         | 
| 1014 | 
            +
                            generator,
         | 
| 1015 | 
            +
                            self.do_classifier_free_guidance
         | 
| 1016 | 
            +
                        )[1]
         | 
| 1017 | 
            +
                        control_latents = (
         | 
| 1018 | 
            +
                            torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
         | 
| 1019 | 
            +
                        ).to(device, dtype)
         | 
| 1020 | 
             
                    else:
         | 
| 1021 | 
            +
                        control_video_latents = torch.zeros_like(latents).to(device, dtype)
         | 
| 1022 | 
            +
                        control_latents = (
         | 
| 1023 | 
            +
                            torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
         | 
| 1024 | 
            +
                        ).to(device, dtype)
         | 
| 1025 | 
            +
                        
         | 
| 1026 | 
            +
                    if ref_image is not None:
         | 
| 1027 | 
            +
                        video_length = ref_image.shape[2]
         | 
| 1028 | 
            +
                        ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width) 
         | 
| 1029 | 
            +
                        ref_image = ref_image.to(dtype=torch.float32)
         | 
| 1030 | 
            +
                        ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 1031 | 
            +
                        
         | 
| 1032 | 
            +
                        ref_image_latentes = self.prepare_control_latents(
         | 
| 1033 | 
            +
                            None,
         | 
| 1034 | 
            +
                            ref_image,
         | 
| 1035 | 
            +
                            batch_size,
         | 
| 1036 | 
            +
                            height,
         | 
| 1037 | 
            +
                            width,
         | 
| 1038 | 
            +
                            prompt_embeds.dtype,
         | 
| 1039 | 
            +
                            device,
         | 
| 1040 | 
            +
                            generator,
         | 
| 1041 | 
            +
                            self.do_classifier_free_guidance
         | 
| 1042 | 
            +
                        )[1]
         | 
| 1043 | 
            +
             | 
| 1044 | 
            +
                        ref_image_latentes_conv_in = torch.zeros_like(latents)
         | 
| 1045 | 
            +
                        if latents.size()[2] != 1:
         | 
| 1046 | 
            +
                            ref_image_latentes_conv_in[:, :, :1] = ref_image_latentes
         | 
| 1047 | 
            +
                        ref_image_latentes_conv_in = (
         | 
| 1048 | 
            +
                            torch.cat([ref_image_latentes_conv_in] * 2) if self.do_classifier_free_guidance else ref_image_latentes_conv_in
         | 
| 1049 | 
            +
                        ).to(device, dtype)
         | 
| 1050 | 
            +
                        control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1)
         | 
| 1051 | 
            +
                    else:
         | 
| 1052 | 
            +
                        if self.transformer.config.get("add_ref_latent_in_control_model", False):
         | 
| 1053 | 
            +
                            ref_image_latentes_conv_in = torch.zeros_like(latents)
         | 
| 1054 | 
            +
                            ref_image_latentes_conv_in = (
         | 
| 1055 | 
            +
                                torch.cat([ref_image_latentes_conv_in] * 2) if self.do_classifier_free_guidance else ref_image_latentes_conv_in
         | 
| 1056 | 
            +
                            ).to(device, dtype)
         | 
| 1057 | 
            +
                            control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1)
         | 
| 1058 |  | 
| 1059 | 
             
                    if comfyui_progressbar:
         | 
| 1060 | 
             
                        pbar.update(1)
         | 
|  | |
| 1086 | 
             
                        )
         | 
| 1087 |  | 
| 1088 | 
             
                    # Get other hunyuan params
         | 
|  | |
|  | |
| 1089 | 
             
                    target_size = target_size or (height, width)
         | 
| 1090 | 
             
                    add_time_ids = list(original_size + target_size + crops_coords_top_left)
         | 
| 1091 | 
            +
                    add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
         | 
| 1092 | 
            +
                    style = torch.tensor([0], device=device)
         | 
| 1093 |  | 
| 1094 | 
             
                    if self.do_classifier_free_guidance:
         | 
|  | |
|  | |
|  | |
|  | |
| 1095 | 
             
                        add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
         | 
| 1096 | 
             
                        style = torch.cat([style] * 2, dim=0)
         | 
| 1097 |  | 
| 1098 | 
             
                    # To latents.device
         | 
| 1099 | 
            +
                    add_time_ids = add_time_ids.to(dtype=dtype, device=device).repeat(
         | 
|  | |
|  | |
|  | |
|  | |
| 1100 | 
             
                        batch_size * num_images_per_prompt, 1
         | 
| 1101 | 
             
                    )
         | 
| 1102 | 
             
                    style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
         | 
| 1103 |  | 
| 1104 | 
            +
                    # Get other pixart params
         | 
| 1105 | 
            +
                    added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
         | 
| 1106 | 
            +
                    if self.transformer.config.get("sample_size", 64) == 128:
         | 
| 1107 | 
            +
                        resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
         | 
| 1108 | 
            +
                        aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
         | 
| 1109 | 
            +
                        resolution = resolution.to(dtype=dtype, device=device)
         | 
| 1110 | 
            +
                        aspect_ratio = aspect_ratio.to(dtype=dtype, device=device)
         | 
| 1111 | 
            +
             | 
| 1112 | 
            +
                        if self.do_classifier_free_guidance:
         | 
| 1113 | 
            +
                            resolution = torch.cat([resolution, resolution], dim=0)
         | 
| 1114 | 
            +
                            aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
         | 
| 1115 | 
            +
             | 
| 1116 | 
            +
                        added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
         | 
| 1117 | 
            +
             | 
| 1118 | 
            +
                    if self.do_classifier_free_guidance:
         | 
| 1119 | 
            +
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
         | 
| 1120 | 
            +
                        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
         | 
| 1121 | 
            +
                        if prompt_embeds_2 is not None:
         | 
| 1122 | 
            +
                            prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
         | 
| 1123 | 
            +
                            prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
         | 
| 1124 | 
            +
             | 
| 1125 | 
            +
                    # To latents.device
         | 
| 1126 | 
            +
                    prompt_embeds = prompt_embeds.to(device=device)
         | 
| 1127 | 
            +
                    prompt_attention_mask = prompt_attention_mask.to(device=device)
         | 
| 1128 | 
            +
                    if prompt_embeds_2 is not None:
         | 
| 1129 | 
            +
                        prompt_embeds_2 = prompt_embeds_2.to(device=device)
         | 
| 1130 | 
            +
                        prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
         | 
| 1131 | 
            +
             | 
| 1132 | 
             
                    # 8. Denoising loop
         | 
| 1133 | 
             
                    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
         | 
| 1134 | 
             
                    self._num_timesteps = len(timesteps)
         | 
|  | |
| 1139 |  | 
| 1140 | 
             
                            # expand the latents if we are doing classifier free guidance
         | 
| 1141 | 
             
                            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
         | 
| 1142 | 
            +
                            if hasattr(self.scheduler, "scale_model_input"):
         | 
| 1143 | 
            +
                                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         | 
| 1144 |  | 
| 1145 | 
             
                            # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
         | 
| 1146 | 
             
                            t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
         | 
|  | |
| 1157 | 
             
                                image_meta_size=add_time_ids,
         | 
| 1158 | 
             
                                style=style,
         | 
| 1159 | 
             
                                image_rotary_emb=image_rotary_emb,
         | 
| 1160 | 
            +
                                added_cond_kwargs=added_cond_kwargs,
         | 
| 1161 | 
             
                                control_latents=control_latents,
         | 
| 1162 | 
            +
                                return_dict=False,
         | 
| 1163 | 
             
                            )[0]
         | 
| 1164 | 
             
                            if noise_pred.size()[1] != self.vae.config.latent_channels:
         | 
| 1165 | 
             
                                noise_pred, _ = noise_pred.chunk(2, dim=1)
         | 
|  | |
| 1199 | 
             
                            if comfyui_progressbar:
         | 
| 1200 | 
             
                                pbar.update(1)
         | 
| 1201 |  | 
|  | |
|  | |
|  | |
|  | |
| 1202 | 
             
                    # Post-processing
         | 
| 1203 | 
             
                    video = self.decode_latents(latents)
         | 
| 1204 |  | 
|  | |
| 1212 | 
             
                    if not return_dict:
         | 
| 1213 | 
             
                        return video
         | 
| 1214 |  | 
| 1215 | 
            +
                    return EasyAnimatePipelineOutput(frames=video)
         | 
    	
        easyanimate/pipeline/pipeline_easyanimate_inpaint.py
    CHANGED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py
    DELETED
    
    | @@ -1,925 +0,0 @@ | |
| 1 | 
            -
            # Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            import inspect
         | 
| 16 | 
            -
            from typing import Callable, Dict, List, Optional, Tuple, Union
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            import numpy as np
         | 
| 19 | 
            -
            import torch
         | 
| 20 | 
            -
            from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
         | 
| 21 | 
            -
            from diffusers.image_processor import VaeImageProcessor
         | 
| 22 | 
            -
            from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
         | 
| 23 | 
            -
                                                     get_3d_rotary_pos_embed)
         | 
| 24 | 
            -
            from diffusers.pipelines.pipeline_utils import DiffusionPipeline
         | 
| 25 | 
            -
            from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
         | 
| 26 | 
            -
            from diffusers.pipelines.stable_diffusion.safety_checker import \
         | 
| 27 | 
            -
                StableDiffusionSafetyChecker
         | 
| 28 | 
            -
            from diffusers.schedulers import DDIMScheduler
         | 
| 29 | 
            -
            from diffusers.utils import (is_torch_xla_available, logging,
         | 
| 30 | 
            -
                                         replace_example_docstring)
         | 
| 31 | 
            -
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 32 | 
            -
            from einops import rearrange
         | 
| 33 | 
            -
            from tqdm import tqdm
         | 
| 34 | 
            -
            from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
         | 
| 35 | 
            -
                                      T5Tokenizer, T5EncoderModel)
         | 
| 36 | 
            -
             | 
| 37 | 
            -
            from .pipeline_easyanimate import EasyAnimatePipelineOutput
         | 
| 38 | 
            -
            from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
         | 
| 39 | 
            -
             | 
| 40 | 
            -
            if is_torch_xla_available():
         | 
| 41 | 
            -
                import torch_xla.core.xla_model as xm
         | 
| 42 | 
            -
             | 
| 43 | 
            -
                XLA_AVAILABLE = True
         | 
| 44 | 
            -
            else:
         | 
| 45 | 
            -
                XLA_AVAILABLE = False
         | 
| 46 | 
            -
             | 
| 47 | 
            -
             | 
| 48 | 
            -
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 49 | 
            -
             | 
| 50 | 
            -
            EXAMPLE_DOC_STRING = """
         | 
| 51 | 
            -
                Examples:
         | 
| 52 | 
            -
                    ```py
         | 
| 53 | 
            -
                    >>> pass
         | 
| 54 | 
            -
                    ```
         | 
| 55 | 
            -
            """
         | 
| 56 | 
            -
             | 
| 57 | 
            -
             | 
| 58 | 
            -
            def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
         | 
| 59 | 
            -
                tw = tgt_width
         | 
| 60 | 
            -
                th = tgt_height
         | 
| 61 | 
            -
                h, w = src
         | 
| 62 | 
            -
                r = h / w
         | 
| 63 | 
            -
                if r > (th / tw):
         | 
| 64 | 
            -
                    resize_height = th
         | 
| 65 | 
            -
                    resize_width = int(round(th / h * w))
         | 
| 66 | 
            -
                else:
         | 
| 67 | 
            -
                    resize_width = tw
         | 
| 68 | 
            -
                    resize_height = int(round(tw / w * h))
         | 
| 69 | 
            -
             | 
| 70 | 
            -
                crop_top = int(round((th - resize_height) / 2.0))
         | 
| 71 | 
            -
                crop_left = int(round((tw - resize_width) / 2.0))
         | 
| 72 | 
            -
             | 
| 73 | 
            -
                return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
         | 
| 74 | 
            -
             | 
| 75 | 
            -
             | 
| 76 | 
            -
            # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
         | 
| 77 | 
            -
            def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
         | 
| 78 | 
            -
                """
         | 
| 79 | 
            -
                Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
         | 
| 80 | 
            -
                Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
         | 
| 81 | 
            -
                """
         | 
| 82 | 
            -
                std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
         | 
| 83 | 
            -
                std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
         | 
| 84 | 
            -
                # rescale the results from guidance (fixes overexposure)
         | 
| 85 | 
            -
                noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
         | 
| 86 | 
            -
                # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
         | 
| 87 | 
            -
                noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
         | 
| 88 | 
            -
                return noise_cfg
         | 
| 89 | 
            -
             | 
| 90 | 
            -
             | 
| 91 | 
            -
            class EasyAnimatePipeline_Multi_Text_Encoder(DiffusionPipeline):
         | 
| 92 | 
            -
                r"""
         | 
| 93 | 
            -
                Pipeline for text-to-video generation using EasyAnimate.
         | 
| 94 | 
            -
             | 
| 95 | 
            -
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
         | 
| 96 | 
            -
                library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
         | 
| 97 | 
            -
             | 
| 98 | 
            -
                EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
         | 
| 99 | 
            -
                HunyuanDiT team)
         | 
| 100 | 
            -
             | 
| 101 | 
            -
                Args:
         | 
| 102 | 
            -
                    vae ([`AutoencoderKLMagvit`]):
         | 
| 103 | 
            -
                        Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. 
         | 
| 104 | 
            -
                    text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
         | 
| 105 | 
            -
                        Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
         | 
| 106 | 
            -
                        EasyAnimate uses a fine-tuned [bilingual CLIP].
         | 
| 107 | 
            -
                    tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
         | 
| 108 | 
            -
                        A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
         | 
| 109 | 
            -
                    transformer ([`EasyAnimateTransformer3DModel`]):
         | 
| 110 | 
            -
                        The EasyAnimate model designed by Tencent Hunyuan.
         | 
| 111 | 
            -
                    text_encoder_2 (`T5EncoderModel`):
         | 
| 112 | 
            -
                        The mT5 embedder. 
         | 
| 113 | 
            -
                    tokenizer_2 (`T5Tokenizer`):
         | 
| 114 | 
            -
                        The tokenizer for the mT5 embedder.
         | 
| 115 | 
            -
                    scheduler ([`DDIMScheduler`]):
         | 
| 116 | 
            -
                        A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
         | 
| 117 | 
            -
                """
         | 
| 118 | 
            -
             | 
| 119 | 
            -
                model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
         | 
| 120 | 
            -
                _optional_components = [
         | 
| 121 | 
            -
                    "safety_checker",
         | 
| 122 | 
            -
                    "feature_extractor",
         | 
| 123 | 
            -
                    "text_encoder_2",
         | 
| 124 | 
            -
                    "tokenizer_2",
         | 
| 125 | 
            -
                    "text_encoder",
         | 
| 126 | 
            -
                    "tokenizer",
         | 
| 127 | 
            -
                ]
         | 
| 128 | 
            -
                _exclude_from_cpu_offload = ["safety_checker"]
         | 
| 129 | 
            -
                _callback_tensor_inputs = [
         | 
| 130 | 
            -
                    "latents",
         | 
| 131 | 
            -
                    "prompt_embeds",
         | 
| 132 | 
            -
                    "negative_prompt_embeds",
         | 
| 133 | 
            -
                    "prompt_embeds_2",
         | 
| 134 | 
            -
                    "negative_prompt_embeds_2",
         | 
| 135 | 
            -
                ]
         | 
| 136 | 
            -
             | 
| 137 | 
            -
                def __init__(
         | 
| 138 | 
            -
                    self,
         | 
| 139 | 
            -
                    vae: AutoencoderKLMagvit,
         | 
| 140 | 
            -
                    text_encoder: BertModel,
         | 
| 141 | 
            -
                    tokenizer: BertTokenizer,
         | 
| 142 | 
            -
                    text_encoder_2: T5EncoderModel,
         | 
| 143 | 
            -
                    tokenizer_2: T5Tokenizer,
         | 
| 144 | 
            -
                    transformer: EasyAnimateTransformer3DModel,
         | 
| 145 | 
            -
                    scheduler: DDIMScheduler,
         | 
| 146 | 
            -
                    safety_checker: StableDiffusionSafetyChecker,
         | 
| 147 | 
            -
                    feature_extractor: CLIPImageProcessor,
         | 
| 148 | 
            -
                    requires_safety_checker: bool = True,
         | 
| 149 | 
            -
                ):
         | 
| 150 | 
            -
                    super().__init__()
         | 
| 151 | 
            -
             | 
| 152 | 
            -
                    self.register_modules(
         | 
| 153 | 
            -
                        vae=vae,
         | 
| 154 | 
            -
                        text_encoder=text_encoder,
         | 
| 155 | 
            -
                        tokenizer=tokenizer,
         | 
| 156 | 
            -
                        tokenizer_2=tokenizer_2,
         | 
| 157 | 
            -
                        transformer=transformer,
         | 
| 158 | 
            -
                        scheduler=scheduler,
         | 
| 159 | 
            -
                        safety_checker=safety_checker,
         | 
| 160 | 
            -
                        feature_extractor=feature_extractor,
         | 
| 161 | 
            -
                        text_encoder_2=text_encoder_2,
         | 
| 162 | 
            -
                    )
         | 
| 163 | 
            -
             | 
| 164 | 
            -
                    if safety_checker is None and requires_safety_checker:
         | 
| 165 | 
            -
                        logger.warning(
         | 
| 166 | 
            -
                            f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
         | 
| 167 | 
            -
                            " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
         | 
| 168 | 
            -
                            " results in services or applications open to the public. Both the diffusers team and Hugging Face"
         | 
| 169 | 
            -
                            " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
         | 
| 170 | 
            -
                            " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
         | 
| 171 | 
            -
                            " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
         | 
| 172 | 
            -
                        )
         | 
| 173 | 
            -
             | 
| 174 | 
            -
                    if safety_checker is not None and feature_extractor is None:
         | 
| 175 | 
            -
                        raise ValueError(
         | 
| 176 | 
            -
                            "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
         | 
| 177 | 
            -
                            " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
         | 
| 178 | 
            -
                        )
         | 
| 179 | 
            -
             | 
| 180 | 
            -
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         | 
| 181 | 
            -
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         | 
| 182 | 
            -
                    self.enable_autocast_float8_transformer_flag = False
         | 
| 183 | 
            -
                    self.register_to_config(requires_safety_checker=requires_safety_checker)
         | 
| 184 | 
            -
             | 
| 185 | 
            -
                def enable_sequential_cpu_offload(self, *args, **kwargs):
         | 
| 186 | 
            -
                    super().enable_sequential_cpu_offload(*args, **kwargs)
         | 
| 187 | 
            -
                    if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
         | 
| 188 | 
            -
                        import accelerate
         | 
| 189 | 
            -
                        accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
         | 
| 190 | 
            -
                        self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
         | 
| 191 | 
            -
             | 
| 192 | 
            -
                def encode_prompt(
         | 
| 193 | 
            -
                    self,
         | 
| 194 | 
            -
                    prompt: str,
         | 
| 195 | 
            -
                    device: torch.device,
         | 
| 196 | 
            -
                    dtype: torch.dtype,
         | 
| 197 | 
            -
                    num_images_per_prompt: int = 1,
         | 
| 198 | 
            -
                    do_classifier_free_guidance: bool = True,
         | 
| 199 | 
            -
                    negative_prompt: Optional[str] = None,
         | 
| 200 | 
            -
                    prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 201 | 
            -
                    negative_prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 202 | 
            -
                    prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 203 | 
            -
                    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 204 | 
            -
                    max_sequence_length: Optional[int] = None,
         | 
| 205 | 
            -
                    text_encoder_index: int = 0,
         | 
| 206 | 
            -
                    actual_max_sequence_length: int = 256
         | 
| 207 | 
            -
                ):
         | 
| 208 | 
            -
                    r"""
         | 
| 209 | 
            -
                    Encodes the prompt into text encoder hidden states.
         | 
| 210 | 
            -
             | 
| 211 | 
            -
                    Args:
         | 
| 212 | 
            -
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 213 | 
            -
                            prompt to be encoded
         | 
| 214 | 
            -
                        device: (`torch.device`):
         | 
| 215 | 
            -
                            torch device
         | 
| 216 | 
            -
                        dtype (`torch.dtype`):
         | 
| 217 | 
            -
                            torch dtype
         | 
| 218 | 
            -
                        num_images_per_prompt (`int`):
         | 
| 219 | 
            -
                            number of images that should be generated per prompt
         | 
| 220 | 
            -
                        do_classifier_free_guidance (`bool`):
         | 
| 221 | 
            -
                            whether to use classifier free guidance or not
         | 
| 222 | 
            -
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 223 | 
            -
                            The prompt or prompts not to guide the image generation. If not defined, one has to pass
         | 
| 224 | 
            -
                            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
         | 
| 225 | 
            -
                            less than `1`).
         | 
| 226 | 
            -
                        prompt_embeds (`torch.Tensor`, *optional*):
         | 
| 227 | 
            -
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         | 
| 228 | 
            -
                            provided, text embeddings will be generated from `prompt` input argument.
         | 
| 229 | 
            -
                        negative_prompt_embeds (`torch.Tensor`, *optional*):
         | 
| 230 | 
            -
                            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
         | 
| 231 | 
            -
                            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
         | 
| 232 | 
            -
                            argument.
         | 
| 233 | 
            -
                        prompt_attention_mask (`torch.Tensor`, *optional*):
         | 
| 234 | 
            -
                            Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
         | 
| 235 | 
            -
                        negative_prompt_attention_mask (`torch.Tensor`, *optional*):
         | 
| 236 | 
            -
                            Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
         | 
| 237 | 
            -
                        max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
         | 
| 238 | 
            -
                        text_encoder_index (`int`, *optional*):
         | 
| 239 | 
            -
                            Index of the text encoder to use. `0` for clip and `1` for T5.
         | 
| 240 | 
            -
                    """
         | 
| 241 | 
            -
                    tokenizers = [self.tokenizer, self.tokenizer_2]
         | 
| 242 | 
            -
                    text_encoders = [self.text_encoder, self.text_encoder_2]
         | 
| 243 | 
            -
             | 
| 244 | 
            -
                    tokenizer = tokenizers[text_encoder_index]
         | 
| 245 | 
            -
                    text_encoder = text_encoders[text_encoder_index]
         | 
| 246 | 
            -
             | 
| 247 | 
            -
                    if max_sequence_length is None:
         | 
| 248 | 
            -
                        if text_encoder_index == 0:
         | 
| 249 | 
            -
                            max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
         | 
| 250 | 
            -
                        if text_encoder_index == 1:
         | 
| 251 | 
            -
                            max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
         | 
| 252 | 
            -
                    else:
         | 
| 253 | 
            -
                        max_length = max_sequence_length
         | 
| 254 | 
            -
             | 
| 255 | 
            -
                    if prompt is not None and isinstance(prompt, str):
         | 
| 256 | 
            -
                        batch_size = 1
         | 
| 257 | 
            -
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 258 | 
            -
                        batch_size = len(prompt)
         | 
| 259 | 
            -
                    else:
         | 
| 260 | 
            -
                        batch_size = prompt_embeds.shape[0]
         | 
| 261 | 
            -
             | 
| 262 | 
            -
                    if prompt_embeds is None:
         | 
| 263 | 
            -
                        text_inputs = tokenizer(
         | 
| 264 | 
            -
                            prompt,
         | 
| 265 | 
            -
                            padding="max_length",
         | 
| 266 | 
            -
                            max_length=max_length,
         | 
| 267 | 
            -
                            truncation=True,
         | 
| 268 | 
            -
                            return_attention_mask=True,
         | 
| 269 | 
            -
                            return_tensors="pt",
         | 
| 270 | 
            -
                        )
         | 
| 271 | 
            -
                        text_input_ids = text_inputs.input_ids
         | 
| 272 | 
            -
                        if text_input_ids.shape[-1] > actual_max_sequence_length:
         | 
| 273 | 
            -
                            reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
         | 
| 274 | 
            -
                            text_inputs = tokenizer(
         | 
| 275 | 
            -
                                reprompt,
         | 
| 276 | 
            -
                                padding="max_length",
         | 
| 277 | 
            -
                                max_length=max_length,
         | 
| 278 | 
            -
                                truncation=True,
         | 
| 279 | 
            -
                                return_attention_mask=True,
         | 
| 280 | 
            -
                                return_tensors="pt",
         | 
| 281 | 
            -
                            )
         | 
| 282 | 
            -
                            text_input_ids = text_inputs.input_ids
         | 
| 283 | 
            -
                        untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
         | 
| 284 | 
            -
             | 
| 285 | 
            -
                        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
         | 
| 286 | 
            -
                            text_input_ids, untruncated_ids
         | 
| 287 | 
            -
                        ):
         | 
| 288 | 
            -
                            _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
         | 
| 289 | 
            -
                            removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
         | 
| 290 | 
            -
                            logger.warning(
         | 
| 291 | 
            -
                                "The following part of your input was truncated because CLIP can only handle sequences up to"
         | 
| 292 | 
            -
                                f" {_actual_max_sequence_length} tokens: {removed_text}"
         | 
| 293 | 
            -
                            )
         | 
| 294 | 
            -
                        prompt_attention_mask = text_inputs.attention_mask.to(device)
         | 
| 295 | 
            -
             | 
| 296 | 
            -
                        if self.transformer.config.enable_text_attention_mask:
         | 
| 297 | 
            -
                            prompt_embeds = text_encoder(
         | 
| 298 | 
            -
                                text_input_ids.to(device),
         | 
| 299 | 
            -
                                attention_mask=prompt_attention_mask,
         | 
| 300 | 
            -
                            )
         | 
| 301 | 
            -
                        else:
         | 
| 302 | 
            -
                            prompt_embeds = text_encoder(
         | 
| 303 | 
            -
                                text_input_ids.to(device)
         | 
| 304 | 
            -
                            )
         | 
| 305 | 
            -
                        prompt_embeds = prompt_embeds[0]
         | 
| 306 | 
            -
                        prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 307 | 
            -
             | 
| 308 | 
            -
                    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
         | 
| 309 | 
            -
             | 
| 310 | 
            -
                    bs_embed, seq_len, _ = prompt_embeds.shape
         | 
| 311 | 
            -
                    # duplicate text embeddings for each generation per prompt, using mps friendly method
         | 
| 312 | 
            -
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 313 | 
            -
                    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 314 | 
            -
             | 
| 315 | 
            -
                    # get unconditional embeddings for classifier free guidance
         | 
| 316 | 
            -
                    if do_classifier_free_guidance and negative_prompt_embeds is None:
         | 
| 317 | 
            -
                        uncond_tokens: List[str]
         | 
| 318 | 
            -
                        if negative_prompt is None:
         | 
| 319 | 
            -
                            uncond_tokens = [""] * batch_size
         | 
| 320 | 
            -
                        elif prompt is not None and type(prompt) is not type(negative_prompt):
         | 
| 321 | 
            -
                            raise TypeError(
         | 
| 322 | 
            -
                                f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
         | 
| 323 | 
            -
                                f" {type(prompt)}."
         | 
| 324 | 
            -
                            )
         | 
| 325 | 
            -
                        elif isinstance(negative_prompt, str):
         | 
| 326 | 
            -
                            uncond_tokens = [negative_prompt]
         | 
| 327 | 
            -
                        elif batch_size != len(negative_prompt):
         | 
| 328 | 
            -
                            raise ValueError(
         | 
| 329 | 
            -
                                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
         | 
| 330 | 
            -
                                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
         | 
| 331 | 
            -
                                " the batch size of `prompt`."
         | 
| 332 | 
            -
                            )
         | 
| 333 | 
            -
                        else:
         | 
| 334 | 
            -
                            uncond_tokens = negative_prompt
         | 
| 335 | 
            -
             | 
| 336 | 
            -
                        max_length = prompt_embeds.shape[1]
         | 
| 337 | 
            -
                        uncond_input = tokenizer(
         | 
| 338 | 
            -
                            uncond_tokens,
         | 
| 339 | 
            -
                            padding="max_length",
         | 
| 340 | 
            -
                            max_length=max_length,
         | 
| 341 | 
            -
                            truncation=True,
         | 
| 342 | 
            -
                            return_tensors="pt",
         | 
| 343 | 
            -
                        )
         | 
| 344 | 
            -
                        uncond_input_ids = uncond_input.input_ids
         | 
| 345 | 
            -
                        if uncond_input_ids.shape[-1] > actual_max_sequence_length:
         | 
| 346 | 
            -
                            reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
         | 
| 347 | 
            -
                            uncond_input = tokenizer(
         | 
| 348 | 
            -
                                reuncond_tokens,
         | 
| 349 | 
            -
                                padding="max_length",
         | 
| 350 | 
            -
                                max_length=max_length,
         | 
| 351 | 
            -
                                truncation=True,
         | 
| 352 | 
            -
                                return_attention_mask=True,
         | 
| 353 | 
            -
                                return_tensors="pt",
         | 
| 354 | 
            -
                            )
         | 
| 355 | 
            -
                            uncond_input_ids = uncond_input.input_ids
         | 
| 356 | 
            -
             | 
| 357 | 
            -
                        negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
         | 
| 358 | 
            -
                        if self.transformer.config.enable_text_attention_mask:
         | 
| 359 | 
            -
                            negative_prompt_embeds = text_encoder(
         | 
| 360 | 
            -
                                uncond_input.input_ids.to(device),
         | 
| 361 | 
            -
                                attention_mask=negative_prompt_attention_mask,
         | 
| 362 | 
            -
                            )
         | 
| 363 | 
            -
                        else:
         | 
| 364 | 
            -
                            negative_prompt_embeds = text_encoder(
         | 
| 365 | 
            -
                                uncond_input.input_ids.to(device)
         | 
| 366 | 
            -
                            )
         | 
| 367 | 
            -
                        negative_prompt_embeds = negative_prompt_embeds[0]
         | 
| 368 | 
            -
                        negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 369 | 
            -
             | 
| 370 | 
            -
                    if do_classifier_free_guidance:
         | 
| 371 | 
            -
                        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
         | 
| 372 | 
            -
                        seq_len = negative_prompt_embeds.shape[1]
         | 
| 373 | 
            -
             | 
| 374 | 
            -
                        negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
         | 
| 375 | 
            -
             | 
| 376 | 
            -
                        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 377 | 
            -
                        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
         | 
| 378 | 
            -
             | 
| 379 | 
            -
                    return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
         | 
| 380 | 
            -
             | 
| 381 | 
            -
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
         | 
| 382 | 
            -
                def run_safety_checker(self, image, device, dtype):
         | 
| 383 | 
            -
                    if self.safety_checker is None:
         | 
| 384 | 
            -
                        has_nsfw_concept = None
         | 
| 385 | 
            -
                    else:
         | 
| 386 | 
            -
                        if torch.is_tensor(image):
         | 
| 387 | 
            -
                            feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
         | 
| 388 | 
            -
                        else:
         | 
| 389 | 
            -
                            feature_extractor_input = self.image_processor.numpy_to_pil(image)
         | 
| 390 | 
            -
                        safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
         | 
| 391 | 
            -
                        image, has_nsfw_concept = self.safety_checker(
         | 
| 392 | 
            -
                            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
         | 
| 393 | 
            -
                        )
         | 
| 394 | 
            -
                    return image, has_nsfw_concept
         | 
| 395 | 
            -
             | 
| 396 | 
            -
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
         | 
| 397 | 
            -
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 398 | 
            -
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
| 399 | 
            -
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         | 
| 400 | 
            -
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         | 
| 401 | 
            -
                    # and should be between [0, 1]
         | 
| 402 | 
            -
             | 
| 403 | 
            -
                    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 404 | 
            -
                    extra_step_kwargs = {}
         | 
| 405 | 
            -
                    if accepts_eta:
         | 
| 406 | 
            -
                        extra_step_kwargs["eta"] = eta
         | 
| 407 | 
            -
             | 
| 408 | 
            -
                    # check if the scheduler accepts generator
         | 
| 409 | 
            -
                    accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 410 | 
            -
                    if accepts_generator:
         | 
| 411 | 
            -
                        extra_step_kwargs["generator"] = generator
         | 
| 412 | 
            -
                    return extra_step_kwargs
         | 
| 413 | 
            -
             | 
| 414 | 
            -
                def check_inputs(
         | 
| 415 | 
            -
                    self,
         | 
| 416 | 
            -
                    prompt,
         | 
| 417 | 
            -
                    height,
         | 
| 418 | 
            -
                    width,
         | 
| 419 | 
            -
                    negative_prompt=None,
         | 
| 420 | 
            -
                    prompt_embeds=None,
         | 
| 421 | 
            -
                    negative_prompt_embeds=None,
         | 
| 422 | 
            -
                    prompt_attention_mask=None,
         | 
| 423 | 
            -
                    negative_prompt_attention_mask=None,
         | 
| 424 | 
            -
                    prompt_embeds_2=None,
         | 
| 425 | 
            -
                    negative_prompt_embeds_2=None,
         | 
| 426 | 
            -
                    prompt_attention_mask_2=None,
         | 
| 427 | 
            -
                    negative_prompt_attention_mask_2=None,
         | 
| 428 | 
            -
                    callback_on_step_end_tensor_inputs=None,
         | 
| 429 | 
            -
                ):
         | 
| 430 | 
            -
                    if height % 8 != 0 or width % 8 != 0:
         | 
| 431 | 
            -
                        raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
         | 
| 432 | 
            -
             | 
| 433 | 
            -
                    if callback_on_step_end_tensor_inputs is not None and not all(
         | 
| 434 | 
            -
                        k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
         | 
| 435 | 
            -
                    ):
         | 
| 436 | 
            -
                        raise ValueError(
         | 
| 437 | 
            -
                            f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
         | 
| 438 | 
            -
                        )
         | 
| 439 | 
            -
             | 
| 440 | 
            -
                    if prompt is not None and prompt_embeds is not None:
         | 
| 441 | 
            -
                        raise ValueError(
         | 
| 442 | 
            -
                            f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
         | 
| 443 | 
            -
                            " only forward one of the two."
         | 
| 444 | 
            -
                        )
         | 
| 445 | 
            -
                    elif prompt is None and prompt_embeds is None:
         | 
| 446 | 
            -
                        raise ValueError(
         | 
| 447 | 
            -
                            "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
         | 
| 448 | 
            -
                        )
         | 
| 449 | 
            -
                    elif prompt is None and prompt_embeds_2 is None:
         | 
| 450 | 
            -
                        raise ValueError(
         | 
| 451 | 
            -
                            "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
         | 
| 452 | 
            -
                        )
         | 
| 453 | 
            -
                    elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
         | 
| 454 | 
            -
                        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
         | 
| 455 | 
            -
             | 
| 456 | 
            -
                    if prompt_embeds is not None and prompt_attention_mask is None:
         | 
| 457 | 
            -
                        raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
         | 
| 458 | 
            -
             | 
| 459 | 
            -
                    if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
         | 
| 460 | 
            -
                        raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
         | 
| 461 | 
            -
             | 
| 462 | 
            -
                    if negative_prompt is not None and negative_prompt_embeds is not None:
         | 
| 463 | 
            -
                        raise ValueError(
         | 
| 464 | 
            -
                            f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
         | 
| 465 | 
            -
                            f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
         | 
| 466 | 
            -
                        )
         | 
| 467 | 
            -
             | 
| 468 | 
            -
                    if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
         | 
| 469 | 
            -
                        raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
         | 
| 470 | 
            -
             | 
| 471 | 
            -
                    if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
         | 
| 472 | 
            -
                        raise ValueError(
         | 
| 473 | 
            -
                            "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
         | 
| 474 | 
            -
                        )
         | 
| 475 | 
            -
                    if prompt_embeds is not None and negative_prompt_embeds is not None:
         | 
| 476 | 
            -
                        if prompt_embeds.shape != negative_prompt_embeds.shape:
         | 
| 477 | 
            -
                            raise ValueError(
         | 
| 478 | 
            -
                                "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
         | 
| 479 | 
            -
                                f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
         | 
| 480 | 
            -
                                f" {negative_prompt_embeds.shape}."
         | 
| 481 | 
            -
                            )
         | 
| 482 | 
            -
                    if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
         | 
| 483 | 
            -
                        if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
         | 
| 484 | 
            -
                            raise ValueError(
         | 
| 485 | 
            -
                                "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
         | 
| 486 | 
            -
                                f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
         | 
| 487 | 
            -
                                f" {negative_prompt_embeds_2.shape}."
         | 
| 488 | 
            -
                            )
         | 
| 489 | 
            -
             | 
| 490 | 
            -
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
         | 
| 491 | 
            -
                def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
         | 
| 492 | 
            -
                    if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
         | 
| 493 | 
            -
                        if self.vae.cache_mag_vae:
         | 
| 494 | 
            -
                            mini_batch_encoder = self.vae.mini_batch_encoder
         | 
| 495 | 
            -
                            mini_batch_decoder = self.vae.mini_batch_decoder
         | 
| 496 | 
            -
                            shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 497 | 
            -
                        else:
         | 
| 498 | 
            -
                            mini_batch_encoder = self.vae.mini_batch_encoder
         | 
| 499 | 
            -
                            mini_batch_decoder = self.vae.mini_batch_decoder
         | 
| 500 | 
            -
                            shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 501 | 
            -
                    else:
         | 
| 502 | 
            -
                        shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 503 | 
            -
             | 
| 504 | 
            -
                    if isinstance(generator, list) and len(generator) != batch_size:
         | 
| 505 | 
            -
                        raise ValueError(
         | 
| 506 | 
            -
                            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
         | 
| 507 | 
            -
                            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
         | 
| 508 | 
            -
                        )
         | 
| 509 | 
            -
             | 
| 510 | 
            -
                    if latents is None:
         | 
| 511 | 
            -
                        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         | 
| 512 | 
            -
                    else:
         | 
| 513 | 
            -
                        latents = latents.to(device)
         | 
| 514 | 
            -
                    
         | 
| 515 | 
            -
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 516 | 
            -
                    latents = latents * self.scheduler.init_noise_sigma
         | 
| 517 | 
            -
                    return latents
         | 
| 518 | 
            -
             | 
| 519 | 
            -
                def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
         | 
| 520 | 
            -
                    if video.size()[2] <= mini_batch_encoder:
         | 
| 521 | 
            -
                        return video
         | 
| 522 | 
            -
                    prefix_index_before = mini_batch_encoder // 2
         | 
| 523 | 
            -
                    prefix_index_after = mini_batch_encoder - prefix_index_before
         | 
| 524 | 
            -
                    pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
         | 
| 525 | 
            -
             | 
| 526 | 
            -
                    # Encode middle videos
         | 
| 527 | 
            -
                    latents = self.vae.encode(pixel_values)[0]
         | 
| 528 | 
            -
                    latents = latents.mode()
         | 
| 529 | 
            -
                    # Decode middle videos
         | 
| 530 | 
            -
                    middle_video = self.vae.decode(latents)[0]
         | 
| 531 | 
            -
             | 
| 532 | 
            -
                    video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
         | 
| 533 | 
            -
                    return video
         | 
| 534 | 
            -
             | 
| 535 | 
            -
                def decode_latents(self, latents):
         | 
| 536 | 
            -
                    video_length = latents.shape[2]
         | 
| 537 | 
            -
                    latents = 1 / self.vae.config.scaling_factor * latents
         | 
| 538 | 
            -
                    if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
         | 
| 539 | 
            -
                        mini_batch_encoder = self.vae.mini_batch_encoder
         | 
| 540 | 
            -
                        mini_batch_decoder = self.vae.mini_batch_decoder
         | 
| 541 | 
            -
                        video = self.vae.decode(latents)[0]
         | 
| 542 | 
            -
                        video = video.clamp(-1, 1)
         | 
| 543 | 
            -
                        if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
         | 
| 544 | 
            -
                            video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
         | 
| 545 | 
            -
                    else:
         | 
| 546 | 
            -
                        latents = rearrange(latents, "b c f h w -> (b f) c h w")
         | 
| 547 | 
            -
                        video = []
         | 
| 548 | 
            -
                        for frame_idx in tqdm(range(latents.shape[0])):
         | 
| 549 | 
            -
                            video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
         | 
| 550 | 
            -
                        video = torch.cat(video)
         | 
| 551 | 
            -
                        video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 552 | 
            -
                    video = (video / 2 + 0.5).clamp(0, 1)
         | 
| 553 | 
            -
                    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
         | 
| 554 | 
            -
                    video = video.cpu().float().numpy()
         | 
| 555 | 
            -
                    return video
         | 
| 556 | 
            -
             | 
| 557 | 
            -
                @property
         | 
| 558 | 
            -
                def guidance_scale(self):
         | 
| 559 | 
            -
                    return self._guidance_scale
         | 
| 560 | 
            -
             | 
| 561 | 
            -
                @property
         | 
| 562 | 
            -
                def guidance_rescale(self):
         | 
| 563 | 
            -
                    return self._guidance_rescale
         | 
| 564 | 
            -
             | 
| 565 | 
            -
                # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 566 | 
            -
                # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 567 | 
            -
                # corresponds to doing no classifier free guidance.
         | 
| 568 | 
            -
                @property
         | 
| 569 | 
            -
                def do_classifier_free_guidance(self):
         | 
| 570 | 
            -
                    return self._guidance_scale > 1
         | 
| 571 | 
            -
             | 
| 572 | 
            -
                @property
         | 
| 573 | 
            -
                def num_timesteps(self):
         | 
| 574 | 
            -
                    return self._num_timesteps
         | 
| 575 | 
            -
             | 
| 576 | 
            -
                @property
         | 
| 577 | 
            -
                def interrupt(self):
         | 
| 578 | 
            -
                    return self._interrupt
         | 
| 579 | 
            -
             | 
| 580 | 
            -
                def enable_autocast_float8_transformer(self):
         | 
| 581 | 
            -
                    self.enable_autocast_float8_transformer_flag = True
         | 
| 582 | 
            -
             | 
| 583 | 
            -
                @torch.no_grad()
         | 
| 584 | 
            -
                @replace_example_docstring(EXAMPLE_DOC_STRING)
         | 
| 585 | 
            -
                def __call__(
         | 
| 586 | 
            -
                    self,
         | 
| 587 | 
            -
                    prompt: Union[str, List[str]] = None,
         | 
| 588 | 
            -
                    video_length: Optional[int] = None,
         | 
| 589 | 
            -
                    height: Optional[int] = None,
         | 
| 590 | 
            -
                    width: Optional[int] = None,
         | 
| 591 | 
            -
                    num_inference_steps: Optional[int] = 50,
         | 
| 592 | 
            -
                    guidance_scale: Optional[float] = 5.0,
         | 
| 593 | 
            -
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 594 | 
            -
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 595 | 
            -
                    eta: Optional[float] = 0.0,
         | 
| 596 | 
            -
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 597 | 
            -
                    latents: Optional[torch.Tensor] = None,
         | 
| 598 | 
            -
                    prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 599 | 
            -
                    prompt_embeds_2: Optional[torch.Tensor] = None,
         | 
| 600 | 
            -
                    negative_prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 601 | 
            -
                    negative_prompt_embeds_2: Optional[torch.Tensor] = None,
         | 
| 602 | 
            -
                    prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 603 | 
            -
                    prompt_attention_mask_2: Optional[torch.Tensor] = None,
         | 
| 604 | 
            -
                    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 605 | 
            -
                    negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
         | 
| 606 | 
            -
                    output_type: Optional[str] = "latent",
         | 
| 607 | 
            -
                    return_dict: bool = True,
         | 
| 608 | 
            -
                    callback_on_step_end: Optional[
         | 
| 609 | 
            -
                        Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
         | 
| 610 | 
            -
                    ] = None,
         | 
| 611 | 
            -
                    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
         | 
| 612 | 
            -
                    guidance_rescale: float = 0.0,
         | 
| 613 | 
            -
                    original_size: Optional[Tuple[int, int]] = (1024, 1024),
         | 
| 614 | 
            -
                    target_size: Optional[Tuple[int, int]] = None,
         | 
| 615 | 
            -
                    crops_coords_top_left: Tuple[int, int] = (0, 0),
         | 
| 616 | 
            -
                    comfyui_progressbar: bool = False,
         | 
| 617 | 
            -
                ):
         | 
| 618 | 
            -
                    r"""
         | 
| 619 | 
            -
                    Generates images or video using the EasyAnimate pipeline based on the provided prompts.
         | 
| 620 | 
            -
             | 
| 621 | 
            -
                    Examples:
         | 
| 622 | 
            -
                        prompt (`str` or `List[str]`, *optional*): 
         | 
| 623 | 
            -
                            Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
         | 
| 624 | 
            -
                        video_length (`int`, *optional*): 
         | 
| 625 | 
            -
                            Length of the generated video (in frames).
         | 
| 626 | 
            -
                        height (`int`, *optional*): 
         | 
| 627 | 
            -
                            Height of the generated image in pixels.
         | 
| 628 | 
            -
                        width (`int`, *optional*): 
         | 
| 629 | 
            -
                            Width of the generated image in pixels.
         | 
| 630 | 
            -
                        num_inference_steps (`int`, *optional*, defaults to 50): 
         | 
| 631 | 
            -
                            Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference.
         | 
| 632 | 
            -
                        guidance_scale (`float`, *optional*, defaults to 5.0): 
         | 
| 633 | 
            -
                            Encourages the model to align outputs with prompts. A higher value may decrease image quality.
         | 
| 634 | 
            -
                        negative_prompt (`str` or `List[str]`, *optional*): 
         | 
| 635 | 
            -
                            Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
         | 
| 636 | 
            -
                        num_images_per_prompt (`int`, *optional*, defaults to 1): 
         | 
| 637 | 
            -
                            Number of images to generate for each prompt.
         | 
| 638 | 
            -
                        eta (`float`, *optional*, defaults to 0.0): 
         | 
| 639 | 
            -
                            Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
         | 
| 640 | 
            -
                        generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 
         | 
| 641 | 
            -
                            A generator to ensure reproducibility in image generation.
         | 
| 642 | 
            -
                        latents (`torch.Tensor`, *optional*): 
         | 
| 643 | 
            -
                            Predefined latent tensors to condition generation.
         | 
| 644 | 
            -
                        prompt_embeds (`torch.Tensor`, *optional*): 
         | 
| 645 | 
            -
                            Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
         | 
| 646 | 
            -
                        prompt_embeds_2 (`torch.Tensor`, *optional*): 
         | 
| 647 | 
            -
                            Secondary text embeddings to supplement or replace the initial prompt embeddings.
         | 
| 648 | 
            -
                        negative_prompt_embeds (`torch.Tensor`, *optional*): 
         | 
| 649 | 
            -
                            Embeddings for negative prompts. Overrides string inputs if defined.
         | 
| 650 | 
            -
                        negative_prompt_embeds_2 (`torch.Tensor`, *optional*): 
         | 
| 651 | 
            -
                            Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`.
         | 
| 652 | 
            -
                        prompt_attention_mask (`torch.Tensor`, *optional*): 
         | 
| 653 | 
            -
                            Attention mask for the primary prompt embeddings.
         | 
| 654 | 
            -
                        prompt_attention_mask_2 (`torch.Tensor`, *optional*): 
         | 
| 655 | 
            -
                            Attention mask for the secondary prompt embeddings.
         | 
| 656 | 
            -
                        negative_prompt_attention_mask (`torch.Tensor`, *optional*): 
         | 
| 657 | 
            -
                            Attention mask for negative prompt embeddings.
         | 
| 658 | 
            -
                        negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): 
         | 
| 659 | 
            -
                            Attention mask for secondary negative prompt embeddings.
         | 
| 660 | 
            -
                        output_type (`str`, *optional*, defaults to "latent"): 
         | 
| 661 | 
            -
                            Format of the generated output, either as a PIL image or as a NumPy array.
         | 
| 662 | 
            -
                        return_dict (`bool`, *optional*, defaults to `True`): 
         | 
| 663 | 
            -
                            If `True`, returns a structured output. Otherwise returns a simple tuple.
         | 
| 664 | 
            -
                        callback_on_step_end (`Callable`, *optional*): 
         | 
| 665 | 
            -
                            Functions called at the end of each denoising step.
         | 
| 666 | 
            -
                        callback_on_step_end_tensor_inputs (`List[str]`, *optional*): 
         | 
| 667 | 
            -
                            Tensor names to be included in callback function calls.
         | 
| 668 | 
            -
                        guidance_rescale (`float`, *optional*, defaults to 0.0): 
         | 
| 669 | 
            -
                            Adjusts noise levels based on guidance scale.
         | 
| 670 | 
            -
                        original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): 
         | 
| 671 | 
            -
                            Original dimensions of the output.
         | 
| 672 | 
            -
                        target_size (`Tuple[int, int]`, *optional*): 
         | 
| 673 | 
            -
                            Desired output dimensions for calculations.
         | 
| 674 | 
            -
                        crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): 
         | 
| 675 | 
            -
                            Coordinates for cropping.
         | 
| 676 | 
            -
             | 
| 677 | 
            -
                    Returns:
         | 
| 678 | 
            -
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 679 | 
            -
                            If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
         | 
| 680 | 
            -
                            otherwise a `tuple` is returned where the first element is a list with the generated images and the
         | 
| 681 | 
            -
                            second element is a list of `bool`s indicating whether the corresponding generated image contains
         | 
| 682 | 
            -
                            "not-safe-for-work" (nsfw) content.
         | 
| 683 | 
            -
                    """
         | 
| 684 | 
            -
             | 
| 685 | 
            -
                    if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
         | 
| 686 | 
            -
                        callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
         | 
| 687 | 
            -
             | 
| 688 | 
            -
                    # 0. default height and width
         | 
| 689 | 
            -
                    height = int((height // 16) * 16)
         | 
| 690 | 
            -
                    width = int((width // 16) * 16)
         | 
| 691 | 
            -
             | 
| 692 | 
            -
                    # 1. Check inputs. Raise error if not correct
         | 
| 693 | 
            -
                    self.check_inputs(
         | 
| 694 | 
            -
                        prompt,
         | 
| 695 | 
            -
                        height,
         | 
| 696 | 
            -
                        width,
         | 
| 697 | 
            -
                        negative_prompt,
         | 
| 698 | 
            -
                        prompt_embeds,
         | 
| 699 | 
            -
                        negative_prompt_embeds,
         | 
| 700 | 
            -
                        prompt_attention_mask,
         | 
| 701 | 
            -
                        negative_prompt_attention_mask,
         | 
| 702 | 
            -
                        prompt_embeds_2,
         | 
| 703 | 
            -
                        negative_prompt_embeds_2,
         | 
| 704 | 
            -
                        prompt_attention_mask_2,
         | 
| 705 | 
            -
                        negative_prompt_attention_mask_2,
         | 
| 706 | 
            -
                        callback_on_step_end_tensor_inputs,
         | 
| 707 | 
            -
                    )
         | 
| 708 | 
            -
                    self._guidance_scale = guidance_scale
         | 
| 709 | 
            -
                    self._guidance_rescale = guidance_rescale
         | 
| 710 | 
            -
                    self._interrupt = False
         | 
| 711 | 
            -
             | 
| 712 | 
            -
                    # 2. Define call parameters
         | 
| 713 | 
            -
                    if prompt is not None and isinstance(prompt, str):
         | 
| 714 | 
            -
                        batch_size = 1
         | 
| 715 | 
            -
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 716 | 
            -
                        batch_size = len(prompt)
         | 
| 717 | 
            -
                    else:
         | 
| 718 | 
            -
                        batch_size = prompt_embeds.shape[0]
         | 
| 719 | 
            -
             | 
| 720 | 
            -
                    device = self._execution_device
         | 
| 721 | 
            -
             | 
| 722 | 
            -
                    # 3. Encode input prompt
         | 
| 723 | 
            -
                    (
         | 
| 724 | 
            -
                        prompt_embeds,
         | 
| 725 | 
            -
                        negative_prompt_embeds,
         | 
| 726 | 
            -
                        prompt_attention_mask,
         | 
| 727 | 
            -
                        negative_prompt_attention_mask,
         | 
| 728 | 
            -
                    ) = self.encode_prompt(
         | 
| 729 | 
            -
                        prompt=prompt,
         | 
| 730 | 
            -
                        device=device,
         | 
| 731 | 
            -
                        dtype=self.transformer.dtype,
         | 
| 732 | 
            -
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 733 | 
            -
                        do_classifier_free_guidance=self.do_classifier_free_guidance,
         | 
| 734 | 
            -
                        negative_prompt=negative_prompt,
         | 
| 735 | 
            -
                        prompt_embeds=prompt_embeds,
         | 
| 736 | 
            -
                        negative_prompt_embeds=negative_prompt_embeds,
         | 
| 737 | 
            -
                        prompt_attention_mask=prompt_attention_mask,
         | 
| 738 | 
            -
                        negative_prompt_attention_mask=negative_prompt_attention_mask,
         | 
| 739 | 
            -
                        text_encoder_index=0,
         | 
| 740 | 
            -
                    )
         | 
| 741 | 
            -
                    (
         | 
| 742 | 
            -
                        prompt_embeds_2,
         | 
| 743 | 
            -
                        negative_prompt_embeds_2,
         | 
| 744 | 
            -
                        prompt_attention_mask_2,
         | 
| 745 | 
            -
                        negative_prompt_attention_mask_2,
         | 
| 746 | 
            -
                    ) = self.encode_prompt(
         | 
| 747 | 
            -
                        prompt=prompt,
         | 
| 748 | 
            -
                        device=device,
         | 
| 749 | 
            -
                        dtype=self.transformer.dtype,
         | 
| 750 | 
            -
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 751 | 
            -
                        do_classifier_free_guidance=self.do_classifier_free_guidance,
         | 
| 752 | 
            -
                        negative_prompt=negative_prompt,
         | 
| 753 | 
            -
                        prompt_embeds=prompt_embeds_2,
         | 
| 754 | 
            -
                        negative_prompt_embeds=negative_prompt_embeds_2,
         | 
| 755 | 
            -
                        prompt_attention_mask=prompt_attention_mask_2,
         | 
| 756 | 
            -
                        negative_prompt_attention_mask=negative_prompt_attention_mask_2,
         | 
| 757 | 
            -
                        text_encoder_index=1,
         | 
| 758 | 
            -
                    )
         | 
| 759 | 
            -
                    torch.cuda.empty_cache()
         | 
| 760 | 
            -
             | 
| 761 | 
            -
                    # 4. Prepare timesteps
         | 
| 762 | 
            -
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         | 
| 763 | 
            -
                    timesteps = self.scheduler.timesteps
         | 
| 764 | 
            -
                    if comfyui_progressbar:
         | 
| 765 | 
            -
                        from comfy.utils import ProgressBar
         | 
| 766 | 
            -
                        pbar = ProgressBar(num_inference_steps + 1)
         | 
| 767 | 
            -
             | 
| 768 | 
            -
                    # 5. Prepare latent variables
         | 
| 769 | 
            -
                    num_channels_latents = self.transformer.config.in_channels
         | 
| 770 | 
            -
                    latents = self.prepare_latents(
         | 
| 771 | 
            -
                        batch_size * num_images_per_prompt,
         | 
| 772 | 
            -
                        num_channels_latents,
         | 
| 773 | 
            -
                        video_length,
         | 
| 774 | 
            -
                        height,
         | 
| 775 | 
            -
                        width,
         | 
| 776 | 
            -
                        prompt_embeds.dtype,
         | 
| 777 | 
            -
                        device,
         | 
| 778 | 
            -
                        generator,
         | 
| 779 | 
            -
                        latents,
         | 
| 780 | 
            -
                    )
         | 
| 781 | 
            -
                    if comfyui_progressbar:
         | 
| 782 | 
            -
                        pbar.update(1)
         | 
| 783 | 
            -
             | 
| 784 | 
            -
                    # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
         | 
| 785 | 
            -
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 786 | 
            -
             | 
| 787 | 
            -
                    # 7 create image_rotary_emb, style embedding & time ids
         | 
| 788 | 
            -
                    grid_height = height // 8 // self.transformer.config.patch_size
         | 
| 789 | 
            -
                    grid_width = width // 8 // self.transformer.config.patch_size
         | 
| 790 | 
            -
                    if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
         | 
| 791 | 
            -
                        base_size_width = 720 // 8 // self.transformer.config.patch_size
         | 
| 792 | 
            -
                        base_size_height = 480 // 8 // self.transformer.config.patch_size
         | 
| 793 | 
            -
             | 
| 794 | 
            -
                        grid_crops_coords = get_resize_crop_region_for_grid(
         | 
| 795 | 
            -
                            (grid_height, grid_width), base_size_width, base_size_height
         | 
| 796 | 
            -
                        )
         | 
| 797 | 
            -
                        image_rotary_emb = get_3d_rotary_pos_embed(
         | 
| 798 | 
            -
                            self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
         | 
| 799 | 
            -
                            temporal_size=latents.size(2), use_real=True,
         | 
| 800 | 
            -
                        )
         | 
| 801 | 
            -
                    else:
         | 
| 802 | 
            -
                        base_size = 512 // 8 // self.transformer.config.patch_size
         | 
| 803 | 
            -
                        grid_crops_coords = get_resize_crop_region_for_grid(
         | 
| 804 | 
            -
                            (grid_height, grid_width), base_size, base_size
         | 
| 805 | 
            -
                        )
         | 
| 806 | 
            -
                        image_rotary_emb = get_2d_rotary_pos_embed(
         | 
| 807 | 
            -
                            self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
         | 
| 808 | 
            -
                        )
         | 
| 809 | 
            -
             | 
| 810 | 
            -
                    # Get other hunyuan params
         | 
| 811 | 
            -
                    style = torch.tensor([0], device=device)
         | 
| 812 | 
            -
             | 
| 813 | 
            -
                    target_size = target_size or (height, width)
         | 
| 814 | 
            -
                    add_time_ids = list(original_size + target_size + crops_coords_top_left)
         | 
| 815 | 
            -
                    add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
         | 
| 816 | 
            -
             | 
| 817 | 
            -
                    if self.do_classifier_free_guidance:
         | 
| 818 | 
            -
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
         | 
| 819 | 
            -
                        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
         | 
| 820 | 
            -
                        prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
         | 
| 821 | 
            -
                        prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
         | 
| 822 | 
            -
                        add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
         | 
| 823 | 
            -
                        style = torch.cat([style] * 2, dim=0)
         | 
| 824 | 
            -
             | 
| 825 | 
            -
                    # To latents.device
         | 
| 826 | 
            -
                    prompt_embeds = prompt_embeds.to(device=device)
         | 
| 827 | 
            -
                    prompt_attention_mask = prompt_attention_mask.to(device=device)
         | 
| 828 | 
            -
                    prompt_embeds_2 = prompt_embeds_2.to(device=device)
         | 
| 829 | 
            -
                    prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
         | 
| 830 | 
            -
                    add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
         | 
| 831 | 
            -
                        batch_size * num_images_per_prompt, 1
         | 
| 832 | 
            -
                    )
         | 
| 833 | 
            -
                    style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
         | 
| 834 | 
            -
             | 
| 835 | 
            -
                    torch.cuda.empty_cache()
         | 
| 836 | 
            -
                    if self.enable_autocast_float8_transformer_flag:
         | 
| 837 | 
            -
                        origin_weight_dtype = self.transformer.dtype
         | 
| 838 | 
            -
                        self.transformer = self.transformer.to(torch.float8_e4m3fn)
         | 
| 839 | 
            -
                    # 8. Denoising loop
         | 
| 840 | 
            -
                    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
         | 
| 841 | 
            -
                    self._num_timesteps = len(timesteps)
         | 
| 842 | 
            -
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 843 | 
            -
                        for i, t in enumerate(timesteps):
         | 
| 844 | 
            -
                            if self.interrupt:
         | 
| 845 | 
            -
                                continue
         | 
| 846 | 
            -
             | 
| 847 | 
            -
                            # expand the latents if we are doing classifier free guidance
         | 
| 848 | 
            -
                            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
         | 
| 849 | 
            -
                            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         | 
| 850 | 
            -
             | 
| 851 | 
            -
                            # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
         | 
| 852 | 
            -
                            t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
         | 
| 853 | 
            -
                                dtype=latent_model_input.dtype
         | 
| 854 | 
            -
                            )
         | 
| 855 | 
            -
             | 
| 856 | 
            -
                            # predict the noise residual
         | 
| 857 | 
            -
                            noise_pred = self.transformer(
         | 
| 858 | 
            -
                                latent_model_input,
         | 
| 859 | 
            -
                                t_expand,
         | 
| 860 | 
            -
                                encoder_hidden_states=prompt_embeds,
         | 
| 861 | 
            -
                                text_embedding_mask=prompt_attention_mask,
         | 
| 862 | 
            -
                                encoder_hidden_states_t5=prompt_embeds_2,
         | 
| 863 | 
            -
                                text_embedding_mask_t5=prompt_attention_mask_2,
         | 
| 864 | 
            -
                                image_meta_size=add_time_ids,
         | 
| 865 | 
            -
                                style=style,
         | 
| 866 | 
            -
                                image_rotary_emb=image_rotary_emb,
         | 
| 867 | 
            -
                                return_dict=False,
         | 
| 868 | 
            -
                            )[0]
         | 
| 869 | 
            -
                            
         | 
| 870 | 
            -
                            if noise_pred.size()[1] != self.vae.config.latent_channels:
         | 
| 871 | 
            -
                                noise_pred, _ = noise_pred.chunk(2, dim=1)
         | 
| 872 | 
            -
             | 
| 873 | 
            -
                            # perform guidance
         | 
| 874 | 
            -
                            if self.do_classifier_free_guidance:
         | 
| 875 | 
            -
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 876 | 
            -
                                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 877 | 
            -
             | 
| 878 | 
            -
                            if self.do_classifier_free_guidance and guidance_rescale > 0.0:
         | 
| 879 | 
            -
                                # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
         | 
| 880 | 
            -
                                noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
         | 
| 881 | 
            -
             | 
| 882 | 
            -
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 883 | 
            -
                            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
         | 
| 884 | 
            -
             | 
| 885 | 
            -
                            if callback_on_step_end is not None:
         | 
| 886 | 
            -
                                callback_kwargs = {}
         | 
| 887 | 
            -
                                for k in callback_on_step_end_tensor_inputs:
         | 
| 888 | 
            -
                                    callback_kwargs[k] = locals()[k]
         | 
| 889 | 
            -
                                callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
         | 
| 890 | 
            -
             | 
| 891 | 
            -
                                latents = callback_outputs.pop("latents", latents)
         | 
| 892 | 
            -
                                prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
         | 
| 893 | 
            -
                                negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
         | 
| 894 | 
            -
                                prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
         | 
| 895 | 
            -
                                negative_prompt_embeds_2 = callback_outputs.pop(
         | 
| 896 | 
            -
                                    "negative_prompt_embeds_2", negative_prompt_embeds_2
         | 
| 897 | 
            -
                                )
         | 
| 898 | 
            -
             | 
| 899 | 
            -
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
| 900 | 
            -
                                progress_bar.update()
         | 
| 901 | 
            -
             | 
| 902 | 
            -
                            if XLA_AVAILABLE:
         | 
| 903 | 
            -
                                xm.mark_step()
         | 
| 904 | 
            -
             | 
| 905 | 
            -
                            if comfyui_progressbar:
         | 
| 906 | 
            -
                                pbar.update(1)
         | 
| 907 | 
            -
             | 
| 908 | 
            -
                    if self.enable_autocast_float8_transformer_flag:
         | 
| 909 | 
            -
                        self.transformer = self.transformer.to("cpu", origin_weight_dtype)
         | 
| 910 | 
            -
             | 
| 911 | 
            -
                    torch.cuda.empty_cache()
         | 
| 912 | 
            -
                    # Post-processing
         | 
| 913 | 
            -
                    video = self.decode_latents(latents)
         | 
| 914 | 
            -
             | 
| 915 | 
            -
                    # Convert to tensor
         | 
| 916 | 
            -
                    if output_type == "latent":
         | 
| 917 | 
            -
                        video = torch.from_numpy(video)
         | 
| 918 | 
            -
             | 
| 919 | 
            -
                    # Offload all models
         | 
| 920 | 
            -
                    self.maybe_free_model_hooks()
         | 
| 921 | 
            -
             | 
| 922 | 
            -
                    if not return_dict:
         | 
| 923 | 
            -
                        return video
         | 
| 924 | 
            -
             | 
| 925 | 
            -
                    return EasyAnimatePipelineOutput(videos=video)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py
    DELETED
    
    | @@ -1,1334 +0,0 @@ | |
| 1 | 
            -
            # Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            -
            #
         | 
| 3 | 
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            -
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            -
            # You may obtain a copy of the License at
         | 
| 6 | 
            -
            #
         | 
| 7 | 
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            -
            #
         | 
| 9 | 
            -
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            -
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            -
            # limitations under the License.
         | 
| 14 | 
            -
             | 
| 15 | 
            -
            import inspect
         | 
| 16 | 
            -
            from typing import Callable, Dict, List, Optional, Tuple, Union
         | 
| 17 | 
            -
             | 
| 18 | 
            -
            import torch
         | 
| 19 | 
            -
            import torch.nn.functional as F
         | 
| 20 | 
            -
            from diffusers import DiffusionPipeline
         | 
| 21 | 
            -
            from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
         | 
| 22 | 
            -
            from diffusers.image_processor import VaeImageProcessor
         | 
| 23 | 
            -
            from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
         | 
| 24 | 
            -
            from diffusers.models.embeddings import (get_2d_rotary_pos_embed,
         | 
| 25 | 
            -
                                                     get_3d_rotary_pos_embed)
         | 
| 26 | 
            -
            from diffusers.pipelines.pipeline_utils import DiffusionPipeline
         | 
| 27 | 
            -
            from diffusers.pipelines.stable_diffusion.safety_checker import \
         | 
| 28 | 
            -
                StableDiffusionSafetyChecker
         | 
| 29 | 
            -
            from diffusers.schedulers import DDIMScheduler
         | 
| 30 | 
            -
            from diffusers.utils import (is_torch_xla_available, logging,
         | 
| 31 | 
            -
                                         replace_example_docstring)
         | 
| 32 | 
            -
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 33 | 
            -
            from einops import rearrange
         | 
| 34 | 
            -
            from PIL import Image
         | 
| 35 | 
            -
            from tqdm import tqdm
         | 
| 36 | 
            -
            from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
         | 
| 37 | 
            -
                                      CLIPVisionModelWithProjection, T5Tokenizer,
         | 
| 38 | 
            -
                                      T5EncoderModel)
         | 
| 39 | 
            -
             | 
| 40 | 
            -
            from .pipeline_easyanimate import EasyAnimatePipelineOutput
         | 
| 41 | 
            -
            from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
         | 
| 42 | 
            -
             | 
| 43 | 
            -
            if is_torch_xla_available():
         | 
| 44 | 
            -
                import torch_xla.core.xla_model as xm
         | 
| 45 | 
            -
             | 
| 46 | 
            -
                XLA_AVAILABLE = True
         | 
| 47 | 
            -
            else:
         | 
| 48 | 
            -
                XLA_AVAILABLE = False
         | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 52 | 
            -
             | 
| 53 | 
            -
            EXAMPLE_DOC_STRING = """
         | 
| 54 | 
            -
                Examples:
         | 
| 55 | 
            -
                    ```py
         | 
| 56 | 
            -
                    >>> pass
         | 
| 57 | 
            -
                    ```
         | 
| 58 | 
            -
            """
         | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
            def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
         | 
| 62 | 
            -
                tw = tgt_width
         | 
| 63 | 
            -
                th = tgt_height
         | 
| 64 | 
            -
                h, w = src
         | 
| 65 | 
            -
                r = h / w
         | 
| 66 | 
            -
                if r > (th / tw):
         | 
| 67 | 
            -
                    resize_height = th
         | 
| 68 | 
            -
                    resize_width = int(round(th / h * w))
         | 
| 69 | 
            -
                else:
         | 
| 70 | 
            -
                    resize_width = tw
         | 
| 71 | 
            -
                    resize_height = int(round(tw / w * h))
         | 
| 72 | 
            -
             | 
| 73 | 
            -
                crop_top = int(round((th - resize_height) / 2.0))
         | 
| 74 | 
            -
                crop_left = int(round((tw - resize_width) / 2.0))
         | 
| 75 | 
            -
             | 
| 76 | 
            -
                return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
         | 
| 77 | 
            -
             | 
| 78 | 
            -
             | 
| 79 | 
            -
            # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
         | 
| 80 | 
            -
            def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
         | 
| 81 | 
            -
                """
         | 
| 82 | 
            -
                Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
         | 
| 83 | 
            -
                Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
         | 
| 84 | 
            -
                """
         | 
| 85 | 
            -
                std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
         | 
| 86 | 
            -
                std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
         | 
| 87 | 
            -
                # rescale the results from guidance (fixes overexposure)
         | 
| 88 | 
            -
                noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
         | 
| 89 | 
            -
                # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
         | 
| 90 | 
            -
                noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
         | 
| 91 | 
            -
                return noise_cfg
         | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
| 94 | 
            -
            def resize_mask(mask, latent, process_first_frame_only=True):
         | 
| 95 | 
            -
                latent_size = latent.size()
         | 
| 96 | 
            -
             | 
| 97 | 
            -
                if process_first_frame_only:
         | 
| 98 | 
            -
                    target_size = list(latent_size[2:])
         | 
| 99 | 
            -
                    target_size[0] = 1
         | 
| 100 | 
            -
                    first_frame_resized = F.interpolate(
         | 
| 101 | 
            -
                        mask[:, :, 0:1, :, :],
         | 
| 102 | 
            -
                        size=target_size,
         | 
| 103 | 
            -
                        mode='trilinear',
         | 
| 104 | 
            -
                        align_corners=False
         | 
| 105 | 
            -
                    )
         | 
| 106 | 
            -
                    
         | 
| 107 | 
            -
                    target_size = list(latent_size[2:])
         | 
| 108 | 
            -
                    target_size[0] = target_size[0] - 1
         | 
| 109 | 
            -
                    if target_size[0] != 0:
         | 
| 110 | 
            -
                        remaining_frames_resized = F.interpolate(
         | 
| 111 | 
            -
                            mask[:, :, 1:, :, :],
         | 
| 112 | 
            -
                            size=target_size,
         | 
| 113 | 
            -
                            mode='trilinear',
         | 
| 114 | 
            -
                            align_corners=False
         | 
| 115 | 
            -
                        )
         | 
| 116 | 
            -
                        resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
         | 
| 117 | 
            -
                    else:
         | 
| 118 | 
            -
                        resized_mask = first_frame_resized
         | 
| 119 | 
            -
                else:
         | 
| 120 | 
            -
                    target_size = list(latent_size[2:])
         | 
| 121 | 
            -
                    resized_mask = F.interpolate(
         | 
| 122 | 
            -
                        mask,
         | 
| 123 | 
            -
                        size=target_size,
         | 
| 124 | 
            -
                        mode='trilinear',
         | 
| 125 | 
            -
                        align_corners=False
         | 
| 126 | 
            -
                    )
         | 
| 127 | 
            -
                return resized_mask
         | 
| 128 | 
            -
             | 
| 129 | 
            -
             | 
| 130 | 
            -
            def add_noise_to_reference_video(image, ratio=None):
         | 
| 131 | 
            -
                if ratio is None:
         | 
| 132 | 
            -
                    sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
         | 
| 133 | 
            -
                    sigma = torch.exp(sigma).to(image.dtype)
         | 
| 134 | 
            -
                else:
         | 
| 135 | 
            -
                    sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
         | 
| 136 | 
            -
                
         | 
| 137 | 
            -
                image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
         | 
| 138 | 
            -
                image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
         | 
| 139 | 
            -
                image = image + image_noise
         | 
| 140 | 
            -
                return image
         | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
            class EasyAnimatePipeline_Multi_Text_Encoder_Inpaint(DiffusionPipeline):
         | 
| 144 | 
            -
                r"""
         | 
| 145 | 
            -
                Pipeline for text-to-video generation using EasyAnimate.
         | 
| 146 | 
            -
             | 
| 147 | 
            -
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
         | 
| 148 | 
            -
                library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
         | 
| 149 | 
            -
             | 
| 150 | 
            -
                EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by
         | 
| 151 | 
            -
                HunyuanDiT team)
         | 
| 152 | 
            -
             | 
| 153 | 
            -
                Args:
         | 
| 154 | 
            -
                    vae ([`AutoencoderKLMagvit`]):
         | 
| 155 | 
            -
                        Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. 
         | 
| 156 | 
            -
                    text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]):
         | 
| 157 | 
            -
                        Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
         | 
| 158 | 
            -
                        EasyAnimate uses a fine-tuned [bilingual CLIP].
         | 
| 159 | 
            -
                    tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]):
         | 
| 160 | 
            -
                        A `BertTokenizer` or `CLIPTokenizer` to tokenize text.
         | 
| 161 | 
            -
                    transformer ([`EasyAnimateTransformer3DModel`]):
         | 
| 162 | 
            -
                        The EasyAnimate model designed by Tencent Hunyuan.
         | 
| 163 | 
            -
                    text_encoder_2 (`T5EncoderModel`):
         | 
| 164 | 
            -
                        The mT5 embedder. 
         | 
| 165 | 
            -
                    tokenizer_2 (`T5Tokenizer`):
         | 
| 166 | 
            -
                        The tokenizer for the mT5 embedder.
         | 
| 167 | 
            -
                    scheduler ([`DDIMScheduler`]):
         | 
| 168 | 
            -
                        A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
         | 
| 169 | 
            -
                    clip_image_processor (`CLIPImageProcessor`):
         | 
| 170 | 
            -
                        The CLIP image embedder. 
         | 
| 171 | 
            -
                    clip_image_encoder (`CLIPVisionModelWithProjection`):
         | 
| 172 | 
            -
                        The image processor for the CLIP image embedder.
         | 
| 173 | 
            -
                """
         | 
| 174 | 
            -
             | 
| 175 | 
            -
                model_cpu_offload_seq = "text_encoder->text_encoder_2->clip_image_encoder->transformer->vae"
         | 
| 176 | 
            -
                _optional_components = [
         | 
| 177 | 
            -
                    "safety_checker",
         | 
| 178 | 
            -
                    "feature_extractor",
         | 
| 179 | 
            -
                    "text_encoder_2",
         | 
| 180 | 
            -
                    "tokenizer_2",
         | 
| 181 | 
            -
                    "text_encoder",
         | 
| 182 | 
            -
                    "tokenizer",
         | 
| 183 | 
            -
                    "clip_image_encoder",
         | 
| 184 | 
            -
                ]
         | 
| 185 | 
            -
                _exclude_from_cpu_offload = ["safety_checker"]
         | 
| 186 | 
            -
                _callback_tensor_inputs = [
         | 
| 187 | 
            -
                    "latents",
         | 
| 188 | 
            -
                    "prompt_embeds",
         | 
| 189 | 
            -
                    "negative_prompt_embeds",
         | 
| 190 | 
            -
                    "prompt_embeds_2",
         | 
| 191 | 
            -
                    "negative_prompt_embeds_2",
         | 
| 192 | 
            -
                ]
         | 
| 193 | 
            -
             | 
| 194 | 
            -
                def __init__(
         | 
| 195 | 
            -
                    self,
         | 
| 196 | 
            -
                    vae: AutoencoderKLMagvit,
         | 
| 197 | 
            -
                    text_encoder: BertModel,
         | 
| 198 | 
            -
                    tokenizer: BertTokenizer,
         | 
| 199 | 
            -
                    text_encoder_2: T5EncoderModel,
         | 
| 200 | 
            -
                    tokenizer_2: T5Tokenizer,
         | 
| 201 | 
            -
                    transformer: EasyAnimateTransformer3DModel,
         | 
| 202 | 
            -
                    scheduler: DDIMScheduler,
         | 
| 203 | 
            -
                    safety_checker: StableDiffusionSafetyChecker,
         | 
| 204 | 
            -
                    feature_extractor: CLIPImageProcessor,
         | 
| 205 | 
            -
                    requires_safety_checker: bool = True,
         | 
| 206 | 
            -
                    clip_image_processor: CLIPImageProcessor = None,
         | 
| 207 | 
            -
                    clip_image_encoder: CLIPVisionModelWithProjection = None,
         | 
| 208 | 
            -
                ):
         | 
| 209 | 
            -
                    super().__init__()
         | 
| 210 | 
            -
             | 
| 211 | 
            -
                    self.register_modules(
         | 
| 212 | 
            -
                        vae=vae,
         | 
| 213 | 
            -
                        text_encoder=text_encoder,
         | 
| 214 | 
            -
                        tokenizer=tokenizer,
         | 
| 215 | 
            -
                        tokenizer_2=tokenizer_2,
         | 
| 216 | 
            -
                        transformer=transformer,
         | 
| 217 | 
            -
                        scheduler=scheduler,
         | 
| 218 | 
            -
                        safety_checker=safety_checker,
         | 
| 219 | 
            -
                        feature_extractor=feature_extractor,
         | 
| 220 | 
            -
                        text_encoder_2=text_encoder_2,
         | 
| 221 | 
            -
                        clip_image_processor=clip_image_processor, 
         | 
| 222 | 
            -
                        clip_image_encoder=clip_image_encoder,
         | 
| 223 | 
            -
                    )
         | 
| 224 | 
            -
             | 
| 225 | 
            -
                    if safety_checker is None and requires_safety_checker:
         | 
| 226 | 
            -
                        logger.warning(
         | 
| 227 | 
            -
                            f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
         | 
| 228 | 
            -
                            " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
         | 
| 229 | 
            -
                            " results in services or applications open to the public. Both the diffusers team and Hugging Face"
         | 
| 230 | 
            -
                            " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
         | 
| 231 | 
            -
                            " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
         | 
| 232 | 
            -
                            " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
         | 
| 233 | 
            -
                        )
         | 
| 234 | 
            -
             | 
| 235 | 
            -
                    if safety_checker is not None and feature_extractor is None:
         | 
| 236 | 
            -
                        raise ValueError(
         | 
| 237 | 
            -
                            "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
         | 
| 238 | 
            -
                            " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
         | 
| 239 | 
            -
                        )
         | 
| 240 | 
            -
             | 
| 241 | 
            -
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         | 
| 242 | 
            -
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         | 
| 243 | 
            -
                    self.mask_processor = VaeImageProcessor(
         | 
| 244 | 
            -
                        vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
         | 
| 245 | 
            -
                    )
         | 
| 246 | 
            -
                    self.enable_autocast_float8_transformer_flag = False
         | 
| 247 | 
            -
                    self.register_to_config(requires_safety_checker=requires_safety_checker)
         | 
| 248 | 
            -
             | 
| 249 | 
            -
                def enable_sequential_cpu_offload(self, *args, **kwargs):
         | 
| 250 | 
            -
                    super().enable_sequential_cpu_offload(*args, **kwargs)
         | 
| 251 | 
            -
                    if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
         | 
| 252 | 
            -
                        import accelerate
         | 
| 253 | 
            -
                        accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
         | 
| 254 | 
            -
                        self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
         | 
| 255 | 
            -
             | 
| 256 | 
            -
                def encode_prompt(
         | 
| 257 | 
            -
                    self,
         | 
| 258 | 
            -
                    prompt: str,
         | 
| 259 | 
            -
                    device: torch.device,
         | 
| 260 | 
            -
                    dtype: torch.dtype,
         | 
| 261 | 
            -
                    num_images_per_prompt: int = 1,
         | 
| 262 | 
            -
                    do_classifier_free_guidance: bool = True,
         | 
| 263 | 
            -
                    negative_prompt: Optional[str] = None,
         | 
| 264 | 
            -
                    prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 265 | 
            -
                    negative_prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 266 | 
            -
                    prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 267 | 
            -
                    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 268 | 
            -
                    max_sequence_length: Optional[int] = None,
         | 
| 269 | 
            -
                    text_encoder_index: int = 0,
         | 
| 270 | 
            -
                    actual_max_sequence_length: int = 256
         | 
| 271 | 
            -
                ):
         | 
| 272 | 
            -
                    r"""
         | 
| 273 | 
            -
                    Encodes the prompt into text encoder hidden states.
         | 
| 274 | 
            -
             | 
| 275 | 
            -
                    Args:
         | 
| 276 | 
            -
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 277 | 
            -
                            prompt to be encoded
         | 
| 278 | 
            -
                        device: (`torch.device`):
         | 
| 279 | 
            -
                            torch device
         | 
| 280 | 
            -
                        dtype (`torch.dtype`):
         | 
| 281 | 
            -
                            torch dtype
         | 
| 282 | 
            -
                        num_images_per_prompt (`int`):
         | 
| 283 | 
            -
                            number of images that should be generated per prompt
         | 
| 284 | 
            -
                        do_classifier_free_guidance (`bool`):
         | 
| 285 | 
            -
                            whether to use classifier free guidance or not
         | 
| 286 | 
            -
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 287 | 
            -
                            The prompt or prompts not to guide the image generation. If not defined, one has to pass
         | 
| 288 | 
            -
                            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
         | 
| 289 | 
            -
                            less than `1`).
         | 
| 290 | 
            -
                        prompt_embeds (`torch.Tensor`, *optional*):
         | 
| 291 | 
            -
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         | 
| 292 | 
            -
                            provided, text embeddings will be generated from `prompt` input argument.
         | 
| 293 | 
            -
                        negative_prompt_embeds (`torch.Tensor`, *optional*):
         | 
| 294 | 
            -
                            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
         | 
| 295 | 
            -
                            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
         | 
| 296 | 
            -
                            argument.
         | 
| 297 | 
            -
                        prompt_attention_mask (`torch.Tensor`, *optional*):
         | 
| 298 | 
            -
                            Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
         | 
| 299 | 
            -
                        negative_prompt_attention_mask (`torch.Tensor`, *optional*):
         | 
| 300 | 
            -
                            Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
         | 
| 301 | 
            -
                        max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
         | 
| 302 | 
            -
                        text_encoder_index (`int`, *optional*):
         | 
| 303 | 
            -
                            Index of the text encoder to use. `0` for clip and `1` for T5.
         | 
| 304 | 
            -
                    """
         | 
| 305 | 
            -
                    tokenizers = [self.tokenizer, self.tokenizer_2]
         | 
| 306 | 
            -
                    text_encoders = [self.text_encoder, self.text_encoder_2]
         | 
| 307 | 
            -
             | 
| 308 | 
            -
                    tokenizer = tokenizers[text_encoder_index]
         | 
| 309 | 
            -
                    text_encoder = text_encoders[text_encoder_index]
         | 
| 310 | 
            -
             | 
| 311 | 
            -
                    if max_sequence_length is None:
         | 
| 312 | 
            -
                        if text_encoder_index == 0:
         | 
| 313 | 
            -
                            max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length)
         | 
| 314 | 
            -
                        if text_encoder_index == 1:
         | 
| 315 | 
            -
                            max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length)
         | 
| 316 | 
            -
                    else:
         | 
| 317 | 
            -
                        max_length = max_sequence_length
         | 
| 318 | 
            -
             | 
| 319 | 
            -
                    if prompt is not None and isinstance(prompt, str):
         | 
| 320 | 
            -
                        batch_size = 1
         | 
| 321 | 
            -
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 322 | 
            -
                        batch_size = len(prompt)
         | 
| 323 | 
            -
                    else:
         | 
| 324 | 
            -
                        batch_size = prompt_embeds.shape[0]
         | 
| 325 | 
            -
             | 
| 326 | 
            -
                    if prompt_embeds is None:
         | 
| 327 | 
            -
                        text_inputs = tokenizer(
         | 
| 328 | 
            -
                            prompt,
         | 
| 329 | 
            -
                            padding="max_length",
         | 
| 330 | 
            -
                            max_length=max_length,
         | 
| 331 | 
            -
                            truncation=True,
         | 
| 332 | 
            -
                            return_attention_mask=True,
         | 
| 333 | 
            -
                            return_tensors="pt",
         | 
| 334 | 
            -
                        )
         | 
| 335 | 
            -
                        text_input_ids = text_inputs.input_ids
         | 
| 336 | 
            -
                        if text_input_ids.shape[-1] > actual_max_sequence_length:
         | 
| 337 | 
            -
                            reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
         | 
| 338 | 
            -
                            text_inputs = tokenizer(
         | 
| 339 | 
            -
                                reprompt,
         | 
| 340 | 
            -
                                padding="max_length",
         | 
| 341 | 
            -
                                max_length=max_length,
         | 
| 342 | 
            -
                                truncation=True,
         | 
| 343 | 
            -
                                return_attention_mask=True,
         | 
| 344 | 
            -
                                return_tensors="pt",
         | 
| 345 | 
            -
                            )
         | 
| 346 | 
            -
                            text_input_ids = text_inputs.input_ids
         | 
| 347 | 
            -
                        untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
         | 
| 348 | 
            -
             | 
| 349 | 
            -
                        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
         | 
| 350 | 
            -
                            text_input_ids, untruncated_ids
         | 
| 351 | 
            -
                        ):
         | 
| 352 | 
            -
                            _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length)
         | 
| 353 | 
            -
                            removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1])
         | 
| 354 | 
            -
                            logger.warning(
         | 
| 355 | 
            -
                                "The following part of your input was truncated because CLIP can only handle sequences up to"
         | 
| 356 | 
            -
                                f" {_actual_max_sequence_length} tokens: {removed_text}"
         | 
| 357 | 
            -
                            )
         | 
| 358 | 
            -
                        prompt_attention_mask = text_inputs.attention_mask.to(device)
         | 
| 359 | 
            -
                        if self.transformer.config.enable_text_attention_mask:
         | 
| 360 | 
            -
                            prompt_embeds = text_encoder(
         | 
| 361 | 
            -
                                text_input_ids.to(device),
         | 
| 362 | 
            -
                                attention_mask=prompt_attention_mask,
         | 
| 363 | 
            -
                            )
         | 
| 364 | 
            -
                        else:
         | 
| 365 | 
            -
                            prompt_embeds = text_encoder(
         | 
| 366 | 
            -
                                text_input_ids.to(device)
         | 
| 367 | 
            -
                            )
         | 
| 368 | 
            -
                        prompt_embeds = prompt_embeds[0]
         | 
| 369 | 
            -
                        prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 370 | 
            -
             | 
| 371 | 
            -
                    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
         | 
| 372 | 
            -
             | 
| 373 | 
            -
                    bs_embed, seq_len, _ = prompt_embeds.shape
         | 
| 374 | 
            -
                    # duplicate text embeddings for each generation per prompt, using mps friendly method
         | 
| 375 | 
            -
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 376 | 
            -
                    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 377 | 
            -
             | 
| 378 | 
            -
                    # get unconditional embeddings for classifier free guidance
         | 
| 379 | 
            -
                    if do_classifier_free_guidance and negative_prompt_embeds is None:
         | 
| 380 | 
            -
                        uncond_tokens: List[str]
         | 
| 381 | 
            -
                        if negative_prompt is None:
         | 
| 382 | 
            -
                            uncond_tokens = [""] * batch_size
         | 
| 383 | 
            -
                        elif prompt is not None and type(prompt) is not type(negative_prompt):
         | 
| 384 | 
            -
                            raise TypeError(
         | 
| 385 | 
            -
                                f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
         | 
| 386 | 
            -
                                f" {type(prompt)}."
         | 
| 387 | 
            -
                            )
         | 
| 388 | 
            -
                        elif isinstance(negative_prompt, str):
         | 
| 389 | 
            -
                            uncond_tokens = [negative_prompt]
         | 
| 390 | 
            -
                        elif batch_size != len(negative_prompt):
         | 
| 391 | 
            -
                            raise ValueError(
         | 
| 392 | 
            -
                                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
         | 
| 393 | 
            -
                                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
         | 
| 394 | 
            -
                                " the batch size of `prompt`."
         | 
| 395 | 
            -
                            )
         | 
| 396 | 
            -
                        else:
         | 
| 397 | 
            -
                            uncond_tokens = negative_prompt
         | 
| 398 | 
            -
             | 
| 399 | 
            -
                        max_length = prompt_embeds.shape[1]
         | 
| 400 | 
            -
                        uncond_input = tokenizer(
         | 
| 401 | 
            -
                            uncond_tokens,
         | 
| 402 | 
            -
                            padding="max_length",
         | 
| 403 | 
            -
                            max_length=max_length,
         | 
| 404 | 
            -
                            truncation=True,
         | 
| 405 | 
            -
                            return_tensors="pt",
         | 
| 406 | 
            -
                        )
         | 
| 407 | 
            -
                        uncond_input_ids = uncond_input.input_ids
         | 
| 408 | 
            -
                        if uncond_input_ids.shape[-1] > actual_max_sequence_length:
         | 
| 409 | 
            -
                            reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True)
         | 
| 410 | 
            -
                            uncond_input = tokenizer(
         | 
| 411 | 
            -
                                reuncond_tokens,
         | 
| 412 | 
            -
                                padding="max_length",
         | 
| 413 | 
            -
                                max_length=max_length,
         | 
| 414 | 
            -
                                truncation=True,
         | 
| 415 | 
            -
                                return_attention_mask=True,
         | 
| 416 | 
            -
                                return_tensors="pt",
         | 
| 417 | 
            -
                            )
         | 
| 418 | 
            -
                            uncond_input_ids = uncond_input.input_ids
         | 
| 419 | 
            -
             | 
| 420 | 
            -
                        negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
         | 
| 421 | 
            -
                        if self.transformer.config.enable_text_attention_mask:
         | 
| 422 | 
            -
                            negative_prompt_embeds = text_encoder(
         | 
| 423 | 
            -
                                uncond_input.input_ids.to(device),
         | 
| 424 | 
            -
                                attention_mask=negative_prompt_attention_mask,
         | 
| 425 | 
            -
                            )
         | 
| 426 | 
            -
                        else:
         | 
| 427 | 
            -
                            negative_prompt_embeds = text_encoder(
         | 
| 428 | 
            -
                                uncond_input.input_ids.to(device)
         | 
| 429 | 
            -
                            )
         | 
| 430 | 
            -
                        negative_prompt_embeds = negative_prompt_embeds[0]
         | 
| 431 | 
            -
                        negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
         | 
| 432 | 
            -
             | 
| 433 | 
            -
                    if do_classifier_free_guidance:
         | 
| 434 | 
            -
                        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
         | 
| 435 | 
            -
                        seq_len = negative_prompt_embeds.shape[1]
         | 
| 436 | 
            -
             | 
| 437 | 
            -
                        negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
         | 
| 438 | 
            -
             | 
| 439 | 
            -
                        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 440 | 
            -
                        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
         | 
| 441 | 
            -
             | 
| 442 | 
            -
                    return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
         | 
| 443 | 
            -
             | 
| 444 | 
            -
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
         | 
| 445 | 
            -
                def run_safety_checker(self, image, device, dtype):
         | 
| 446 | 
            -
                    if self.safety_checker is None:
         | 
| 447 | 
            -
                        has_nsfw_concept = None
         | 
| 448 | 
            -
                    else:
         | 
| 449 | 
            -
                        if torch.is_tensor(image):
         | 
| 450 | 
            -
                            feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
         | 
| 451 | 
            -
                        else:
         | 
| 452 | 
            -
                            feature_extractor_input = self.image_processor.numpy_to_pil(image)
         | 
| 453 | 
            -
                        safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
         | 
| 454 | 
            -
                        image, has_nsfw_concept = self.safety_checker(
         | 
| 455 | 
            -
                            images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
         | 
| 456 | 
            -
                        )
         | 
| 457 | 
            -
                    return image, has_nsfw_concept
         | 
| 458 | 
            -
             | 
| 459 | 
            -
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
         | 
| 460 | 
            -
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 461 | 
            -
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
| 462 | 
            -
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         | 
| 463 | 
            -
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         | 
| 464 | 
            -
                    # and should be between [0, 1]
         | 
| 465 | 
            -
             | 
| 466 | 
            -
                    accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 467 | 
            -
                    extra_step_kwargs = {}
         | 
| 468 | 
            -
                    if accepts_eta:
         | 
| 469 | 
            -
                        extra_step_kwargs["eta"] = eta
         | 
| 470 | 
            -
             | 
| 471 | 
            -
                    # check if the scheduler accepts generator
         | 
| 472 | 
            -
                    accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 473 | 
            -
                    if accepts_generator:
         | 
| 474 | 
            -
                        extra_step_kwargs["generator"] = generator
         | 
| 475 | 
            -
                    return extra_step_kwargs
         | 
| 476 | 
            -
             | 
| 477 | 
            -
                def check_inputs(
         | 
| 478 | 
            -
                    self,
         | 
| 479 | 
            -
                    prompt,
         | 
| 480 | 
            -
                    height,
         | 
| 481 | 
            -
                    width,
         | 
| 482 | 
            -
                    negative_prompt=None,
         | 
| 483 | 
            -
                    prompt_embeds=None,
         | 
| 484 | 
            -
                    negative_prompt_embeds=None,
         | 
| 485 | 
            -
                    prompt_attention_mask=None,
         | 
| 486 | 
            -
                    negative_prompt_attention_mask=None,
         | 
| 487 | 
            -
                    prompt_embeds_2=None,
         | 
| 488 | 
            -
                    negative_prompt_embeds_2=None,
         | 
| 489 | 
            -
                    prompt_attention_mask_2=None,
         | 
| 490 | 
            -
                    negative_prompt_attention_mask_2=None,
         | 
| 491 | 
            -
                    callback_on_step_end_tensor_inputs=None,
         | 
| 492 | 
            -
                ):
         | 
| 493 | 
            -
                    if height % 8 != 0 or width % 8 != 0:
         | 
| 494 | 
            -
                        raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
         | 
| 495 | 
            -
             | 
| 496 | 
            -
                    if callback_on_step_end_tensor_inputs is not None and not all(
         | 
| 497 | 
            -
                        k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
         | 
| 498 | 
            -
                    ):
         | 
| 499 | 
            -
                        raise ValueError(
         | 
| 500 | 
            -
                            f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
         | 
| 501 | 
            -
                        )
         | 
| 502 | 
            -
             | 
| 503 | 
            -
                    if prompt is not None and prompt_embeds is not None:
         | 
| 504 | 
            -
                        raise ValueError(
         | 
| 505 | 
            -
                            f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
         | 
| 506 | 
            -
                            " only forward one of the two."
         | 
| 507 | 
            -
                        )
         | 
| 508 | 
            -
                    elif prompt is None and prompt_embeds is None:
         | 
| 509 | 
            -
                        raise ValueError(
         | 
| 510 | 
            -
                            "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
         | 
| 511 | 
            -
                        )
         | 
| 512 | 
            -
                    elif prompt is None and prompt_embeds_2 is None:
         | 
| 513 | 
            -
                        raise ValueError(
         | 
| 514 | 
            -
                            "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
         | 
| 515 | 
            -
                        )
         | 
| 516 | 
            -
                    elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
         | 
| 517 | 
            -
                        raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
         | 
| 518 | 
            -
             | 
| 519 | 
            -
                    if prompt_embeds is not None and prompt_attention_mask is None:
         | 
| 520 | 
            -
                        raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
         | 
| 521 | 
            -
             | 
| 522 | 
            -
                    if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
         | 
| 523 | 
            -
                        raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
         | 
| 524 | 
            -
             | 
| 525 | 
            -
                    if negative_prompt is not None and negative_prompt_embeds is not None:
         | 
| 526 | 
            -
                        raise ValueError(
         | 
| 527 | 
            -
                            f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
         | 
| 528 | 
            -
                            f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
         | 
| 529 | 
            -
                        )
         | 
| 530 | 
            -
             | 
| 531 | 
            -
                    if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
         | 
| 532 | 
            -
                        raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
         | 
| 533 | 
            -
             | 
| 534 | 
            -
                    if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
         | 
| 535 | 
            -
                        raise ValueError(
         | 
| 536 | 
            -
                            "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
         | 
| 537 | 
            -
                        )
         | 
| 538 | 
            -
                    if prompt_embeds is not None and negative_prompt_embeds is not None:
         | 
| 539 | 
            -
                        if prompt_embeds.shape != negative_prompt_embeds.shape:
         | 
| 540 | 
            -
                            raise ValueError(
         | 
| 541 | 
            -
                                "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
         | 
| 542 | 
            -
                                f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
         | 
| 543 | 
            -
                                f" {negative_prompt_embeds.shape}."
         | 
| 544 | 
            -
                            )
         | 
| 545 | 
            -
                    if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
         | 
| 546 | 
            -
                        if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
         | 
| 547 | 
            -
                            raise ValueError(
         | 
| 548 | 
            -
                                "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
         | 
| 549 | 
            -
                                f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
         | 
| 550 | 
            -
                                f" {negative_prompt_embeds_2.shape}."
         | 
| 551 | 
            -
                            )
         | 
| 552 | 
            -
             | 
| 553 | 
            -
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
         | 
| 554 | 
            -
                def get_timesteps(self, num_inference_steps, strength, device):
         | 
| 555 | 
            -
                    # get the original timestep using init_timestep
         | 
| 556 | 
            -
                    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
         | 
| 557 | 
            -
             | 
| 558 | 
            -
                    t_start = max(num_inference_steps - init_timestep, 0)
         | 
| 559 | 
            -
                    timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
         | 
| 560 | 
            -
             | 
| 561 | 
            -
                    return timesteps, num_inference_steps - t_start
         | 
| 562 | 
            -
             | 
| 563 | 
            -
                def prepare_mask_latents(
         | 
| 564 | 
            -
                    self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
         | 
| 565 | 
            -
                ):
         | 
| 566 | 
            -
                    # resize the mask to latents shape as we concatenate the mask to the latents
         | 
| 567 | 
            -
                    # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
         | 
| 568 | 
            -
                    # and half precision
         | 
| 569 | 
            -
                    if mask is not None:
         | 
| 570 | 
            -
                        mask = mask.to(device=device, dtype=self.vae.dtype)
         | 
| 571 | 
            -
                        if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
         | 
| 572 | 
            -
                            bs = 1
         | 
| 573 | 
            -
                            new_mask = []
         | 
| 574 | 
            -
                            for i in range(0, mask.shape[0], bs):
         | 
| 575 | 
            -
                                mask_bs = mask[i : i + bs]
         | 
| 576 | 
            -
                                mask_bs = self.vae.encode(mask_bs)[0]
         | 
| 577 | 
            -
                                mask_bs = mask_bs.mode()
         | 
| 578 | 
            -
                                new_mask.append(mask_bs)
         | 
| 579 | 
            -
                            mask = torch.cat(new_mask, dim = 0)
         | 
| 580 | 
            -
                            mask = mask * self.vae.config.scaling_factor
         | 
| 581 | 
            -
             | 
| 582 | 
            -
                        else:
         | 
| 583 | 
            -
                            if mask.shape[1] == 4:
         | 
| 584 | 
            -
                                mask = mask
         | 
| 585 | 
            -
                            else:
         | 
| 586 | 
            -
                                video_length = mask.shape[2]
         | 
| 587 | 
            -
                                mask = rearrange(mask, "b c f h w -> (b f) c h w")
         | 
| 588 | 
            -
                                mask = self._encode_vae_image(mask, generator=generator)
         | 
| 589 | 
            -
                                mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 590 | 
            -
             | 
| 591 | 
            -
                    if masked_image is not None:
         | 
| 592 | 
            -
                        masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
         | 
| 593 | 
            -
                        if self.transformer.config.add_noise_in_inpaint_model:
         | 
| 594 | 
            -
                            masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
         | 
| 595 | 
            -
                        if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
         | 
| 596 | 
            -
                            bs = 1
         | 
| 597 | 
            -
                            new_mask_pixel_values = []
         | 
| 598 | 
            -
                            for i in range(0, masked_image.shape[0], bs):
         | 
| 599 | 
            -
                                mask_pixel_values_bs = masked_image[i : i + bs]
         | 
| 600 | 
            -
                                mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
         | 
| 601 | 
            -
                                mask_pixel_values_bs = mask_pixel_values_bs.mode()
         | 
| 602 | 
            -
                                new_mask_pixel_values.append(mask_pixel_values_bs)
         | 
| 603 | 
            -
                            masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
         | 
| 604 | 
            -
                            masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
         | 
| 605 | 
            -
             | 
| 606 | 
            -
                        else:
         | 
| 607 | 
            -
                            if masked_image.shape[1] == 4:
         | 
| 608 | 
            -
                                masked_image_latents = masked_image
         | 
| 609 | 
            -
                            else:
         | 
| 610 | 
            -
                                video_length = masked_image.shape[2]
         | 
| 611 | 
            -
                                masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
         | 
| 612 | 
            -
                                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
         | 
| 613 | 
            -
                                masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 614 | 
            -
             | 
| 615 | 
            -
                        # aligning device to prevent device errors when concating it with the latent model input
         | 
| 616 | 
            -
                        masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
         | 
| 617 | 
            -
                    else:
         | 
| 618 | 
            -
                        masked_image_latents = None
         | 
| 619 | 
            -
             | 
| 620 | 
            -
                    return mask, masked_image_latents
         | 
| 621 | 
            -
             | 
| 622 | 
            -
                def prepare_latents(
         | 
| 623 | 
            -
                    self, 
         | 
| 624 | 
            -
                    batch_size,
         | 
| 625 | 
            -
                    num_channels_latents,
         | 
| 626 | 
            -
                    height,
         | 
| 627 | 
            -
                    width,
         | 
| 628 | 
            -
                    video_length,
         | 
| 629 | 
            -
                    dtype,
         | 
| 630 | 
            -
                    device,
         | 
| 631 | 
            -
                    generator,
         | 
| 632 | 
            -
                    latents=None,
         | 
| 633 | 
            -
                    video=None,
         | 
| 634 | 
            -
                    timestep=None,
         | 
| 635 | 
            -
                    is_strength_max=True,
         | 
| 636 | 
            -
                    return_noise=False,
         | 
| 637 | 
            -
                    return_video_latents=False,
         | 
| 638 | 
            -
                ):
         | 
| 639 | 
            -
                    if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
         | 
| 640 | 
            -
                        if self.vae.cache_mag_vae:
         | 
| 641 | 
            -
                            mini_batch_encoder = self.vae.mini_batch_encoder
         | 
| 642 | 
            -
                            mini_batch_decoder = self.vae.mini_batch_decoder
         | 
| 643 | 
            -
                            shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 644 | 
            -
                        else:
         | 
| 645 | 
            -
                            mini_batch_encoder = self.vae.mini_batch_encoder
         | 
| 646 | 
            -
                            mini_batch_decoder = self.vae.mini_batch_decoder
         | 
| 647 | 
            -
                            shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 648 | 
            -
                    else:
         | 
| 649 | 
            -
                        shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 650 | 
            -
             | 
| 651 | 
            -
                    if isinstance(generator, list) and len(generator) != batch_size:
         | 
| 652 | 
            -
                        raise ValueError(
         | 
| 653 | 
            -
                            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
         | 
| 654 | 
            -
                            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
         | 
| 655 | 
            -
                        )
         | 
| 656 | 
            -
             | 
| 657 | 
            -
                    if return_video_latents or (latents is None and not is_strength_max):
         | 
| 658 | 
            -
                        video = video.to(device=device, dtype=self.vae.dtype)
         | 
| 659 | 
            -
                        if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
         | 
| 660 | 
            -
                            bs = 1
         | 
| 661 | 
            -
                            new_video = []
         | 
| 662 | 
            -
                            for i in range(0, video.shape[0], bs):
         | 
| 663 | 
            -
                                video_bs = video[i : i + bs]
         | 
| 664 | 
            -
                                video_bs = self.vae.encode(video_bs)[0]
         | 
| 665 | 
            -
                                video_bs = video_bs.sample()
         | 
| 666 | 
            -
                                new_video.append(video_bs)
         | 
| 667 | 
            -
                            video = torch.cat(new_video, dim = 0)
         | 
| 668 | 
            -
                            video = video * self.vae.config.scaling_factor
         | 
| 669 | 
            -
             | 
| 670 | 
            -
                        else:
         | 
| 671 | 
            -
                            if video.shape[1] == 4:
         | 
| 672 | 
            -
                                video = video
         | 
| 673 | 
            -
                            else:
         | 
| 674 | 
            -
                                video_length = video.shape[2]
         | 
| 675 | 
            -
                                video = rearrange(video, "b c f h w -> (b f) c h w")
         | 
| 676 | 
            -
                                video = self._encode_vae_image(video, generator=generator)
         | 
| 677 | 
            -
                                video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 678 | 
            -
                        video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
         | 
| 679 | 
            -
                        video_latents = video_latents.to(device=device, dtype=dtype)
         | 
| 680 | 
            -
             | 
| 681 | 
            -
                    if latents is None:
         | 
| 682 | 
            -
                        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         | 
| 683 | 
            -
                        # if strength is 1. then initialise the latents to noise, else initial to image + noise
         | 
| 684 | 
            -
                        latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
         | 
| 685 | 
            -
                        # if pure noise then scale the initial latents by the  Scheduler's init sigma
         | 
| 686 | 
            -
                        latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
         | 
| 687 | 
            -
                    else:
         | 
| 688 | 
            -
                        noise = latents.to(device)
         | 
| 689 | 
            -
                        latents = noise * self.scheduler.init_noise_sigma
         | 
| 690 | 
            -
             | 
| 691 | 
            -
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 692 | 
            -
                    outputs = (latents,)
         | 
| 693 | 
            -
             | 
| 694 | 
            -
                    if return_noise:
         | 
| 695 | 
            -
                        outputs += (noise,)
         | 
| 696 | 
            -
             | 
| 697 | 
            -
                    if return_video_latents:
         | 
| 698 | 
            -
                        outputs += (video_latents,)
         | 
| 699 | 
            -
             | 
| 700 | 
            -
                    return outputs
         | 
| 701 | 
            -
             | 
| 702 | 
            -
                def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
         | 
| 703 | 
            -
                    if video.size()[2] <= mini_batch_encoder:
         | 
| 704 | 
            -
                        return video
         | 
| 705 | 
            -
                    prefix_index_before = mini_batch_encoder // 2
         | 
| 706 | 
            -
                    prefix_index_after = mini_batch_encoder - prefix_index_before
         | 
| 707 | 
            -
                    pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
         | 
| 708 | 
            -
             | 
| 709 | 
            -
                    # Encode middle videos
         | 
| 710 | 
            -
                    latents = self.vae.encode(pixel_values)[0]
         | 
| 711 | 
            -
                    latents = latents.mode()
         | 
| 712 | 
            -
                    # Decode middle videos
         | 
| 713 | 
            -
                    middle_video = self.vae.decode(latents)[0]
         | 
| 714 | 
            -
             | 
| 715 | 
            -
                    video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
         | 
| 716 | 
            -
                    return video
         | 
| 717 | 
            -
             | 
| 718 | 
            -
                def decode_latents(self, latents):
         | 
| 719 | 
            -
                    video_length = latents.shape[2]
         | 
| 720 | 
            -
                    latents = 1 / self.vae.config.scaling_factor * latents
         | 
| 721 | 
            -
                    if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5:
         | 
| 722 | 
            -
                        mini_batch_encoder = self.vae.mini_batch_encoder
         | 
| 723 | 
            -
                        mini_batch_decoder = self.vae.mini_batch_decoder
         | 
| 724 | 
            -
                        video = self.vae.decode(latents)[0]
         | 
| 725 | 
            -
                        video = video.clamp(-1, 1)
         | 
| 726 | 
            -
                        if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae:
         | 
| 727 | 
            -
                            video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
         | 
| 728 | 
            -
                    else:
         | 
| 729 | 
            -
                        latents = rearrange(latents, "b c f h w -> (b f) c h w")
         | 
| 730 | 
            -
                        video = []
         | 
| 731 | 
            -
                        for frame_idx in tqdm(range(latents.shape[0])):
         | 
| 732 | 
            -
                            video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
         | 
| 733 | 
            -
                        video = torch.cat(video)
         | 
| 734 | 
            -
                        video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 735 | 
            -
                    video = (video / 2 + 0.5).clamp(0, 1)
         | 
| 736 | 
            -
                    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
         | 
| 737 | 
            -
                    video = video.cpu().float().numpy()
         | 
| 738 | 
            -
                    return video
         | 
| 739 | 
            -
             | 
| 740 | 
            -
                @property
         | 
| 741 | 
            -
                def guidance_scale(self):
         | 
| 742 | 
            -
                    return self._guidance_scale
         | 
| 743 | 
            -
             | 
| 744 | 
            -
                @property
         | 
| 745 | 
            -
                def guidance_rescale(self):
         | 
| 746 | 
            -
                    return self._guidance_rescale
         | 
| 747 | 
            -
             | 
| 748 | 
            -
                # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 749 | 
            -
                # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 750 | 
            -
                # corresponds to doing no classifier free guidance.
         | 
| 751 | 
            -
                @property
         | 
| 752 | 
            -
                def do_classifier_free_guidance(self):
         | 
| 753 | 
            -
                    return self._guidance_scale > 1
         | 
| 754 | 
            -
             | 
| 755 | 
            -
                @property
         | 
| 756 | 
            -
                def num_timesteps(self):
         | 
| 757 | 
            -
                    return self._num_timesteps
         | 
| 758 | 
            -
             | 
| 759 | 
            -
                @property
         | 
| 760 | 
            -
                def interrupt(self):
         | 
| 761 | 
            -
                    return self._interrupt
         | 
| 762 | 
            -
             | 
| 763 | 
            -
                def enable_autocast_float8_transformer(self):
         | 
| 764 | 
            -
                    self.enable_autocast_float8_transformer_flag = True
         | 
| 765 | 
            -
             | 
| 766 | 
            -
                @torch.no_grad()
         | 
| 767 | 
            -
                @replace_example_docstring(EXAMPLE_DOC_STRING)
         | 
| 768 | 
            -
                def __call__(
         | 
| 769 | 
            -
                    self,
         | 
| 770 | 
            -
                    prompt: Union[str, List[str]] = None,
         | 
| 771 | 
            -
                    video_length: Optional[int] = None,
         | 
| 772 | 
            -
                    video: Union[torch.FloatTensor] = None,
         | 
| 773 | 
            -
                    mask_video: Union[torch.FloatTensor] = None,
         | 
| 774 | 
            -
                    masked_video_latents: Union[torch.FloatTensor] = None,
         | 
| 775 | 
            -
                    height: Optional[int] = None,
         | 
| 776 | 
            -
                    width: Optional[int] = None,
         | 
| 777 | 
            -
                    num_inference_steps: Optional[int] = 50,
         | 
| 778 | 
            -
                    guidance_scale: Optional[float] = 5.0,
         | 
| 779 | 
            -
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 780 | 
            -
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 781 | 
            -
                    eta: Optional[float] = 0.0,
         | 
| 782 | 
            -
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 783 | 
            -
                    latents: Optional[torch.Tensor] = None,
         | 
| 784 | 
            -
                    prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 785 | 
            -
                    prompt_embeds_2: Optional[torch.Tensor] = None,
         | 
| 786 | 
            -
                    negative_prompt_embeds: Optional[torch.Tensor] = None,
         | 
| 787 | 
            -
                    negative_prompt_embeds_2: Optional[torch.Tensor] = None,
         | 
| 788 | 
            -
                    prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 789 | 
            -
                    prompt_attention_mask_2: Optional[torch.Tensor] = None,
         | 
| 790 | 
            -
                    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
         | 
| 791 | 
            -
                    negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
         | 
| 792 | 
            -
                    output_type: Optional[str] = "latent",
         | 
| 793 | 
            -
                    return_dict: bool = True,
         | 
| 794 | 
            -
                    callback_on_step_end: Optional[
         | 
| 795 | 
            -
                        Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
         | 
| 796 | 
            -
                    ] = None,
         | 
| 797 | 
            -
                    callback_on_step_end_tensor_inputs: List[str] = ["latents"],
         | 
| 798 | 
            -
                    guidance_rescale: float = 0.0,
         | 
| 799 | 
            -
                    original_size: Optional[Tuple[int, int]] = (1024, 1024),
         | 
| 800 | 
            -
                    target_size: Optional[Tuple[int, int]] = None,
         | 
| 801 | 
            -
                    crops_coords_top_left: Tuple[int, int] = (0, 0),
         | 
| 802 | 
            -
                    clip_image: Image = None,
         | 
| 803 | 
            -
                    clip_apply_ratio: float = 0.40,
         | 
| 804 | 
            -
                    strength: float = 1.0,
         | 
| 805 | 
            -
                    noise_aug_strength: float = 0.0563,
         | 
| 806 | 
            -
                    comfyui_progressbar: bool = False,
         | 
| 807 | 
            -
                ):
         | 
| 808 | 
            -
                    r"""
         | 
| 809 | 
            -
                    The call function to the pipeline for generation with HunyuanDiT.
         | 
| 810 | 
            -
             | 
| 811 | 
            -
                    Examples:
         | 
| 812 | 
            -
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 813 | 
            -
                            The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
         | 
| 814 | 
            -
                        video_length (`int`, *optional*):
         | 
| 815 | 
            -
                            Length of the video to be generated in seconds. This parameter influences the number of frames and
         | 
| 816 | 
            -
                            continuity of generated content.
         | 
| 817 | 
            -
                        video (`torch.FloatTensor`, *optional*):
         | 
| 818 | 
            -
                            A tensor representing an input video, which can be modified depending on the prompts provided.
         | 
| 819 | 
            -
                        mask_video (`torch.FloatTensor`, *optional*):
         | 
| 820 | 
            -
                            A tensor to specify areas of the video to be masked (omitted from generation).
         | 
| 821 | 
            -
                        masked_video_latents (`torch.FloatTensor`, *optional*):
         | 
| 822 | 
            -
                            Latents from masked portions of the video, utilized during image generation.
         | 
| 823 | 
            -
                        height (`int`, *optional*):
         | 
| 824 | 
            -
                            The height in pixels of the generated image or video frames.
         | 
| 825 | 
            -
                        width (`int`, *optional*):
         | 
| 826 | 
            -
                            The width in pixels of the generated image or video frames.
         | 
| 827 | 
            -
                        num_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 828 | 
            -
                            The number of denoising steps. More denoising steps usually lead to a higher quality image but slower
         | 
| 829 | 
            -
                            inference time. This parameter is modulated by `strength`.
         | 
| 830 | 
            -
                        guidance_scale (`float`, *optional*, defaults to 5.0):
         | 
| 831 | 
            -
                            A higher guidance scale value encourages the model to generate images closely linked to the text 
         | 
| 832 | 
            -
                            `prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`.
         | 
| 833 | 
            -
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 834 | 
            -
                            The prompt or prompts to guide what to exclude in image generation. If not defined, you need to
         | 
| 835 | 
            -
                            provide `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`).
         | 
| 836 | 
            -
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 837 | 
            -
                            The number of images to generate per prompt.
         | 
| 838 | 
            -
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 839 | 
            -
                            A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the
         | 
| 840 | 
            -
                            [`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the 
         | 
| 841 | 
            -
                            inference process.
         | 
| 842 | 
            -
                        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
         | 
| 843 | 
            -
                            A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting
         | 
| 844 | 
            -
                            random seeds which helps in making generation deterministic.
         | 
| 845 | 
            -
                        latents (`torch.Tensor`, *optional*):
         | 
| 846 | 
            -
                            A pre-computed latent representation which can be used to guide the generation process.
         | 
| 847 | 
            -
                        prompt_embeds (`torch.Tensor`, *optional*):
         | 
| 848 | 
            -
                            Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
         | 
| 849 | 
            -
                            provided, embeddings are generated from the `prompt` input argument.
         | 
| 850 | 
            -
                        prompt_embeds_2 (`torch.Tensor`, *optional*):
         | 
| 851 | 
            -
                            Secondary set of pre-generated text embeddings, useful for advanced prompt weighting.
         | 
| 852 | 
            -
                        negative_prompt_embeds (`torch.Tensor`, *optional*):
         | 
| 853 | 
            -
                            Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the outputs.
         | 
| 854 | 
            -
                            If not provided, embeddings are generated from the `negative_prompt` argument.
         | 
| 855 | 
            -
                        negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
         | 
| 856 | 
            -
                            Secondary set of pre-generated negative text embeddings for further control.
         | 
| 857 | 
            -
                        prompt_attention_mask (`torch.Tensor`, *optional*):
         | 
| 858 | 
            -
                            Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using
         | 
| 859 | 
            -
                            `prompt_embeds`.
         | 
| 860 | 
            -
                        prompt_attention_mask_2 (`torch.Tensor`, *optional*):
         | 
| 861 | 
            -
                            Attention mask for the secondary prompt embedding.
         | 
| 862 | 
            -
                        negative_prompt_attention_mask (`torch.Tensor`, *optional*):
         | 
| 863 | 
            -
                            Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used.
         | 
| 864 | 
            -
                        negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
         | 
| 865 | 
            -
                            Attention mask for the secondary negative prompt embedding.
         | 
| 866 | 
            -
                        output_type (`str`, *optional*, defaults to `"latent"`):
         | 
| 867 | 
            -
                            The output format of the generated image. Choose between `PIL.Image` and `np.array` to define
         | 
| 868 | 
            -
                            how you want the results to be formatted.
         | 
| 869 | 
            -
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 870 | 
            -
                            If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned;
         | 
| 871 | 
            -
                            otherwise, a tuple containing the generated images and safety flags will be returned.
         | 
| 872 | 
            -
                        callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
         | 
| 873 | 
            -
                            A callback function (or a list of them) that will be executed at the end of each denoising step,
         | 
| 874 | 
            -
                            allowing for custom processing during generation.
         | 
| 875 | 
            -
                        callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
         | 
| 876 | 
            -
                            Specifies which tensor inputs should be included in the callback function. If not defined, all tensor
         | 
| 877 | 
            -
                            inputs will be passed, facilitating enhanced logging or monitoring of the generation process.
         | 
| 878 | 
            -
                        guidance_rescale (`float`, *optional*, defaults to 0.0):
         | 
| 879 | 
            -
                            Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from
         | 
| 880 | 
            -
                            [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
         | 
| 881 | 
            -
                        original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
         | 
| 882 | 
            -
                            The original dimensions of the image. Used to compute time ids during the generation process.
         | 
| 883 | 
            -
                        target_size (`Tuple[int, int]`, *optional*):
         | 
| 884 | 
            -
                            The targeted dimensions of the generated image, also utilized in the time id calculations.
         | 
| 885 | 
            -
                        crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
         | 
| 886 | 
            -
                            Coordinates defining the top left corner of any cropping, utilized while calculating the time ids.
         | 
| 887 | 
            -
                        clip_image (`Image`, *optional*):
         | 
| 888 | 
            -
                            An optional image to assist in the generation process. It may be used as an additional visual cue.
         | 
| 889 | 
            -
                        clip_apply_ratio (`float`, *optional*, defaults to 0.40):
         | 
| 890 | 
            -
                            Ratio indicating how much influence the clip image should exert over the generated content.
         | 
| 891 | 
            -
                        strength (`float`, *optional*, defaults to 1.0):
         | 
| 892 | 
            -
                            Affects the overall styling or quality of the generated output. Values closer to 1 usually provide direct 
         | 
| 893 | 
            -
                            adherence to prompts.
         | 
| 894 | 
            -
                        comfyui_progressbar (`bool`, *optional*, defaults to `False`):
         | 
| 895 | 
            -
                            Enables a progress bar in ComfyUI, providing visual feedback during the generation process.
         | 
| 896 | 
            -
             | 
| 897 | 
            -
                    Examples:
         | 
| 898 | 
            -
                        # Example usage of the function for generating images based on prompts.
         | 
| 899 | 
            -
                    
         | 
| 900 | 
            -
                    Returns:
         | 
| 901 | 
            -
                        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
         | 
| 902 | 
            -
                            Returns either a structured output containing generated images and their metadata when `return_dict` is
         | 
| 903 | 
            -
                            `True`, or a simpler tuple, where the first element is a list of generated images and the second
         | 
| 904 | 
            -
                            element indicates if any of them contain "not-safe-for-work" (NSFW) content.
         | 
| 905 | 
            -
                    """
         | 
| 906 | 
            -
             | 
| 907 | 
            -
                    if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
         | 
| 908 | 
            -
                        callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
         | 
| 909 | 
            -
             | 
| 910 | 
            -
                    # 0. default height and width
         | 
| 911 | 
            -
                    height = int(height // 16 * 16)
         | 
| 912 | 
            -
                    width = int(width // 16 * 16)
         | 
| 913 | 
            -
             | 
| 914 | 
            -
                    # 1. Check inputs. Raise error if not correct
         | 
| 915 | 
            -
                    self.check_inputs(
         | 
| 916 | 
            -
                        prompt,
         | 
| 917 | 
            -
                        height,
         | 
| 918 | 
            -
                        width,
         | 
| 919 | 
            -
                        negative_prompt,
         | 
| 920 | 
            -
                        prompt_embeds,
         | 
| 921 | 
            -
                        negative_prompt_embeds,
         | 
| 922 | 
            -
                        prompt_attention_mask,
         | 
| 923 | 
            -
                        negative_prompt_attention_mask,
         | 
| 924 | 
            -
                        prompt_embeds_2,
         | 
| 925 | 
            -
                        negative_prompt_embeds_2,
         | 
| 926 | 
            -
                        prompt_attention_mask_2,
         | 
| 927 | 
            -
                        negative_prompt_attention_mask_2,
         | 
| 928 | 
            -
                        callback_on_step_end_tensor_inputs,
         | 
| 929 | 
            -
                    )
         | 
| 930 | 
            -
                    self._guidance_scale = guidance_scale
         | 
| 931 | 
            -
                    self._guidance_rescale = guidance_rescale
         | 
| 932 | 
            -
                    self._interrupt = False
         | 
| 933 | 
            -
             | 
| 934 | 
            -
                    # 2. Define call parameters
         | 
| 935 | 
            -
                    if prompt is not None and isinstance(prompt, str):
         | 
| 936 | 
            -
                        batch_size = 1
         | 
| 937 | 
            -
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 938 | 
            -
                        batch_size = len(prompt)
         | 
| 939 | 
            -
                    else:
         | 
| 940 | 
            -
                        batch_size = prompt_embeds.shape[0]
         | 
| 941 | 
            -
             | 
| 942 | 
            -
                    device = self._execution_device
         | 
| 943 | 
            -
                        
         | 
| 944 | 
            -
                    # 3. Encode input prompt
         | 
| 945 | 
            -
                    (
         | 
| 946 | 
            -
                        prompt_embeds,
         | 
| 947 | 
            -
                        negative_prompt_embeds,
         | 
| 948 | 
            -
                        prompt_attention_mask,
         | 
| 949 | 
            -
                        negative_prompt_attention_mask,
         | 
| 950 | 
            -
                    ) = self.encode_prompt(
         | 
| 951 | 
            -
                        prompt=prompt,
         | 
| 952 | 
            -
                        device=device,
         | 
| 953 | 
            -
                        dtype=self.transformer.dtype,
         | 
| 954 | 
            -
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 955 | 
            -
                        do_classifier_free_guidance=self.do_classifier_free_guidance,
         | 
| 956 | 
            -
                        negative_prompt=negative_prompt,
         | 
| 957 | 
            -
                        prompt_embeds=prompt_embeds,
         | 
| 958 | 
            -
                        negative_prompt_embeds=negative_prompt_embeds,
         | 
| 959 | 
            -
                        prompt_attention_mask=prompt_attention_mask,
         | 
| 960 | 
            -
                        negative_prompt_attention_mask=negative_prompt_attention_mask,
         | 
| 961 | 
            -
                        text_encoder_index=0,
         | 
| 962 | 
            -
                    )
         | 
| 963 | 
            -
                    (
         | 
| 964 | 
            -
                        prompt_embeds_2,
         | 
| 965 | 
            -
                        negative_prompt_embeds_2,
         | 
| 966 | 
            -
                        prompt_attention_mask_2,
         | 
| 967 | 
            -
                        negative_prompt_attention_mask_2,
         | 
| 968 | 
            -
                    ) = self.encode_prompt(
         | 
| 969 | 
            -
                        prompt=prompt,
         | 
| 970 | 
            -
                        device=device,
         | 
| 971 | 
            -
                        dtype=self.transformer.dtype,
         | 
| 972 | 
            -
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 973 | 
            -
                        do_classifier_free_guidance=self.do_classifier_free_guidance,
         | 
| 974 | 
            -
                        negative_prompt=negative_prompt,
         | 
| 975 | 
            -
                        prompt_embeds=prompt_embeds_2,
         | 
| 976 | 
            -
                        negative_prompt_embeds=negative_prompt_embeds_2,
         | 
| 977 | 
            -
                        prompt_attention_mask=prompt_attention_mask_2,
         | 
| 978 | 
            -
                        negative_prompt_attention_mask=negative_prompt_attention_mask_2,
         | 
| 979 | 
            -
                        text_encoder_index=1,
         | 
| 980 | 
            -
                    ) 
         | 
| 981 | 
            -
                    torch.cuda.empty_cache()
         | 
| 982 | 
            -
             | 
| 983 | 
            -
                    # 4. set timesteps
         | 
| 984 | 
            -
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         | 
| 985 | 
            -
                    timesteps, num_inference_steps = self.get_timesteps(
         | 
| 986 | 
            -
                        num_inference_steps=num_inference_steps, strength=strength, device=device
         | 
| 987 | 
            -
                    )
         | 
| 988 | 
            -
                    if comfyui_progressbar:
         | 
| 989 | 
            -
                        from comfy.utils import ProgressBar
         | 
| 990 | 
            -
                        pbar = ProgressBar(num_inference_steps + 3)
         | 
| 991 | 
            -
                    # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
         | 
| 992 | 
            -
                    latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
         | 
| 993 | 
            -
                    # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
         | 
| 994 | 
            -
                    is_strength_max = strength == 1.0
         | 
| 995 | 
            -
             | 
| 996 | 
            -
                    if video is not None:
         | 
| 997 | 
            -
                        video_length = video.shape[2]
         | 
| 998 | 
            -
                        init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) 
         | 
| 999 | 
            -
                        init_video = init_video.to(dtype=torch.float32)
         | 
| 1000 | 
            -
                        init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 1001 | 
            -
                    else:
         | 
| 1002 | 
            -
                        init_video = None
         | 
| 1003 | 
            -
             | 
| 1004 | 
            -
                    # Prepare latent variables
         | 
| 1005 | 
            -
                    num_channels_latents = self.vae.config.latent_channels
         | 
| 1006 | 
            -
                    num_channels_transformer = self.transformer.config.in_channels
         | 
| 1007 | 
            -
                    return_image_latents = num_channels_transformer == num_channels_latents
         | 
| 1008 | 
            -
             | 
| 1009 | 
            -
                    # 5. Prepare latents.
         | 
| 1010 | 
            -
                    latents_outputs = self.prepare_latents(
         | 
| 1011 | 
            -
                        batch_size * num_images_per_prompt,
         | 
| 1012 | 
            -
                        num_channels_latents,
         | 
| 1013 | 
            -
                        height,
         | 
| 1014 | 
            -
                        width,
         | 
| 1015 | 
            -
                        video_length,
         | 
| 1016 | 
            -
                        prompt_embeds.dtype,
         | 
| 1017 | 
            -
                        device,
         | 
| 1018 | 
            -
                        generator,
         | 
| 1019 | 
            -
                        latents,
         | 
| 1020 | 
            -
                        video=init_video,
         | 
| 1021 | 
            -
                        timestep=latent_timestep,
         | 
| 1022 | 
            -
                        is_strength_max=is_strength_max,
         | 
| 1023 | 
            -
                        return_noise=True,
         | 
| 1024 | 
            -
                        return_video_latents=return_image_latents,
         | 
| 1025 | 
            -
                    )
         | 
| 1026 | 
            -
                    if return_image_latents:
         | 
| 1027 | 
            -
                        latents, noise, image_latents = latents_outputs
         | 
| 1028 | 
            -
                    else:
         | 
| 1029 | 
            -
                        latents, noise = latents_outputs
         | 
| 1030 | 
            -
             | 
| 1031 | 
            -
                    if comfyui_progressbar:
         | 
| 1032 | 
            -
                        pbar.update(1)
         | 
| 1033 | 
            -
             | 
| 1034 | 
            -
                    # 6. Prepare clip latents if it needs.
         | 
| 1035 | 
            -
                    if clip_image is not None and self.transformer.enable_clip_in_inpaint:
         | 
| 1036 | 
            -
                        inputs = self.clip_image_processor(images=clip_image, return_tensors="pt")
         | 
| 1037 | 
            -
                        inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype)
         | 
| 1038 | 
            -
                        clip_encoder_hidden_states = self.clip_image_encoder(**inputs).last_hidden_state[:, 1:]
         | 
| 1039 | 
            -
                        clip_encoder_hidden_states_neg = torch.zeros(
         | 
| 1040 | 
            -
                            [
         | 
| 1041 | 
            -
                                batch_size, 
         | 
| 1042 | 
            -
                                int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2, 
         | 
| 1043 | 
            -
                                int(self.clip_image_encoder.config.hidden_size)
         | 
| 1044 | 
            -
                            ]
         | 
| 1045 | 
            -
                        ).to(latents.device, dtype=latents.dtype)
         | 
| 1046 | 
            -
             | 
| 1047 | 
            -
                        clip_attention_mask = torch.ones([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
         | 
| 1048 | 
            -
                        clip_attention_mask_neg = torch.zeros([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
         | 
| 1049 | 
            -
             | 
| 1050 | 
            -
                        clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if self.do_classifier_free_guidance else clip_encoder_hidden_states
         | 
| 1051 | 
            -
                        clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if self.do_classifier_free_guidance else clip_attention_mask
         | 
| 1052 | 
            -
             | 
| 1053 | 
            -
                    elif clip_image is None and num_channels_transformer != num_channels_latents and self.transformer.enable_clip_in_inpaint:
         | 
| 1054 | 
            -
                        clip_encoder_hidden_states = torch.zeros(
         | 
| 1055 | 
            -
                            [
         | 
| 1056 | 
            -
                                batch_size, 
         | 
| 1057 | 
            -
                                int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2, 
         | 
| 1058 | 
            -
                                int(self.clip_image_encoder.config.hidden_size)
         | 
| 1059 | 
            -
                            ]
         | 
| 1060 | 
            -
                        ).to(latents.device, dtype=latents.dtype)
         | 
| 1061 | 
            -
             | 
| 1062 | 
            -
                        clip_attention_mask = torch.zeros([batch_size, self.transformer.n_query])
         | 
| 1063 | 
            -
                        clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype)
         | 
| 1064 | 
            -
             | 
| 1065 | 
            -
                        clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if self.do_classifier_free_guidance else clip_encoder_hidden_states
         | 
| 1066 | 
            -
                        clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if self.do_classifier_free_guidance else clip_attention_mask
         | 
| 1067 | 
            -
             | 
| 1068 | 
            -
                    else:
         | 
| 1069 | 
            -
                        clip_encoder_hidden_states_input = None
         | 
| 1070 | 
            -
                        clip_attention_mask_input = None
         | 
| 1071 | 
            -
                    if comfyui_progressbar:
         | 
| 1072 | 
            -
                        pbar.update(1)
         | 
| 1073 | 
            -
             | 
| 1074 | 
            -
                    # 7. Prepare inpaint latents if it needs.
         | 
| 1075 | 
            -
                    if mask_video is not None:
         | 
| 1076 | 
            -
                        if (mask_video == 255).all():
         | 
| 1077 | 
            -
                            # Use zero latents if we want to t2v.
         | 
| 1078 | 
            -
                            if self.transformer.resize_inpaint_mask_directly:
         | 
| 1079 | 
            -
                                mask_latents = torch.zeros_like(latents)[:, :1].to(latents.device, latents.dtype)
         | 
| 1080 | 
            -
                            else:
         | 
| 1081 | 
            -
                                mask_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
         | 
| 1082 | 
            -
                            masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
         | 
| 1083 | 
            -
             | 
| 1084 | 
            -
                            mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
         | 
| 1085 | 
            -
                            masked_video_latents_input = (
         | 
| 1086 | 
            -
                                torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
         | 
| 1087 | 
            -
                            )
         | 
| 1088 | 
            -
                            inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
         | 
| 1089 | 
            -
                        else:
         | 
| 1090 | 
            -
                            # Prepare mask latent variables
         | 
| 1091 | 
            -
                            video_length = video.shape[2]
         | 
| 1092 | 
            -
                            mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) 
         | 
| 1093 | 
            -
                            mask_condition = mask_condition.to(dtype=torch.float32)
         | 
| 1094 | 
            -
                            mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 1095 | 
            -
             | 
| 1096 | 
            -
                            if num_channels_transformer != num_channels_latents:
         | 
| 1097 | 
            -
                                mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
         | 
| 1098 | 
            -
                                if masked_video_latents is None:
         | 
| 1099 | 
            -
                                    masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
         | 
| 1100 | 
            -
                                else:
         | 
| 1101 | 
            -
                                    masked_video = masked_video_latents
         | 
| 1102 | 
            -
                                
         | 
| 1103 | 
            -
                                if self.transformer.resize_inpaint_mask_directly:
         | 
| 1104 | 
            -
                                    _, masked_video_latents = self.prepare_mask_latents(
         | 
| 1105 | 
            -
                                        None,
         | 
| 1106 | 
            -
                                        masked_video,
         | 
| 1107 | 
            -
                                        batch_size,
         | 
| 1108 | 
            -
                                        height,
         | 
| 1109 | 
            -
                                        width,
         | 
| 1110 | 
            -
                                        prompt_embeds.dtype,
         | 
| 1111 | 
            -
                                        device,
         | 
| 1112 | 
            -
                                        generator,
         | 
| 1113 | 
            -
                                        self.do_classifier_free_guidance,
         | 
| 1114 | 
            -
                                        noise_aug_strength=noise_aug_strength,
         | 
| 1115 | 
            -
                                    )
         | 
| 1116 | 
            -
                                    mask_latents = resize_mask(1 - mask_condition, masked_video_latents, self.vae.cache_mag_vae)
         | 
| 1117 | 
            -
                                    mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
         | 
| 1118 | 
            -
                                else:
         | 
| 1119 | 
            -
                                    mask_latents, masked_video_latents = self.prepare_mask_latents(
         | 
| 1120 | 
            -
                                        mask_condition_tile,
         | 
| 1121 | 
            -
                                        masked_video,
         | 
| 1122 | 
            -
                                        batch_size,
         | 
| 1123 | 
            -
                                        height,
         | 
| 1124 | 
            -
                                        width,
         | 
| 1125 | 
            -
                                        prompt_embeds.dtype,
         | 
| 1126 | 
            -
                                        device,
         | 
| 1127 | 
            -
                                        generator,
         | 
| 1128 | 
            -
                                        self.do_classifier_free_guidance,
         | 
| 1129 | 
            -
                                        noise_aug_strength=noise_aug_strength,
         | 
| 1130 | 
            -
                                    )
         | 
| 1131 | 
            -
                                
         | 
| 1132 | 
            -
                                mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
         | 
| 1133 | 
            -
                                masked_video_latents_input = (
         | 
| 1134 | 
            -
                                    torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
         | 
| 1135 | 
            -
                                )
         | 
| 1136 | 
            -
                                inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
         | 
| 1137 | 
            -
                            else:
         | 
| 1138 | 
            -
                                inpaint_latents = None
         | 
| 1139 | 
            -
             | 
| 1140 | 
            -
                            mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
         | 
| 1141 | 
            -
                            mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
         | 
| 1142 | 
            -
                    else:
         | 
| 1143 | 
            -
                        if num_channels_transformer != num_channels_latents:
         | 
| 1144 | 
            -
                            mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
         | 
| 1145 | 
            -
                            masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
         | 
| 1146 | 
            -
             | 
| 1147 | 
            -
                            mask_input = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask
         | 
| 1148 | 
            -
                            masked_video_latents_input = (
         | 
| 1149 | 
            -
                                torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
         | 
| 1150 | 
            -
                            )
         | 
| 1151 | 
            -
                            inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
         | 
| 1152 | 
            -
                        else:
         | 
| 1153 | 
            -
                            mask = torch.zeros_like(init_video[:, :1])
         | 
| 1154 | 
            -
                            mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
         | 
| 1155 | 
            -
                            mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
         | 
| 1156 | 
            -
             | 
| 1157 | 
            -
                            inpaint_latents = None
         | 
| 1158 | 
            -
                    if comfyui_progressbar:
         | 
| 1159 | 
            -
                        pbar.update(1)
         | 
| 1160 | 
            -
             | 
| 1161 | 
            -
                    # Check that sizes of mask, masked image and latents match
         | 
| 1162 | 
            -
                    if num_channels_transformer != num_channels_latents:
         | 
| 1163 | 
            -
                        num_channels_mask = mask_latents.shape[1]
         | 
| 1164 | 
            -
                        num_channels_masked_image = masked_video_latents.shape[1]
         | 
| 1165 | 
            -
                        if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
         | 
| 1166 | 
            -
                            raise ValueError(
         | 
| 1167 | 
            -
                                f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
         | 
| 1168 | 
            -
                                f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
         | 
| 1169 | 
            -
                                f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
         | 
| 1170 | 
            -
                                f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
         | 
| 1171 | 
            -
                                " `pipeline.transformer` or your `mask_image` or `image` input."
         | 
| 1172 | 
            -
                            )
         | 
| 1173 | 
            -
             | 
| 1174 | 
            -
                    # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
         | 
| 1175 | 
            -
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 1176 | 
            -
             | 
| 1177 | 
            -
                    # 9 create image_rotary_emb, style embedding & time ids
         | 
| 1178 | 
            -
                    grid_height = height // 8 // self.transformer.config.patch_size
         | 
| 1179 | 
            -
                    grid_width = width // 8 // self.transformer.config.patch_size
         | 
| 1180 | 
            -
                    if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope":
         | 
| 1181 | 
            -
                        base_size_width = 720 // 8 // self.transformer.config.patch_size
         | 
| 1182 | 
            -
                        base_size_height = 480 // 8 // self.transformer.config.patch_size
         | 
| 1183 | 
            -
             | 
| 1184 | 
            -
                        grid_crops_coords = get_resize_crop_region_for_grid(
         | 
| 1185 | 
            -
                            (grid_height, grid_width), base_size_width, base_size_height
         | 
| 1186 | 
            -
                        )
         | 
| 1187 | 
            -
                        image_rotary_emb = get_3d_rotary_pos_embed(
         | 
| 1188 | 
            -
                            self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width),
         | 
| 1189 | 
            -
                            temporal_size=latents.size(2), use_real=True,
         | 
| 1190 | 
            -
                        )
         | 
| 1191 | 
            -
                    else:
         | 
| 1192 | 
            -
                        base_size = 512 // 8 // self.transformer.config.patch_size
         | 
| 1193 | 
            -
                        grid_crops_coords = get_resize_crop_region_for_grid(
         | 
| 1194 | 
            -
                            (grid_height, grid_width), base_size, base_size
         | 
| 1195 | 
            -
                        )
         | 
| 1196 | 
            -
                        image_rotary_emb = get_2d_rotary_pos_embed(
         | 
| 1197 | 
            -
                            self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width)
         | 
| 1198 | 
            -
                        )
         | 
| 1199 | 
            -
             | 
| 1200 | 
            -
                    # Get other hunyuan params
         | 
| 1201 | 
            -
                    style = torch.tensor([0], device=device)
         | 
| 1202 | 
            -
             | 
| 1203 | 
            -
                    target_size = target_size or (height, width)
         | 
| 1204 | 
            -
                    add_time_ids = list(original_size + target_size + crops_coords_top_left)
         | 
| 1205 | 
            -
                    add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
         | 
| 1206 | 
            -
             | 
| 1207 | 
            -
                    if self.do_classifier_free_guidance:
         | 
| 1208 | 
            -
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
         | 
| 1209 | 
            -
                        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
         | 
| 1210 | 
            -
                        prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
         | 
| 1211 | 
            -
                        prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
         | 
| 1212 | 
            -
                        add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
         | 
| 1213 | 
            -
                        style = torch.cat([style] * 2, dim=0)
         | 
| 1214 | 
            -
             | 
| 1215 | 
            -
                    prompt_embeds = prompt_embeds.to(device=device)
         | 
| 1216 | 
            -
                    prompt_attention_mask = prompt_attention_mask.to(device=device)
         | 
| 1217 | 
            -
                    prompt_embeds_2 = prompt_embeds_2.to(device=device)
         | 
| 1218 | 
            -
                    prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
         | 
| 1219 | 
            -
                    add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
         | 
| 1220 | 
            -
                        batch_size * num_images_per_prompt, 1
         | 
| 1221 | 
            -
                    )
         | 
| 1222 | 
            -
                    style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
         | 
| 1223 | 
            -
             | 
| 1224 | 
            -
                    torch.cuda.empty_cache()
         | 
| 1225 | 
            -
                    if self.enable_autocast_float8_transformer_flag:
         | 
| 1226 | 
            -
                        origin_weight_dtype = self.transformer.dtype
         | 
| 1227 | 
            -
                        self.transformer = self.transformer.to(torch.float8_e4m3fn)
         | 
| 1228 | 
            -
                    # 10. Denoising loop
         | 
| 1229 | 
            -
                    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
         | 
| 1230 | 
            -
                    self._num_timesteps = len(timesteps)
         | 
| 1231 | 
            -
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 1232 | 
            -
                        for i, t in enumerate(timesteps):
         | 
| 1233 | 
            -
                            if self.interrupt:
         | 
| 1234 | 
            -
                                continue
         | 
| 1235 | 
            -
             | 
| 1236 | 
            -
                            # expand the latents if we are doing classifier free guidance
         | 
| 1237 | 
            -
                            latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
         | 
| 1238 | 
            -
                            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         | 
| 1239 | 
            -
             | 
| 1240 | 
            -
                            if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None:
         | 
| 1241 | 
            -
                                clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input)
         | 
| 1242 | 
            -
                                clip_attention_mask_actual_input = torch.zeros_like(clip_attention_mask_input)
         | 
| 1243 | 
            -
                            else:
         | 
| 1244 | 
            -
                                clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input
         | 
| 1245 | 
            -
                                clip_attention_mask_actual_input = clip_attention_mask_input
         | 
| 1246 | 
            -
                            
         | 
| 1247 | 
            -
                            # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
         | 
| 1248 | 
            -
                            t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
         | 
| 1249 | 
            -
                                dtype=latent_model_input.dtype
         | 
| 1250 | 
            -
                            )
         | 
| 1251 | 
            -
             | 
| 1252 | 
            -
                            # predict the noise residual
         | 
| 1253 | 
            -
                            noise_pred = self.transformer(
         | 
| 1254 | 
            -
                                latent_model_input,
         | 
| 1255 | 
            -
                                t_expand,
         | 
| 1256 | 
            -
                                encoder_hidden_states=prompt_embeds,
         | 
| 1257 | 
            -
                                text_embedding_mask=prompt_attention_mask,
         | 
| 1258 | 
            -
                                encoder_hidden_states_t5=prompt_embeds_2,
         | 
| 1259 | 
            -
                                text_embedding_mask_t5=prompt_attention_mask_2,
         | 
| 1260 | 
            -
                                image_meta_size=add_time_ids,
         | 
| 1261 | 
            -
                                style=style,
         | 
| 1262 | 
            -
                                image_rotary_emb=image_rotary_emb,
         | 
| 1263 | 
            -
                                inpaint_latents=inpaint_latents,
         | 
| 1264 | 
            -
                                clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input,
         | 
| 1265 | 
            -
                                clip_attention_mask=clip_attention_mask_actual_input,
         | 
| 1266 | 
            -
                                return_dict=False,
         | 
| 1267 | 
            -
                            )[0]
         | 
| 1268 | 
            -
                            if noise_pred.size()[1] != self.vae.config.latent_channels:
         | 
| 1269 | 
            -
                                noise_pred, _ = noise_pred.chunk(2, dim=1)
         | 
| 1270 | 
            -
             | 
| 1271 | 
            -
                            # perform guidance
         | 
| 1272 | 
            -
                            if self.do_classifier_free_guidance:
         | 
| 1273 | 
            -
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 1274 | 
            -
                                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 1275 | 
            -
             | 
| 1276 | 
            -
                            if self.do_classifier_free_guidance and guidance_rescale > 0.0:
         | 
| 1277 | 
            -
                                # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
         | 
| 1278 | 
            -
                                noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
         | 
| 1279 | 
            -
             | 
| 1280 | 
            -
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 1281 | 
            -
                            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
         | 
| 1282 | 
            -
             | 
| 1283 | 
            -
                            if num_channels_transformer == 4:
         | 
| 1284 | 
            -
                                init_latents_proper = image_latents
         | 
| 1285 | 
            -
                                init_mask = mask
         | 
| 1286 | 
            -
                                if i < len(timesteps) - 1:
         | 
| 1287 | 
            -
                                    noise_timestep = timesteps[i + 1]
         | 
| 1288 | 
            -
                                    init_latents_proper = self.scheduler.add_noise(
         | 
| 1289 | 
            -
                                        init_latents_proper, noise, torch.tensor([noise_timestep])
         | 
| 1290 | 
            -
                                    )
         | 
| 1291 | 
            -
                                
         | 
| 1292 | 
            -
                                latents = (1 - init_mask) * init_latents_proper + init_mask * latents
         | 
| 1293 | 
            -
             | 
| 1294 | 
            -
                            if callback_on_step_end is not None:
         | 
| 1295 | 
            -
                                callback_kwargs = {}
         | 
| 1296 | 
            -
                                for k in callback_on_step_end_tensor_inputs:
         | 
| 1297 | 
            -
                                    callback_kwargs[k] = locals()[k]
         | 
| 1298 | 
            -
                                callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
         | 
| 1299 | 
            -
             | 
| 1300 | 
            -
                                latents = callback_outputs.pop("latents", latents)
         | 
| 1301 | 
            -
                                prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
         | 
| 1302 | 
            -
                                negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
         | 
| 1303 | 
            -
                                prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
         | 
| 1304 | 
            -
                                negative_prompt_embeds_2 = callback_outputs.pop(
         | 
| 1305 | 
            -
                                    "negative_prompt_embeds_2", negative_prompt_embeds_2
         | 
| 1306 | 
            -
                                )
         | 
| 1307 | 
            -
             | 
| 1308 | 
            -
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
| 1309 | 
            -
                                progress_bar.update()
         | 
| 1310 | 
            -
             | 
| 1311 | 
            -
                            if XLA_AVAILABLE:
         | 
| 1312 | 
            -
                                xm.mark_step()
         | 
| 1313 | 
            -
             | 
| 1314 | 
            -
                            if comfyui_progressbar:
         | 
| 1315 | 
            -
                                pbar.update(1)
         | 
| 1316 | 
            -
             | 
| 1317 | 
            -
                    if self.enable_autocast_float8_transformer_flag:
         | 
| 1318 | 
            -
                        self.transformer = self.transformer.to("cpu", origin_weight_dtype)
         | 
| 1319 | 
            -
             | 
| 1320 | 
            -
                    torch.cuda.empty_cache()
         | 
| 1321 | 
            -
                    # Post-processing
         | 
| 1322 | 
            -
                    video = self.decode_latents(latents)
         | 
| 1323 | 
            -
             | 
| 1324 | 
            -
                    # Convert to tensor
         | 
| 1325 | 
            -
                    if output_type == "latent":
         | 
| 1326 | 
            -
                        video = torch.from_numpy(video)
         | 
| 1327 | 
            -
             | 
| 1328 | 
            -
                    # Offload all models
         | 
| 1329 | 
            -
                    self.maybe_free_model_hooks()
         | 
| 1330 | 
            -
             | 
| 1331 | 
            -
                    if not return_dict:
         | 
| 1332 | 
            -
                        return video
         | 
| 1333 | 
            -
             | 
| 1334 | 
            -
                    return EasyAnimatePipelineOutput(videos=video)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
    	
        easyanimate/ui/ui.py
    CHANGED
    
    | @@ -17,41 +17,42 @@ import torch | |
| 17 | 
             
            from diffusers import (AutoencoderKL, DDIMScheduler,
         | 
| 18 | 
             
                                   DPMSolverMultistepScheduler,
         | 
| 19 | 
             
                                   EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
         | 
| 20 | 
            -
                                   PNDMScheduler)
         | 
| 21 | 
             
            from diffusers.utils.import_utils import is_xformers_available
         | 
| 22 | 
             
            from omegaconf import OmegaConf
         | 
| 23 | 
             
            from PIL import Image
         | 
| 24 | 
             
            from safetensors import safe_open
         | 
| 25 | 
             
            from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
         | 
| 26 | 
            -
                                      CLIPVisionModelWithProjection,  | 
| 27 | 
            -
                                      T5EncoderModel, | 
|  | |
| 28 |  | 
| 29 | 
            -
            from  | 
| 30 | 
            -
            from  | 
| 31 | 
             
                                            name_to_transformer3d)
         | 
| 32 | 
            -
            from  | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
            from  | 
| 37 | 
             
                EasyAnimateInpaintPipeline
         | 
| 38 | 
            -
            from  | 
| 39 | 
            -
             | 
| 40 | 
            -
            from  | 
| 41 | 
            -
                EasyAnimatePipeline_Multi_Text_Encoder_Inpaint
         | 
| 42 | 
            -
            from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
         | 
| 43 | 
            -
            from easyanimate.utils.utils import (
         | 
| 44 | 
             
                get_image_to_video_latent, get_video_to_video_latent,
         | 
| 45 | 
             
                get_width_and_height_from_image_and_base_resolution, save_videos_grid)
         | 
| 46 | 
            -
            from easyanimate.utils.fp8_optimization import convert_weight_dtype_wrapper
         | 
| 47 |  | 
| 48 | 
            -
             | 
| 49 | 
             
                "Euler": EulerDiscreteScheduler,
         | 
| 50 | 
             
                "Euler A": EulerAncestralDiscreteScheduler,
         | 
| 51 | 
             
                "DPM++": DPMSolverMultistepScheduler, 
         | 
| 52 | 
             
                "PNDM": PNDMScheduler,
         | 
| 53 | 
             
                "DDIM": DDIMScheduler,
         | 
| 54 | 
             
            }
         | 
|  | |
|  | |
|  | |
|  | |
| 55 |  | 
| 56 | 
             
            gradio_version = pkg_resources.get_distribution("gradio").version
         | 
| 57 | 
             
            gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
         | 
| @@ -98,8 +99,8 @@ class EasyAnimateController: | |
| 98 | 
             
                    self.GPU_memory_mode       = GPU_memory_mode
         | 
| 99 |  | 
| 100 | 
             
                    self.weight_dtype          = weight_dtype
         | 
| 101 | 
            -
                    self.edition               = "v5"
         | 
| 102 | 
            -
                    self.inference_config      = OmegaConf.load(os.path.join(self.config_dir, " | 
| 103 |  | 
| 104 | 
             
                def refresh_diffusion_transformer(self):
         | 
| 105 | 
             
                    self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
         | 
| @@ -121,26 +122,37 @@ class EasyAnimateController: | |
| 121 | 
             
                    if edition == "v1":
         | 
| 122 | 
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v1_motion_module.yaml"))
         | 
| 123 | 
             
                        return gr.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
         | 
|  | |
| 124 | 
             
                            gr.update(value=512, minimum=384, maximum=704, step=32), \
         | 
| 125 | 
             
                            gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
         | 
| 126 | 
             
                    elif edition == "v2":
         | 
| 127 | 
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v2_magvit_motion_module.yaml"))
         | 
| 128 | 
             
                        return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
         | 
|  | |
| 129 | 
             
                            gr.update(value=672, minimum=128, maximum=1344, step=16), \
         | 
| 130 | 
             
                            gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
         | 
| 131 | 
             
                    elif edition == "v3":
         | 
| 132 | 
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v3_slicevae_motion_module.yaml"))
         | 
| 133 | 
             
                        return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
         | 
|  | |
| 134 | 
             
                            gr.update(value=672, minimum=128, maximum=1344, step=16), \
         | 
| 135 | 
             
                            gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
         | 
| 136 | 
             
                    elif edition == "v4":
         | 
| 137 | 
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v4_slicevae_multi_text_encoder.yaml"))
         | 
| 138 | 
             
                        return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
         | 
|  | |
| 139 | 
             
                            gr.update(value=672, minimum=128, maximum=1344, step=16), \
         | 
| 140 | 
             
                            gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
         | 
| 141 | 
             
                    elif edition == "v5":
         | 
| 142 | 
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5_magvit_multi_text_encoder.yaml"))
         | 
| 143 | 
             
                        return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 144 | 
             
                            gr.update(value=672, minimum=128, maximum=1344, step=16), \
         | 
| 145 | 
             
                            gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4)
         | 
| 146 |  | 
| @@ -170,33 +182,55 @@ class EasyAnimateController: | |
| 170 | 
             
                    self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
         | 
| 171 | 
             
                        diffusion_transformer_dropdown, 
         | 
| 172 | 
             
                        subfolder="transformer", 
         | 
| 173 | 
            -
                        transformer_additional_kwargs=transformer_additional_kwargs
         | 
| 174 | 
            -
             | 
|  | |
|  | |
| 175 |  | 
| 176 | 
             
                    if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
         | 
| 177 | 
             
                        tokenizer = BertTokenizer.from_pretrained(
         | 
| 178 | 
             
                            diffusion_transformer_dropdown, subfolder="tokenizer"
         | 
| 179 | 
             
                        )
         | 
| 180 | 
            -
                         | 
| 181 | 
            -
                             | 
| 182 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 183 | 
             
                    else:
         | 
| 184 | 
            -
                         | 
| 185 | 
            -
                             | 
| 186 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 187 | 
             
                        tokenizer_2 = None
         | 
| 188 |  | 
| 189 | 
             
                    if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
         | 
| 190 | 
             
                        text_encoder = BertModel.from_pretrained(
         | 
| 191 | 
             
                            diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
         | 
| 192 | 
             
                        )
         | 
| 193 | 
            -
                         | 
| 194 | 
            -
                             | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
| 197 | 
            -
                         | 
| 198 | 
            -
                             | 
| 199 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 200 | 
             
                        text_encoder_2 = None
         | 
| 201 |  | 
| 202 | 
             
                    # Get pipeline
         | 
| @@ -212,23 +246,18 @@ class EasyAnimateController: | |
| 212 | 
             
                        clip_image_processor = None
         | 
| 213 |  | 
| 214 | 
             
                    # Get Scheduler
         | 
| 215 | 
            -
                     | 
| 216 | 
            -
                        " | 
| 217 | 
            -
             | 
| 218 | 
            -
                        " | 
| 219 | 
            -
                        "PNDM": PNDMScheduler,
         | 
| 220 | 
            -
                        "DDIM": DDIMScheduler,
         | 
| 221 | 
            -
                    }["Euler"]
         | 
| 222 | 
            -
             | 
| 223 | 
             
                    scheduler = Choosen_Scheduler.from_pretrained(
         | 
| 224 | 
             
                        diffusion_transformer_dropdown, 
         | 
| 225 | 
             
                        subfolder="scheduler"
         | 
| 226 | 
             
                    )
         | 
| 227 |  | 
| 228 | 
            -
                    if self. | 
| 229 | 
             
                        if self.transformer.config.in_channels != self.vae.config.latent_channels:
         | 
| 230 | 
            -
                            self.pipeline =  | 
| 231 | 
            -
                                diffusion_transformer_dropdown,
         | 
| 232 | 
             
                                text_encoder=text_encoder,
         | 
| 233 | 
             
                                text_encoder_2=text_encoder_2,
         | 
| 234 | 
             
                                tokenizer=tokenizer,
         | 
| @@ -236,13 +265,11 @@ class EasyAnimateController: | |
| 236 | 
             
                                vae=self.vae,
         | 
| 237 | 
             
                                transformer=self.transformer,
         | 
| 238 | 
             
                                scheduler=scheduler,
         | 
| 239 | 
            -
                                torch_dtype=self.weight_dtype,
         | 
| 240 | 
             
                                clip_image_encoder=clip_image_encoder,
         | 
| 241 | 
             
                                clip_image_processor=clip_image_processor,
         | 
| 242 | 
            -
                            )
         | 
| 243 | 
             
                        else:
         | 
| 244 | 
            -
                            self.pipeline =  | 
| 245 | 
            -
                                diffusion_transformer_dropdown,
         | 
| 246 | 
             
                                text_encoder=text_encoder,
         | 
| 247 | 
             
                                text_encoder_2=text_encoder_2,
         | 
| 248 | 
             
                                tokenizer=tokenizer,
         | 
| @@ -250,40 +277,25 @@ class EasyAnimateController: | |
| 250 | 
             
                                vae=self.vae,
         | 
| 251 | 
             
                                transformer=self.transformer,
         | 
| 252 | 
             
                                scheduler=scheduler,
         | 
| 253 | 
            -
             | 
| 254 | 
            -
                            )
         | 
| 255 | 
             
                    else:
         | 
| 256 | 
            -
                         | 
| 257 | 
            -
                             | 
| 258 | 
            -
             | 
| 259 | 
            -
             | 
| 260 | 
            -
             | 
| 261 | 
            -
             | 
| 262 | 
            -
             | 
| 263 | 
            -
             | 
| 264 | 
            -
             | 
| 265 | 
            -
                                clip_image_encoder=clip_image_encoder,
         | 
| 266 | 
            -
                                clip_image_processor=clip_image_processor,
         | 
| 267 | 
            -
                            )
         | 
| 268 | 
            -
                        else:
         | 
| 269 | 
            -
                            self.pipeline = EasyAnimatePipeline(
         | 
| 270 | 
            -
                                diffusion_transformer_dropdown,
         | 
| 271 | 
            -
                                text_encoder=text_encoder,
         | 
| 272 | 
            -
                                tokenizer=tokenizer,
         | 
| 273 | 
            -
                                vae=self.vae, 
         | 
| 274 | 
            -
                                transformer=self.transformer,
         | 
| 275 | 
            -
                                scheduler=scheduler,
         | 
| 276 | 
            -
                                torch_dtype=self.weight_dtype
         | 
| 277 | 
            -
                            )
         | 
| 278 |  | 
| 279 | 
             
                    if self.GPU_memory_mode == "sequential_cpu_offload":
         | 
| 280 | 
             
                        self.pipeline.enable_sequential_cpu_offload()
         | 
| 281 | 
             
                    elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
         | 
| 282 | 
             
                        self.pipeline.enable_model_cpu_offload()
         | 
| 283 | 
            -
                        self.pipeline.enable_autocast_float8_transformer()
         | 
| 284 | 
             
                        convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
         | 
| 285 | 
             
                    else:
         | 
| 286 | 
            -
                        self. | 
| 287 | 
             
                    print("Update diffusion transformer done")
         | 
| 288 | 
             
                    return gr.update()
         | 
| 289 |  | 
| @@ -374,8 +386,10 @@ class EasyAnimateController: | |
| 374 | 
             
                    if self.base_model_path != base_model_dropdown:
         | 
| 375 | 
             
                        self.update_base_model(base_model_dropdown)
         | 
| 376 |  | 
|  | |
|  | |
|  | |
| 377 | 
             
                    if self.lora_model_path != lora_model_dropdown:
         | 
| 378 | 
            -
                        print("Update lora model")
         | 
| 379 | 
             
                        self.update_lora_model(lora_model_dropdown)
         | 
| 380 |  | 
| 381 | 
             
                    if control_video is not None and self.model_type == "Inpaint":
         | 
| @@ -426,19 +440,21 @@ class EasyAnimateController: | |
| 426 | 
             
                        else:
         | 
| 427 | 
             
                            raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
         | 
| 428 |  | 
| 429 | 
            -
                    fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8}[self.edition]
         | 
| 430 | 
             
                    is_image = True if generation_method == "Image Generation" else False
         | 
| 431 |  | 
| 432 | 
            -
                    if  | 
|  | |
|  | |
| 433 |  | 
| 434 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 435 | 
             
                    if self.lora_model_path != "none":
         | 
| 436 | 
             
                        # lora part
         | 
| 437 | 
             
                        self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
         | 
| 438 | 
            -
             | 
| 439 | 
            -
                    if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
         | 
| 440 | 
            -
                    else: seed_textbox = np.random.randint(0, 1e10)
         | 
| 441 | 
            -
                    generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
         | 
| 442 |  | 
| 443 | 
             
                    try:
         | 
| 444 | 
             
                        if self.model_type == "Inpaint":
         | 
| @@ -480,7 +496,7 @@ class EasyAnimateController: | |
| 480 | 
             
                                                video        = input_video,
         | 
| 481 | 
             
                                                mask_video   = input_video_mask,
         | 
| 482 | 
             
                                                strength     = 1,
         | 
| 483 | 
            -
                                            ). | 
| 484 |  | 
| 485 | 
             
                                        if init_frames != 0:
         | 
| 486 | 
             
                                            mix_ratio = torch.from_numpy(
         | 
| @@ -531,7 +547,7 @@ class EasyAnimateController: | |
| 531 | 
             
                                        video        = input_video,
         | 
| 532 | 
             
                                        mask_video   = input_video_mask,
         | 
| 533 | 
             
                                        strength     = strength,
         | 
| 534 | 
            -
                                    ). | 
| 535 | 
             
                            else:
         | 
| 536 | 
             
                                if self.vae.cache_mag_vae:
         | 
| 537 | 
             
                                    length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
         | 
| @@ -547,7 +563,7 @@ class EasyAnimateController: | |
| 547 | 
             
                                    height              = height_slider,
         | 
| 548 | 
             
                                    video_length        = length_slider if not is_image else 1,
         | 
| 549 | 
             
                                    generator           = generator
         | 
| 550 | 
            -
                                ). | 
| 551 | 
             
                        else:
         | 
| 552 | 
             
                            if self.vae.cache_mag_vae:
         | 
| 553 | 
             
                                length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
         | 
| @@ -566,7 +582,7 @@ class EasyAnimateController: | |
| 566 | 
             
                                generator           = generator,
         | 
| 567 |  | 
| 568 | 
             
                                control_video = input_video,
         | 
| 569 | 
            -
                            ). | 
| 570 | 
             
                    except Exception as e:
         | 
| 571 | 
             
                        gc.collect()
         | 
| 572 | 
             
                        torch.cuda.empty_cache()
         | 
| @@ -676,8 +692,8 @@ def ui(GPU_memory_mode, weight_dtype): | |
| 676 | 
             
                        with gr.Row():
         | 
| 677 | 
             
                            easyanimate_edition_dropdown = gr.Dropdown(
         | 
| 678 | 
             
                                label="The config of EasyAnimate Edition (EasyAnimate版本配置)",
         | 
| 679 | 
            -
                                choices=["v1", "v2", "v3", "v4", "v5"],
         | 
| 680 | 
            -
                                value="v5",
         | 
| 681 | 
             
                                interactive=True,
         | 
| 682 | 
             
                            )
         | 
| 683 | 
             
                        gr.Markdown(
         | 
| @@ -751,13 +767,22 @@ def ui(GPU_memory_mode, weight_dtype): | |
| 751 | 
             
                            """
         | 
| 752 | 
             
                        )
         | 
| 753 |  | 
| 754 | 
            -
                        prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful  | 
| 755 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 756 |  | 
| 757 | 
             
                        with gr.Row():
         | 
| 758 | 
             
                            with gr.Column():
         | 
| 759 | 
             
                                with gr.Row():
         | 
| 760 | 
            -
                                    sampler_dropdown   = gr.Dropdown( | 
|  | |
|  | |
|  | |
| 761 | 
             
                                    sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=100, step=1)
         | 
| 762 |  | 
| 763 | 
             
                                resize_method = gr.Radio(
         | 
| @@ -794,11 +819,11 @@ def ui(GPU_memory_mode, weight_dtype): | |
| 794 | 
             
                                    template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
         | 
| 795 | 
             
                                    def select_template(evt: gr.SelectData):
         | 
| 796 | 
             
                                        text = {
         | 
| 797 | 
            -
                                            "asset/1.png": " | 
| 798 | 
            -
                                            "asset/2.png": " | 
| 799 | 
            -
                                            "asset/3.png": " | 
| 800 | 
            -
                                            "asset/4.png": " | 
| 801 | 
            -
                                            "asset/5.png": " | 
| 802 | 
             
                                        }[template_gallery_path[evt.index]]
         | 
| 803 | 
             
                                        return template_gallery_path[evt.index], text
         | 
| 804 |  | 
| @@ -838,6 +863,7 @@ def ui(GPU_memory_mode, weight_dtype): | |
| 838 | 
             
                                    gr.Markdown(
         | 
| 839 | 
             
                                        """
         | 
| 840 | 
             
                                        Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
         | 
|  | |
| 841 | 
             
                                        """
         | 
| 842 | 
             
                                    )
         | 
| 843 | 
             
                                    control_video = gr.Video(
         | 
| @@ -927,6 +953,7 @@ def ui(GPU_memory_mode, weight_dtype): | |
| 927 | 
             
                                diffusion_transformer_dropdown, 
         | 
| 928 | 
             
                                motion_module_dropdown, 
         | 
| 929 | 
             
                                motion_module_refresh_button, 
         | 
|  | |
| 930 | 
             
                                width_slider, 
         | 
| 931 | 
             
                                height_slider, 
         | 
| 932 | 
             
                                length_slider, 
         | 
| @@ -1003,33 +1030,55 @@ class EasyAnimateController_Modelscope: | |
| 1003 | 
             
                    self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
         | 
| 1004 | 
             
                        model_name, 
         | 
| 1005 | 
             
                        subfolder="transformer", 
         | 
| 1006 | 
            -
                        transformer_additional_kwargs=transformer_additional_kwargs
         | 
| 1007 | 
            -
             | 
|  | |
|  | |
| 1008 |  | 
| 1009 | 
             
                    if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
         | 
| 1010 | 
             
                        tokenizer = BertTokenizer.from_pretrained(
         | 
| 1011 | 
             
                            model_name, subfolder="tokenizer"
         | 
| 1012 | 
             
                        )
         | 
| 1013 | 
            -
                         | 
| 1014 | 
            -
                             | 
| 1015 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 1016 | 
             
                    else:
         | 
| 1017 | 
            -
                         | 
| 1018 | 
            -
                             | 
| 1019 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 1020 | 
             
                        tokenizer_2 = None
         | 
| 1021 |  | 
| 1022 | 
             
                    if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
         | 
| 1023 | 
             
                        text_encoder = BertModel.from_pretrained(
         | 
| 1024 | 
             
                            model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype
         | 
| 1025 | 
             
                        )
         | 
| 1026 | 
            -
                         | 
| 1027 | 
            -
                             | 
| 1028 | 
            -
             | 
| 1029 | 
            -
             | 
| 1030 | 
            -
                         | 
| 1031 | 
            -
                             | 
| 1032 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1033 | 
             
                        text_encoder_2 = None
         | 
| 1034 |  | 
| 1035 | 
             
                    # Get pipeline
         | 
| @@ -1045,23 +1094,18 @@ class EasyAnimateController_Modelscope: | |
| 1045 | 
             
                        clip_image_processor = None
         | 
| 1046 |  | 
| 1047 | 
             
                    # Get Scheduler
         | 
| 1048 | 
            -
                     | 
| 1049 | 
            -
                        " | 
| 1050 | 
            -
             | 
| 1051 | 
            -
                        " | 
| 1052 | 
            -
                        "PNDM": PNDMScheduler,
         | 
| 1053 | 
            -
                        "DDIM": DDIMScheduler,
         | 
| 1054 | 
            -
                    }["Euler"]
         | 
| 1055 | 
            -
             | 
| 1056 | 
             
                    scheduler = Choosen_Scheduler.from_pretrained(
         | 
| 1057 | 
             
                        model_name, 
         | 
| 1058 | 
             
                        subfolder="scheduler"
         | 
| 1059 | 
             
                    )
         | 
| 1060 |  | 
| 1061 | 
            -
                    if  | 
| 1062 | 
             
                        if self.transformer.config.in_channels != self.vae.config.latent_channels:
         | 
| 1063 | 
            -
                            self.pipeline =  | 
| 1064 | 
            -
                                model_name,
         | 
| 1065 | 
             
                                text_encoder=text_encoder,
         | 
| 1066 | 
             
                                text_encoder_2=text_encoder_2,
         | 
| 1067 | 
             
                                tokenizer=tokenizer,
         | 
| @@ -1069,51 +1113,34 @@ class EasyAnimateController_Modelscope: | |
| 1069 | 
             
                                vae=self.vae,
         | 
| 1070 | 
             
                                transformer=self.transformer,
         | 
| 1071 | 
             
                                scheduler=scheduler,
         | 
| 1072 | 
            -
                                torch_dtype=self.weight_dtype,
         | 
| 1073 | 
             
                                clip_image_encoder=clip_image_encoder,
         | 
| 1074 | 
             
                                clip_image_processor=clip_image_processor,
         | 
| 1075 | 
            -
                            )
         | 
| 1076 | 
             
                        else:
         | 
| 1077 | 
            -
                            self.pipeline =  | 
| 1078 | 
            -
                                model_name,
         | 
| 1079 | 
             
                                text_encoder=text_encoder,
         | 
| 1080 | 
             
                                text_encoder_2=text_encoder_2,
         | 
| 1081 | 
             
                                tokenizer=tokenizer,
         | 
| 1082 | 
             
                                tokenizer_2=tokenizer_2,
         | 
| 1083 | 
             
                                vae=self.vae,
         | 
| 1084 | 
             
                                transformer=self.transformer,
         | 
| 1085 | 
            -
                                scheduler=scheduler | 
| 1086 | 
            -
             | 
| 1087 | 
            -
                            )
         | 
| 1088 | 
             
                    else:
         | 
| 1089 | 
            -
                         | 
| 1090 | 
            -
                             | 
| 1091 | 
            -
             | 
| 1092 | 
            -
             | 
| 1093 | 
            -
             | 
| 1094 | 
            -
             | 
| 1095 | 
            -
             | 
| 1096 | 
            -
             | 
| 1097 | 
            -
             | 
| 1098 | 
            -
                                clip_image_encoder=clip_image_encoder,
         | 
| 1099 | 
            -
                                clip_image_processor=clip_image_processor,
         | 
| 1100 | 
            -
                            )
         | 
| 1101 | 
            -
                        else:
         | 
| 1102 | 
            -
                            self.pipeline = EasyAnimatePipeline(
         | 
| 1103 | 
            -
                                model_name,
         | 
| 1104 | 
            -
                                text_encoder=text_encoder,
         | 
| 1105 | 
            -
                                tokenizer=tokenizer,
         | 
| 1106 | 
            -
                                vae=self.vae, 
         | 
| 1107 | 
            -
                                transformer=self.transformer,
         | 
| 1108 | 
            -
                                scheduler=scheduler,
         | 
| 1109 | 
            -
                                torch_dtype=self.weight_dtype
         | 
| 1110 | 
            -
                            )
         | 
| 1111 |  | 
| 1112 | 
             
                    if GPU_memory_mode == "sequential_cpu_offload":
         | 
| 1113 | 
             
                        self.pipeline.enable_sequential_cpu_offload()
         | 
| 1114 | 
             
                    elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
         | 
| 1115 | 
             
                        self.pipeline.enable_model_cpu_offload()
         | 
| 1116 | 
            -
                        self.pipeline.enable_autocast_float8_transformer()
         | 
| 1117 | 
             
                        convert_weight_dtype_wrapper(self.pipeline.transformer, weight_dtype)
         | 
| 1118 | 
             
                    else:
         | 
| 1119 | 
             
                        GPU_memory_mode.enable_model_cpu_offload()
         | 
| @@ -1214,17 +1241,17 @@ class EasyAnimateController_Modelscope: | |
| 1214 | 
             
                        else:
         | 
| 1215 | 
             
                            raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
         | 
| 1216 |  | 
| 1217 | 
            -
                    fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8}[self.edition]
         | 
| 1218 | 
             
                    is_image = True if generation_method == "Image Generation" else False
         | 
| 1219 |  | 
| 1220 | 
            -
                    self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
         | 
| 1221 | 
            -
                    if self.lora_model_path != "none":
         | 
| 1222 | 
            -
                        # lora part
         | 
| 1223 | 
            -
                        self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
         | 
| 1224 | 
            -
             | 
| 1225 | 
             
                    if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
         | 
| 1226 | 
             
                    else: seed_textbox = np.random.randint(0, 1e10)
         | 
| 1227 | 
             
                    generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 1228 |  | 
| 1229 | 
             
                    try:
         | 
| 1230 | 
             
                        if self.model_type == "Inpaint":
         | 
| @@ -1254,7 +1281,7 @@ class EasyAnimateController_Modelscope: | |
| 1254 | 
             
                                    video        = input_video,
         | 
| 1255 | 
             
                                    mask_video   = input_video_mask,
         | 
| 1256 | 
             
                                    strength     = strength,
         | 
| 1257 | 
            -
                                ). | 
| 1258 | 
             
                            else:
         | 
| 1259 | 
             
                                sample = self.pipeline(
         | 
| 1260 | 
             
                                    prompt_textbox,
         | 
| @@ -1265,7 +1292,7 @@ class EasyAnimateController_Modelscope: | |
| 1265 | 
             
                                    height              = height_slider,
         | 
| 1266 | 
             
                                    video_length        = length_slider if not is_image else 1,
         | 
| 1267 | 
             
                                    generator           = generator
         | 
| 1268 | 
            -
                                ). | 
| 1269 | 
             
                        else:
         | 
| 1270 | 
             
                            if self.vae.cache_mag_vae:
         | 
| 1271 | 
             
                                length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
         | 
| @@ -1285,7 +1312,7 @@ class EasyAnimateController_Modelscope: | |
| 1285 | 
             
                                generator           = generator,
         | 
| 1286 |  | 
| 1287 | 
             
                                control_video = input_video,
         | 
| 1288 | 
            -
                            ). | 
| 1289 | 
             
                    except Exception as e:
         | 
| 1290 | 
             
                        gc.collect()
         | 
| 1291 | 
             
                        torch.cuda.empty_cache()
         | 
| @@ -1406,13 +1433,28 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample, | |
| 1406 | 
             
                            """
         | 
| 1407 | 
             
                        )
         | 
| 1408 |  | 
| 1409 | 
            -
                        prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful  | 
| 1410 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1411 |  | 
| 1412 | 
             
                        with gr.Row():
         | 
| 1413 | 
             
                            with gr.Column():
         | 
| 1414 | 
             
                                with gr.Row():
         | 
| 1415 | 
            -
                                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1416 | 
             
                                    sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=50, step=1, interactive=False)
         | 
| 1417 |  | 
| 1418 | 
             
                                if edition == "v1":
         | 
| @@ -1466,11 +1508,11 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample, | |
| 1466 | 
             
                                        template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
         | 
| 1467 | 
             
                                        def select_template(evt: gr.SelectData):
         | 
| 1468 | 
             
                                            text = {
         | 
| 1469 | 
            -
                                                "asset/1.png": " | 
| 1470 | 
            -
                                                "asset/2.png": " | 
| 1471 | 
            -
                                                "asset/3.png": " | 
| 1472 | 
            -
                                                "asset/4.png": " | 
| 1473 | 
            -
                                                "asset/5.png": " | 
| 1474 | 
             
                                            }[template_gallery_path[evt.index]]
         | 
| 1475 | 
             
                                            return template_gallery_path[evt.index], text
         | 
| 1476 |  | 
| @@ -1510,6 +1552,7 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample, | |
| 1510 | 
             
                                        gr.Markdown(
         | 
| 1511 | 
             
                                            """
         | 
| 1512 | 
             
                                            Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
         | 
|  | |
| 1513 | 
             
                                            """
         | 
| 1514 | 
             
                                        )
         | 
| 1515 | 
             
                                        control_video = gr.Video(
         | 
| @@ -1820,13 +1863,28 @@ def ui_eas(edition, config_path, model_name, savedir_sample): | |
| 1820 | 
             
                            """
         | 
| 1821 | 
             
                        )
         | 
| 1822 |  | 
| 1823 | 
            -
                        prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful  | 
| 1824 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1825 |  | 
| 1826 | 
             
                        with gr.Row():
         | 
| 1827 | 
             
                            with gr.Column():
         | 
| 1828 | 
             
                                with gr.Row():
         | 
| 1829 | 
            -
                                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1830 | 
             
                                    sample_step_slider = gr.Slider(label="Sampling steps", value=40, minimum=10, maximum=40, step=1, interactive=False)
         | 
| 1831 |  | 
| 1832 | 
             
                                if edition == "v1":
         | 
| @@ -1875,11 +1933,11 @@ def ui_eas(edition, config_path, model_name, savedir_sample): | |
| 1875 | 
             
                                        template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
         | 
| 1876 | 
             
                                        def select_template(evt: gr.SelectData):
         | 
| 1877 | 
             
                                            text = {
         | 
| 1878 | 
            -
                                                "asset/1.png": " | 
| 1879 | 
            -
                                                "asset/2.png": " | 
| 1880 | 
            -
                                                "asset/3.png": " | 
| 1881 | 
            -
                                                "asset/4.png": " | 
| 1882 | 
            -
                                                "asset/5.png": " | 
| 1883 | 
             
                                            }[template_gallery_path[evt.index]]
         | 
| 1884 | 
             
                                            return template_gallery_path[evt.index], text
         | 
| 1885 |  | 
|  | |
| 17 | 
             
            from diffusers import (AutoencoderKL, DDIMScheduler,
         | 
| 18 | 
             
                                   DPMSolverMultistepScheduler,
         | 
| 19 | 
             
                                   EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
         | 
| 20 | 
            +
                                   FlowMatchEulerDiscreteScheduler, PNDMScheduler)
         | 
| 21 | 
             
            from diffusers.utils.import_utils import is_xformers_available
         | 
| 22 | 
             
            from omegaconf import OmegaConf
         | 
| 23 | 
             
            from PIL import Image
         | 
| 24 | 
             
            from safetensors import safe_open
         | 
| 25 | 
             
            from transformers import (BertModel, BertTokenizer, CLIPImageProcessor,
         | 
| 26 | 
            +
                                      CLIPVisionModelWithProjection, Qwen2Tokenizer,
         | 
| 27 | 
            +
                                      Qwen2VLForConditionalGeneration, T5EncoderModel,
         | 
| 28 | 
            +
                                      T5Tokenizer)
         | 
| 29 |  | 
| 30 | 
            +
            from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
         | 
| 31 | 
            +
            from ..models import (name_to_autoencoder_magvit,
         | 
| 32 | 
             
                                            name_to_transformer3d)
         | 
| 33 | 
            +
            from ..pipeline.pipeline_easyanimate import \
         | 
| 34 | 
            +
                EasyAnimatePipeline
         | 
| 35 | 
            +
            from ..pipeline.pipeline_easyanimate_control import \
         | 
| 36 | 
            +
                EasyAnimateControlPipeline
         | 
| 37 | 
            +
            from ..pipeline.pipeline_easyanimate_inpaint import \
         | 
| 38 | 
             
                EasyAnimateInpaintPipeline
         | 
| 39 | 
            +
            from ..utils.fp8_optimization import convert_weight_dtype_wrapper
         | 
| 40 | 
            +
            from ..utils.lora_utils import merge_lora, unmerge_lora
         | 
| 41 | 
            +
            from ..utils.utils import (
         | 
|  | |
|  | |
|  | |
| 42 | 
             
                get_image_to_video_latent, get_video_to_video_latent,
         | 
| 43 | 
             
                get_width_and_height_from_image_and_base_resolution, save_videos_grid)
         | 
|  | |
| 44 |  | 
| 45 | 
            +
            ddpm_scheduler_dict = {
         | 
| 46 | 
             
                "Euler": EulerDiscreteScheduler,
         | 
| 47 | 
             
                "Euler A": EulerAncestralDiscreteScheduler,
         | 
| 48 | 
             
                "DPM++": DPMSolverMultistepScheduler, 
         | 
| 49 | 
             
                "PNDM": PNDMScheduler,
         | 
| 50 | 
             
                "DDIM": DDIMScheduler,
         | 
| 51 | 
             
            }
         | 
| 52 | 
            +
            flow_scheduler_dict = {
         | 
| 53 | 
            +
                "Flow": FlowMatchEulerDiscreteScheduler,
         | 
| 54 | 
            +
            }
         | 
| 55 | 
            +
            all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict}
         | 
| 56 |  | 
| 57 | 
             
            gradio_version = pkg_resources.get_distribution("gradio").version
         | 
| 58 | 
             
            gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
         | 
|  | |
| 99 | 
             
                    self.GPU_memory_mode       = GPU_memory_mode
         | 
| 100 |  | 
| 101 | 
             
                    self.weight_dtype          = weight_dtype
         | 
| 102 | 
            +
                    self.edition               = "v5.1"
         | 
| 103 | 
            +
                    self.inference_config      = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5.1_magvit_qwen.yaml"))
         | 
| 104 |  | 
| 105 | 
             
                def refresh_diffusion_transformer(self):
         | 
| 106 | 
             
                    self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
         | 
|  | |
| 122 | 
             
                    if edition == "v1":
         | 
| 123 | 
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v1_motion_module.yaml"))
         | 
| 124 | 
             
                        return gr.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
         | 
| 125 | 
            +
                            gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
         | 
| 126 | 
             
                            gr.update(value=512, minimum=384, maximum=704, step=32), \
         | 
| 127 | 
             
                            gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
         | 
| 128 | 
             
                    elif edition == "v2":
         | 
| 129 | 
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v2_magvit_motion_module.yaml"))
         | 
| 130 | 
             
                        return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
         | 
| 131 | 
            +
                            gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
         | 
| 132 | 
             
                            gr.update(value=672, minimum=128, maximum=1344, step=16), \
         | 
| 133 | 
             
                            gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
         | 
| 134 | 
             
                    elif edition == "v3":
         | 
| 135 | 
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v3_slicevae_motion_module.yaml"))
         | 
| 136 | 
             
                        return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
         | 
| 137 | 
            +
                            gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
         | 
| 138 | 
             
                            gr.update(value=672, minimum=128, maximum=1344, step=16), \
         | 
| 139 | 
             
                            gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
         | 
| 140 | 
             
                    elif edition == "v4":
         | 
| 141 | 
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v4_slicevae_multi_text_encoder.yaml"))
         | 
| 142 | 
             
                        return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
         | 
| 143 | 
            +
                            gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
         | 
| 144 | 
             
                            gr.update(value=672, minimum=128, maximum=1344, step=16), \
         | 
| 145 | 
             
                            gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
         | 
| 146 | 
             
                    elif edition == "v5":
         | 
| 147 | 
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5_magvit_multi_text_encoder.yaml"))
         | 
| 148 | 
             
                        return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
         | 
| 149 | 
            +
                            gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \
         | 
| 150 | 
            +
                            gr.update(value=672, minimum=128, maximum=1344, step=16), \
         | 
| 151 | 
            +
                            gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4)
         | 
| 152 | 
            +
                    elif edition == "v5.1":
         | 
| 153 | 
            +
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5.1_magvit_qwen.yaml"))
         | 
| 154 | 
            +
                        return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
         | 
| 155 | 
            +
                            gr.update(choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]), \
         | 
| 156 | 
             
                            gr.update(value=672, minimum=128, maximum=1344, step=16), \
         | 
| 157 | 
             
                            gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4)
         | 
| 158 |  | 
|  | |
| 182 | 
             
                    self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
         | 
| 183 | 
             
                        diffusion_transformer_dropdown, 
         | 
| 184 | 
             
                        subfolder="transformer", 
         | 
| 185 | 
            +
                        transformer_additional_kwargs=transformer_additional_kwargs,
         | 
| 186 | 
            +
                        torch_dtype=torch.float8_e4m3fn if self.GPU_memory_mode == "model_cpu_offload_and_qfloat8" else self.weight_dtype,
         | 
| 187 | 
            +
                        low_cpu_mem_usage=True,
         | 
| 188 | 
            +
                    )
         | 
| 189 |  | 
| 190 | 
             
                    if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
         | 
| 191 | 
             
                        tokenizer = BertTokenizer.from_pretrained(
         | 
| 192 | 
             
                            diffusion_transformer_dropdown, subfolder="tokenizer"
         | 
| 193 | 
             
                        )
         | 
| 194 | 
            +
                        if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
         | 
| 195 | 
            +
                            tokenizer_2 = Qwen2Tokenizer.from_pretrained(
         | 
| 196 | 
            +
                                os.path.join(diffusion_transformer_dropdown, "tokenizer_2")
         | 
| 197 | 
            +
                            )
         | 
| 198 | 
            +
                        else:
         | 
| 199 | 
            +
                            tokenizer_2 = T5Tokenizer.from_pretrained(
         | 
| 200 | 
            +
                                diffusion_transformer_dropdown, subfolder="tokenizer_2"
         | 
| 201 | 
            +
                            )
         | 
| 202 | 
             
                    else:
         | 
| 203 | 
            +
                        if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
         | 
| 204 | 
            +
                            tokenizer = Qwen2Tokenizer.from_pretrained(
         | 
| 205 | 
            +
                                os.path.join(diffusion_transformer_dropdown, "tokenizer")
         | 
| 206 | 
            +
                            )
         | 
| 207 | 
            +
                        else:
         | 
| 208 | 
            +
                            tokenizer = T5Tokenizer.from_pretrained(
         | 
| 209 | 
            +
                                diffusion_transformer_dropdown, subfolder="tokenizer"
         | 
| 210 | 
            +
                            )
         | 
| 211 | 
             
                        tokenizer_2 = None
         | 
| 212 |  | 
| 213 | 
             
                    if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
         | 
| 214 | 
             
                        text_encoder = BertModel.from_pretrained(
         | 
| 215 | 
             
                            diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
         | 
| 216 | 
             
                        )
         | 
| 217 | 
            +
                        if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
         | 
| 218 | 
            +
                            text_encoder_2 = Qwen2VLForConditionalGeneration.from_pretrained(
         | 
| 219 | 
            +
                                os.path.join(diffusion_transformer_dropdown, "text_encoder_2")
         | 
| 220 | 
            +
                            )
         | 
| 221 | 
            +
                        else:
         | 
| 222 | 
            +
                            text_encoder_2 = T5EncoderModel.from_pretrained(
         | 
| 223 | 
            +
                                diffusion_transformer_dropdown, subfolder="text_encoder_2", torch_dtype=self.weight_dtype
         | 
| 224 | 
            +
                            )
         | 
| 225 | 
            +
                    else:  
         | 
| 226 | 
            +
                        if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
         | 
| 227 | 
            +
                            text_encoder = Qwen2VLForConditionalGeneration.from_pretrained(
         | 
| 228 | 
            +
                                os.path.join(diffusion_transformer_dropdown, "text_encoder")
         | 
| 229 | 
            +
                            )
         | 
| 230 | 
            +
                        else:
         | 
| 231 | 
            +
                            text_encoder = T5EncoderModel.from_pretrained(
         | 
| 232 | 
            +
                                diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
         | 
| 233 | 
            +
                            )
         | 
| 234 | 
             
                        text_encoder_2 = None
         | 
| 235 |  | 
| 236 | 
             
                    # Get pipeline
         | 
|  | |
| 246 | 
             
                        clip_image_processor = None
         | 
| 247 |  | 
| 248 | 
             
                    # Get Scheduler
         | 
| 249 | 
            +
                    if self.edition in ["v5.1"]:
         | 
| 250 | 
            +
                        Choosen_Scheduler = all_cheduler_dict["Flow"]
         | 
| 251 | 
            +
                    else:
         | 
| 252 | 
            +
                        Choosen_Scheduler = all_cheduler_dict["Euler"]
         | 
|  | |
|  | |
|  | |
|  | |
| 253 | 
             
                    scheduler = Choosen_Scheduler.from_pretrained(
         | 
| 254 | 
             
                        diffusion_transformer_dropdown, 
         | 
| 255 | 
             
                        subfolder="scheduler"
         | 
| 256 | 
             
                    )
         | 
| 257 |  | 
| 258 | 
            +
                    if self.model_type == "Inpaint":
         | 
| 259 | 
             
                        if self.transformer.config.in_channels != self.vae.config.latent_channels:
         | 
| 260 | 
            +
                            self.pipeline = EasyAnimateInpaintPipeline(
         | 
|  | |
| 261 | 
             
                                text_encoder=text_encoder,
         | 
| 262 | 
             
                                text_encoder_2=text_encoder_2,
         | 
| 263 | 
             
                                tokenizer=tokenizer,
         | 
|  | |
| 265 | 
             
                                vae=self.vae,
         | 
| 266 | 
             
                                transformer=self.transformer,
         | 
| 267 | 
             
                                scheduler=scheduler,
         | 
|  | |
| 268 | 
             
                                clip_image_encoder=clip_image_encoder,
         | 
| 269 | 
             
                                clip_image_processor=clip_image_processor,
         | 
| 270 | 
            +
                            ).to(self.weight_dtype)
         | 
| 271 | 
             
                        else:
         | 
| 272 | 
            +
                            self.pipeline = EasyAnimatePipeline(
         | 
|  | |
| 273 | 
             
                                text_encoder=text_encoder,
         | 
| 274 | 
             
                                text_encoder_2=text_encoder_2,
         | 
| 275 | 
             
                                tokenizer=tokenizer,
         | 
|  | |
| 277 | 
             
                                vae=self.vae,
         | 
| 278 | 
             
                                transformer=self.transformer,
         | 
| 279 | 
             
                                scheduler=scheduler,
         | 
| 280 | 
            +
                            ).to(self.weight_dtype)
         | 
|  | |
| 281 | 
             
                    else:
         | 
| 282 | 
            +
                        self.pipeline = EasyAnimateControlPipeline(
         | 
| 283 | 
            +
                            text_encoder=text_encoder,
         | 
| 284 | 
            +
                            text_encoder_2=text_encoder_2,
         | 
| 285 | 
            +
                            tokenizer=tokenizer,
         | 
| 286 | 
            +
                            tokenizer_2=tokenizer_2,
         | 
| 287 | 
            +
                            vae=self.vae,
         | 
| 288 | 
            +
                            transformer=self.transformer,
         | 
| 289 | 
            +
                            scheduler=scheduler,
         | 
| 290 | 
            +
                        ).to(self.weight_dtype)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 291 |  | 
| 292 | 
             
                    if self.GPU_memory_mode == "sequential_cpu_offload":
         | 
| 293 | 
             
                        self.pipeline.enable_sequential_cpu_offload()
         | 
| 294 | 
             
                    elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
         | 
| 295 | 
             
                        self.pipeline.enable_model_cpu_offload()
         | 
|  | |
| 296 | 
             
                        convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
         | 
| 297 | 
             
                    else:
         | 
| 298 | 
            +
                        self.pipeline.enable_model_cpu_offload()
         | 
| 299 | 
             
                    print("Update diffusion transformer done")
         | 
| 300 | 
             
                    return gr.update()
         | 
| 301 |  | 
|  | |
| 386 | 
             
                    if self.base_model_path != base_model_dropdown:
         | 
| 387 | 
             
                        self.update_base_model(base_model_dropdown)
         | 
| 388 |  | 
| 389 | 
            +
                    if self.motion_module_path != motion_module_dropdown:
         | 
| 390 | 
            +
                        self.update_motion_module(motion_module_dropdown)
         | 
| 391 | 
            +
             | 
| 392 | 
             
                    if self.lora_model_path != lora_model_dropdown:
         | 
|  | |
| 393 | 
             
                        self.update_lora_model(lora_model_dropdown)
         | 
| 394 |  | 
| 395 | 
             
                    if control_video is not None and self.model_type == "Inpaint":
         | 
|  | |
| 440 | 
             
                        else:
         | 
| 441 | 
             
                            raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
         | 
| 442 |  | 
| 443 | 
            +
                    fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8, "v5.1": 8}[self.edition]
         | 
| 444 | 
             
                    is_image = True if generation_method == "Image Generation" else False
         | 
| 445 |  | 
| 446 | 
            +
                    if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
         | 
| 447 | 
            +
                    else: seed_textbox = np.random.randint(0, 1e10)
         | 
| 448 | 
            +
                    generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
         | 
| 449 |  | 
| 450 | 
            +
                    if is_xformers_available() \
         | 
| 451 | 
            +
                        and self.inference_config['transformer_additional_kwargs'].get('transformer_type', 'Transformer3DModel') == 'Transformer3DModel':
         | 
| 452 | 
            +
                        self.transformer.enable_xformers_memory_efficient_attention()
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                    self.pipeline.scheduler = all_cheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
         | 
| 455 | 
             
                    if self.lora_model_path != "none":
         | 
| 456 | 
             
                        # lora part
         | 
| 457 | 
             
                        self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
         | 
|  | |
|  | |
|  | |
|  | |
| 458 |  | 
| 459 | 
             
                    try:
         | 
| 460 | 
             
                        if self.model_type == "Inpaint":
         | 
|  | |
| 496 | 
             
                                                video        = input_video,
         | 
| 497 | 
             
                                                mask_video   = input_video_mask,
         | 
| 498 | 
             
                                                strength     = 1,
         | 
| 499 | 
            +
                                            ).frames
         | 
| 500 |  | 
| 501 | 
             
                                        if init_frames != 0:
         | 
| 502 | 
             
                                            mix_ratio = torch.from_numpy(
         | 
|  | |
| 547 | 
             
                                        video        = input_video,
         | 
| 548 | 
             
                                        mask_video   = input_video_mask,
         | 
| 549 | 
             
                                        strength     = strength,
         | 
| 550 | 
            +
                                    ).frames
         | 
| 551 | 
             
                            else:
         | 
| 552 | 
             
                                if self.vae.cache_mag_vae:
         | 
| 553 | 
             
                                    length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
         | 
|  | |
| 563 | 
             
                                    height              = height_slider,
         | 
| 564 | 
             
                                    video_length        = length_slider if not is_image else 1,
         | 
| 565 | 
             
                                    generator           = generator
         | 
| 566 | 
            +
                                ).frames
         | 
| 567 | 
             
                        else:
         | 
| 568 | 
             
                            if self.vae.cache_mag_vae:
         | 
| 569 | 
             
                                length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
         | 
|  | |
| 582 | 
             
                                generator           = generator,
         | 
| 583 |  | 
| 584 | 
             
                                control_video = input_video,
         | 
| 585 | 
            +
                            ).frames
         | 
| 586 | 
             
                    except Exception as e:
         | 
| 587 | 
             
                        gc.collect()
         | 
| 588 | 
             
                        torch.cuda.empty_cache()
         | 
|  | |
| 692 | 
             
                        with gr.Row():
         | 
| 693 | 
             
                            easyanimate_edition_dropdown = gr.Dropdown(
         | 
| 694 | 
             
                                label="The config of EasyAnimate Edition (EasyAnimate版本配置)",
         | 
| 695 | 
            +
                                choices=["v1", "v2", "v3", "v4", "v5", "v5.1"],
         | 
| 696 | 
            +
                                value="v5.1",
         | 
| 697 | 
             
                                interactive=True,
         | 
| 698 | 
             
                            )
         | 
| 699 | 
             
                        gr.Markdown(
         | 
|  | |
| 767 | 
             
                            """
         | 
| 768 | 
             
                        )
         | 
| 769 |  | 
| 770 | 
            +
                        prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.")
         | 
| 771 | 
            +
                        gr.Markdown(
         | 
| 772 | 
            +
                            """
         | 
| 773 | 
            +
                            Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism.   
         | 
| 774 | 
            +
                            使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。
         | 
| 775 | 
            +
                            """
         | 
| 776 | 
            +
                        )
         | 
| 777 | 
            +
                        negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." )
         | 
| 778 |  | 
| 779 | 
             
                        with gr.Row():
         | 
| 780 | 
             
                            with gr.Column():
         | 
| 781 | 
             
                                with gr.Row():
         | 
| 782 | 
            +
                                    sampler_dropdown   = gr.Dropdown(
         | 
| 783 | 
            +
                                        label="Sampling method (采样器种类)", 
         | 
| 784 | 
            +
                                        choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]
         | 
| 785 | 
            +
                                    )
         | 
| 786 | 
             
                                    sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=100, step=1)
         | 
| 787 |  | 
| 788 | 
             
                                resize_method = gr.Radio(
         | 
|  | |
| 819 | 
             
                                    template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
         | 
| 820 | 
             
                                    def select_template(evt: gr.SelectData):
         | 
| 821 | 
             
                                        text = {
         | 
| 822 | 
            +
                                            "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.", 
         | 
| 823 | 
            +
                                            "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.", 
         | 
| 824 | 
            +
                                            "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.", 
         | 
| 825 | 
            +
                                            "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.", 
         | 
| 826 | 
            +
                                            "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.", 
         | 
| 827 | 
             
                                        }[template_gallery_path[evt.index]]
         | 
| 828 | 
             
                                        return template_gallery_path[evt.index], text
         | 
| 829 |  | 
|  | |
| 863 | 
             
                                    gr.Markdown(
         | 
| 864 | 
             
                                        """
         | 
| 865 | 
             
                                        Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
         | 
| 866 | 
            +
                                        Only normal controls are supported in app.py; trajectory control and camera control need ComfyUI, as shown in https://github.com/aigc-apps/EasyAnimate/tree/main/comfyui.
         | 
| 867 | 
             
                                        """
         | 
| 868 | 
             
                                    )
         | 
| 869 | 
             
                                    control_video = gr.Video(
         | 
|  | |
| 953 | 
             
                                diffusion_transformer_dropdown, 
         | 
| 954 | 
             
                                motion_module_dropdown, 
         | 
| 955 | 
             
                                motion_module_refresh_button, 
         | 
| 956 | 
            +
                                sampler_dropdown, 
         | 
| 957 | 
             
                                width_slider, 
         | 
| 958 | 
             
                                height_slider, 
         | 
| 959 | 
             
                                length_slider, 
         | 
|  | |
| 1030 | 
             
                    self.transformer = Choosen_Transformer3DModel.from_pretrained_2d(
         | 
| 1031 | 
             
                        model_name, 
         | 
| 1032 | 
             
                        subfolder="transformer", 
         | 
| 1033 | 
            +
                        transformer_additional_kwargs=transformer_additional_kwargs,
         | 
| 1034 | 
            +
                        torch_dtype=torch.float8_e4m3fn if GPU_memory_mode == "model_cpu_offload_and_qfloat8" else weight_dtype,
         | 
| 1035 | 
            +
                        low_cpu_mem_usage=True,
         | 
| 1036 | 
            +
                    )
         | 
| 1037 |  | 
| 1038 | 
             
                    if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
         | 
| 1039 | 
             
                        tokenizer = BertTokenizer.from_pretrained(
         | 
| 1040 | 
             
                            model_name, subfolder="tokenizer"
         | 
| 1041 | 
             
                        )
         | 
| 1042 | 
            +
                        if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
         | 
| 1043 | 
            +
                            tokenizer_2 = Qwen2Tokenizer.from_pretrained(
         | 
| 1044 | 
            +
                                os.path.join(model_name, "tokenizer_2")
         | 
| 1045 | 
            +
                            )
         | 
| 1046 | 
            +
                        else:
         | 
| 1047 | 
            +
                            tokenizer_2 = T5Tokenizer.from_pretrained(
         | 
| 1048 | 
            +
                                model_name, subfolder="tokenizer_2"
         | 
| 1049 | 
            +
                            )
         | 
| 1050 | 
             
                    else:
         | 
| 1051 | 
            +
                        if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
         | 
| 1052 | 
            +
                            tokenizer = Qwen2Tokenizer.from_pretrained(
         | 
| 1053 | 
            +
                                os.path.join(model_name, "tokenizer")
         | 
| 1054 | 
            +
                            )
         | 
| 1055 | 
            +
                        else:
         | 
| 1056 | 
            +
                            tokenizer = T5Tokenizer.from_pretrained(
         | 
| 1057 | 
            +
                                model_name, subfolder="tokenizer"
         | 
| 1058 | 
            +
                            )
         | 
| 1059 | 
             
                        tokenizer_2 = None
         | 
| 1060 |  | 
| 1061 | 
             
                    if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False):
         | 
| 1062 | 
             
                        text_encoder = BertModel.from_pretrained(
         | 
| 1063 | 
             
                            model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype
         | 
| 1064 | 
             
                        )
         | 
| 1065 | 
            +
                        if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
         | 
| 1066 | 
            +
                            text_encoder_2 = Qwen2VLForConditionalGeneration.from_pretrained(
         | 
| 1067 | 
            +
                                os.path.join(model_name, "text_encoder_2"), torch_dtype=self.weight_dtype
         | 
| 1068 | 
            +
                            )
         | 
| 1069 | 
            +
                        else:
         | 
| 1070 | 
            +
                            text_encoder_2 = T5EncoderModel.from_pretrained(
         | 
| 1071 | 
            +
                                model_name, subfolder="text_encoder_2", torch_dtype=self.weight_dtype
         | 
| 1072 | 
            +
                            )
         | 
| 1073 | 
            +
                    else:  
         | 
| 1074 | 
            +
                        if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False):
         | 
| 1075 | 
            +
                            text_encoder = Qwen2VLForConditionalGeneration.from_pretrained(
         | 
| 1076 | 
            +
                                os.path.join(model_name, "text_encoder"), torch_dtype=self.weight_dtype
         | 
| 1077 | 
            +
                            )
         | 
| 1078 | 
            +
                        else:
         | 
| 1079 | 
            +
                            text_encoder = T5EncoderModel.from_pretrained(
         | 
| 1080 | 
            +
                                model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype
         | 
| 1081 | 
            +
                            )
         | 
| 1082 | 
             
                        text_encoder_2 = None
         | 
| 1083 |  | 
| 1084 | 
             
                    # Get pipeline
         | 
|  | |
| 1094 | 
             
                        clip_image_processor = None
         | 
| 1095 |  | 
| 1096 | 
             
                    # Get Scheduler
         | 
| 1097 | 
            +
                    if self.edition in ["v5.1"]:
         | 
| 1098 | 
            +
                        Choosen_Scheduler = all_cheduler_dict["Flow"]
         | 
| 1099 | 
            +
                    else:
         | 
| 1100 | 
            +
                        Choosen_Scheduler = all_cheduler_dict["Euler"]
         | 
|  | |
|  | |
|  | |
|  | |
| 1101 | 
             
                    scheduler = Choosen_Scheduler.from_pretrained(
         | 
| 1102 | 
             
                        model_name, 
         | 
| 1103 | 
             
                        subfolder="scheduler"
         | 
| 1104 | 
             
                    )
         | 
| 1105 |  | 
| 1106 | 
            +
                    if model_type == "Inpaint":
         | 
| 1107 | 
             
                        if self.transformer.config.in_channels != self.vae.config.latent_channels:
         | 
| 1108 | 
            +
                            self.pipeline = EasyAnimateInpaintPipeline(
         | 
|  | |
| 1109 | 
             
                                text_encoder=text_encoder,
         | 
| 1110 | 
             
                                text_encoder_2=text_encoder_2,
         | 
| 1111 | 
             
                                tokenizer=tokenizer,
         | 
|  | |
| 1113 | 
             
                                vae=self.vae,
         | 
| 1114 | 
             
                                transformer=self.transformer,
         | 
| 1115 | 
             
                                scheduler=scheduler,
         | 
|  | |
| 1116 | 
             
                                clip_image_encoder=clip_image_encoder,
         | 
| 1117 | 
             
                                clip_image_processor=clip_image_processor,
         | 
| 1118 | 
            +
                            ).to(weight_dtype)
         | 
| 1119 | 
             
                        else:
         | 
| 1120 | 
            +
                            self.pipeline = EasyAnimatePipeline(
         | 
|  | |
| 1121 | 
             
                                text_encoder=text_encoder,
         | 
| 1122 | 
             
                                text_encoder_2=text_encoder_2,
         | 
| 1123 | 
             
                                tokenizer=tokenizer,
         | 
| 1124 | 
             
                                tokenizer_2=tokenizer_2,
         | 
| 1125 | 
             
                                vae=self.vae,
         | 
| 1126 | 
             
                                transformer=self.transformer,
         | 
| 1127 | 
            +
                                scheduler=scheduler
         | 
| 1128 | 
            +
                            ).to(weight_dtype)
         | 
|  | |
| 1129 | 
             
                    else:
         | 
| 1130 | 
            +
                        self.pipeline = EasyAnimateControlPipeline(
         | 
| 1131 | 
            +
                            text_encoder=text_encoder,
         | 
| 1132 | 
            +
                            text_encoder_2=text_encoder_2,
         | 
| 1133 | 
            +
                            tokenizer=tokenizer,
         | 
| 1134 | 
            +
                            tokenizer_2=tokenizer_2,
         | 
| 1135 | 
            +
                            vae=self.vae,
         | 
| 1136 | 
            +
                            transformer=self.transformer,
         | 
| 1137 | 
            +
                            scheduler=scheduler,
         | 
| 1138 | 
            +
                        ).to(weight_dtype)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 1139 |  | 
| 1140 | 
             
                    if GPU_memory_mode == "sequential_cpu_offload":
         | 
| 1141 | 
             
                        self.pipeline.enable_sequential_cpu_offload()
         | 
| 1142 | 
             
                    elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
         | 
| 1143 | 
             
                        self.pipeline.enable_model_cpu_offload()
         | 
|  | |
| 1144 | 
             
                        convert_weight_dtype_wrapper(self.pipeline.transformer, weight_dtype)
         | 
| 1145 | 
             
                    else:
         | 
| 1146 | 
             
                        GPU_memory_mode.enable_model_cpu_offload()
         | 
|  | |
| 1241 | 
             
                        else:
         | 
| 1242 | 
             
                            raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
         | 
| 1243 |  | 
| 1244 | 
            +
                    fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8, "v5.1": 8}[self.edition]
         | 
| 1245 | 
             
                    is_image = True if generation_method == "Image Generation" else False
         | 
| 1246 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 1247 | 
             
                    if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
         | 
| 1248 | 
             
                    else: seed_textbox = np.random.randint(0, 1e10)
         | 
| 1249 | 
             
                    generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
         | 
| 1250 | 
            +
             | 
| 1251 | 
            +
                    self.pipeline.scheduler = all_cheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
         | 
| 1252 | 
            +
                    if self.lora_model_path != "none":
         | 
| 1253 | 
            +
                        # lora part
         | 
| 1254 | 
            +
                        self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
         | 
| 1255 |  | 
| 1256 | 
             
                    try:
         | 
| 1257 | 
             
                        if self.model_type == "Inpaint":
         | 
|  | |
| 1281 | 
             
                                    video        = input_video,
         | 
| 1282 | 
             
                                    mask_video   = input_video_mask,
         | 
| 1283 | 
             
                                    strength     = strength,
         | 
| 1284 | 
            +
                                ).frames
         | 
| 1285 | 
             
                            else:
         | 
| 1286 | 
             
                                sample = self.pipeline(
         | 
| 1287 | 
             
                                    prompt_textbox,
         | 
|  | |
| 1292 | 
             
                                    height              = height_slider,
         | 
| 1293 | 
             
                                    video_length        = length_slider if not is_image else 1,
         | 
| 1294 | 
             
                                    generator           = generator
         | 
| 1295 | 
            +
                                ).frames
         | 
| 1296 | 
             
                        else:
         | 
| 1297 | 
             
                            if self.vae.cache_mag_vae:
         | 
| 1298 | 
             
                                length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1
         | 
|  | |
| 1312 | 
             
                                generator           = generator,
         | 
| 1313 |  | 
| 1314 | 
             
                                control_video = input_video,
         | 
| 1315 | 
            +
                            ).frames
         | 
| 1316 | 
             
                    except Exception as e:
         | 
| 1317 | 
             
                        gc.collect()
         | 
| 1318 | 
             
                        torch.cuda.empty_cache()
         | 
|  | |
| 1433 | 
             
                            """
         | 
| 1434 | 
             
                        )
         | 
| 1435 |  | 
| 1436 | 
            +
                        prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.")
         | 
| 1437 | 
            +
                        gr.Markdown(
         | 
| 1438 | 
            +
                            """
         | 
| 1439 | 
            +
                            Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism.   
         | 
| 1440 | 
            +
                            使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。
         | 
| 1441 | 
            +
                            """
         | 
| 1442 | 
            +
                        )
         | 
| 1443 | 
            +
                        negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." )
         | 
| 1444 |  | 
| 1445 | 
             
                        with gr.Row():
         | 
| 1446 | 
             
                            with gr.Column():
         | 
| 1447 | 
             
                                with gr.Row():
         | 
| 1448 | 
            +
                                    if edition in ["v5.1"]:
         | 
| 1449 | 
            +
                                        sampler_dropdown   = gr.Dropdown(
         | 
| 1450 | 
            +
                                            label="Sampling method (采样器种类)", 
         | 
| 1451 | 
            +
                                            choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]
         | 
| 1452 | 
            +
                                        )
         | 
| 1453 | 
            +
                                    else:
         | 
| 1454 | 
            +
                                        sampler_dropdown   = gr.Dropdown(
         | 
| 1455 | 
            +
                                            label="Sampling method (采样器种类)", 
         | 
| 1456 | 
            +
                                            choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]
         | 
| 1457 | 
            +
                                        )
         | 
| 1458 | 
             
                                    sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=50, step=1, interactive=False)
         | 
| 1459 |  | 
| 1460 | 
             
                                if edition == "v1":
         | 
|  | |
| 1508 | 
             
                                        template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
         | 
| 1509 | 
             
                                        def select_template(evt: gr.SelectData):
         | 
| 1510 | 
             
                                            text = {
         | 
| 1511 | 
            +
                                                "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.", 
         | 
| 1512 | 
            +
                                                "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.", 
         | 
| 1513 | 
            +
                                                "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.", 
         | 
| 1514 | 
            +
                                                "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.", 
         | 
| 1515 | 
            +
                                                "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.", 
         | 
| 1516 | 
             
                                            }[template_gallery_path[evt.index]]
         | 
| 1517 | 
             
                                            return template_gallery_path[evt.index], text
         | 
| 1518 |  | 
|  | |
| 1552 | 
             
                                        gr.Markdown(
         | 
| 1553 | 
             
                                            """
         | 
| 1554 | 
             
                                            Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
         | 
| 1555 | 
            +
                                            Only normal controls are supported in app.py; trajectory control and camera control need ComfyUI, as shown in https://github.com/aigc-apps/EasyAnimate/tree/main/comfyui.
         | 
| 1556 | 
             
                                            """
         | 
| 1557 | 
             
                                        )
         | 
| 1558 | 
             
                                        control_video = gr.Video(
         | 
|  | |
| 1863 | 
             
                            """
         | 
| 1864 | 
             
                        )
         | 
| 1865 |  | 
| 1866 | 
            +
                        prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.")
         | 
| 1867 | 
            +
                        gr.Markdown(
         | 
| 1868 | 
            +
                            """
         | 
| 1869 | 
            +
                            Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism.   
         | 
| 1870 | 
            +
                            使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。
         | 
| 1871 | 
            +
                            """
         | 
| 1872 | 
            +
                        )
         | 
| 1873 | 
            +
                        negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." )
         | 
| 1874 |  | 
| 1875 | 
             
                        with gr.Row():
         | 
| 1876 | 
             
                            with gr.Column():
         | 
| 1877 | 
             
                                with gr.Row():
         | 
| 1878 | 
            +
                                    if edition in ["v5.1"]:
         | 
| 1879 | 
            +
                                        sampler_dropdown   = gr.Dropdown(
         | 
| 1880 | 
            +
                                            label="Sampling method (采样器种类)", 
         | 
| 1881 | 
            +
                                            choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]
         | 
| 1882 | 
            +
                                        )
         | 
| 1883 | 
            +
                                    else:
         | 
| 1884 | 
            +
                                        sampler_dropdown   = gr.Dropdown(
         | 
| 1885 | 
            +
                                            label="Sampling method (采样器种类)", 
         | 
| 1886 | 
            +
                                            choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]
         | 
| 1887 | 
            +
                                        )
         | 
| 1888 | 
             
                                    sample_step_slider = gr.Slider(label="Sampling steps", value=40, minimum=10, maximum=40, step=1, interactive=False)
         | 
| 1889 |  | 
| 1890 | 
             
                                if edition == "v1":
         | 
|  | |
| 1933 | 
             
                                        template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
         | 
| 1934 | 
             
                                        def select_template(evt: gr.SelectData):
         | 
| 1935 | 
             
                                            text = {
         | 
| 1936 | 
            +
                                                "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.", 
         | 
| 1937 | 
            +
                                                "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.", 
         | 
| 1938 | 
            +
                                                "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.", 
         | 
| 1939 | 
            +
                                                "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.", 
         | 
| 1940 | 
            +
                                                "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.", 
         | 
| 1941 | 
             
                                            }[template_gallery_path[evt.index]]
         | 
| 1942 | 
             
                                            return template_gallery_path[evt.index], text
         | 
| 1943 |  | 
    	
        easyanimate/utils/lora_utils.py
    CHANGED
    
    | @@ -369,7 +369,6 @@ def create_network( | |
| 369 | 
             
            def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
         | 
| 370 | 
             
                LORA_PREFIX_TRANSFORMER = "lora_unet"
         | 
| 371 | 
             
                LORA_PREFIX_TEXT_ENCODER = "lora_te"
         | 
| 372 | 
            -
                SPECIAL_LAYER_NAME = ["text_proj_t5"]
         | 
| 373 | 
             
                if state_dict is None:
         | 
| 374 | 
             
                    state_dict = load_file(lora_path, device=device)
         | 
| 375 | 
             
                else:
         | 
| @@ -410,20 +409,25 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3 | |
| 410 | 
             
                                else:
         | 
| 411 | 
             
                                    temp_name = layer_infos.pop(0)
         | 
| 412 |  | 
| 413 | 
            -
                     | 
| 414 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 415 | 
             
                    if 'alpha' in elems.keys():
         | 
| 416 | 
             
                        alpha = elems['alpha'].item() / weight_up.shape[1]
         | 
| 417 | 
             
                    else:
         | 
| 418 | 
             
                        alpha = 1.0
         | 
| 419 |  | 
| 420 | 
            -
                    curr_layer.weight.data = curr_layer.weight.data.to(device)
         | 
| 421 | 
             
                    if len(weight_up.shape) == 4:
         | 
| 422 | 
            -
                        curr_layer.weight.data += multiplier * alpha * torch.mm( | 
| 423 | 
            -
             | 
| 424 | 
            -
             | 
| 425 | 
             
                    else:
         | 
| 426 | 
             
                        curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
         | 
|  | |
| 427 |  | 
| 428 | 
             
                return pipeline
         | 
| 429 |  | 
| @@ -448,35 +452,43 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl | |
| 448 | 
             
                        layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
         | 
| 449 | 
             
                        curr_layer = pipeline.transformer
         | 
| 450 |  | 
| 451 | 
            -
                     | 
| 452 | 
            -
             | 
| 453 | 
            -
                     | 
| 454 | 
            -
                         | 
| 455 | 
            -
             | 
| 456 | 
            -
                             | 
| 457 | 
            -
                                 | 
| 458 | 
            -
             | 
| 459 | 
            -
             | 
| 460 | 
            -
             | 
| 461 | 
            -
             | 
| 462 | 
            -
             | 
| 463 | 
            -
             | 
| 464 | 
            -
             | 
| 465 | 
            -
             | 
| 466 | 
            -
             | 
| 467 | 
            -
             | 
| 468 | 
            -
             | 
| 469 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 470 | 
             
                    if 'alpha' in elems.keys():
         | 
| 471 | 
             
                        alpha = elems['alpha'].item() / weight_up.shape[1]
         | 
| 472 | 
             
                    else:
         | 
| 473 | 
             
                        alpha = 1.0
         | 
| 474 |  | 
| 475 | 
            -
                    curr_layer.weight.data = curr_layer.weight.data.to(device)
         | 
| 476 | 
             
                    if len(weight_up.shape) == 4:
         | 
| 477 | 
            -
                        curr_layer.weight.data -= multiplier * alpha * torch.mm( | 
| 478 | 
            -
             | 
|  | |
| 479 | 
             
                    else:
         | 
| 480 | 
             
                        curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
         | 
|  | |
| 481 |  | 
| 482 | 
            -
                return pipeline
         | 
|  | |
| 369 | 
             
            def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
         | 
| 370 | 
             
                LORA_PREFIX_TRANSFORMER = "lora_unet"
         | 
| 371 | 
             
                LORA_PREFIX_TEXT_ENCODER = "lora_te"
         | 
|  | |
| 372 | 
             
                if state_dict is None:
         | 
| 373 | 
             
                    state_dict = load_file(lora_path, device=device)
         | 
| 374 | 
             
                else:
         | 
|  | |
| 409 | 
             
                                else:
         | 
| 410 | 
             
                                    temp_name = layer_infos.pop(0)
         | 
| 411 |  | 
| 412 | 
            +
                    origin_dtype = curr_layer.weight.data.dtype
         | 
| 413 | 
            +
                    origin_device = curr_layer.weight.data.device
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    curr_layer = curr_layer.to(device, dtype)
         | 
| 416 | 
            +
                    weight_up = elems['lora_up.weight'].to(device, dtype)
         | 
| 417 | 
            +
                    weight_down = elems['lora_down.weight'].to(device, dtype)
         | 
| 418 | 
            +
                    
         | 
| 419 | 
             
                    if 'alpha' in elems.keys():
         | 
| 420 | 
             
                        alpha = elems['alpha'].item() / weight_up.shape[1]
         | 
| 421 | 
             
                    else:
         | 
| 422 | 
             
                        alpha = 1.0
         | 
| 423 |  | 
|  | |
| 424 | 
             
                    if len(weight_up.shape) == 4:
         | 
| 425 | 
            +
                        curr_layer.weight.data += multiplier * alpha * torch.mm(
         | 
| 426 | 
            +
                            weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
         | 
| 427 | 
            +
                        ).unsqueeze(2).unsqueeze(3)
         | 
| 428 | 
             
                    else:
         | 
| 429 | 
             
                        curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
         | 
| 430 | 
            +
                    curr_layer = curr_layer.to(origin_device, origin_dtype)
         | 
| 431 |  | 
| 432 | 
             
                return pipeline
         | 
| 433 |  | 
|  | |
| 452 | 
             
                        layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
         | 
| 453 | 
             
                        curr_layer = pipeline.transformer
         | 
| 454 |  | 
| 455 | 
            +
                    try:
         | 
| 456 | 
            +
                        curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
         | 
| 457 | 
            +
                    except Exception:
         | 
| 458 | 
            +
                        temp_name = layer_infos.pop(0)
         | 
| 459 | 
            +
                        while len(layer_infos) > -1:
         | 
| 460 | 
            +
                            try:
         | 
| 461 | 
            +
                                curr_layer = curr_layer.__getattr__(temp_name)
         | 
| 462 | 
            +
                                if len(layer_infos) > 0:
         | 
| 463 | 
            +
                                    temp_name = layer_infos.pop(0)
         | 
| 464 | 
            +
                                elif len(layer_infos) == 0:
         | 
| 465 | 
            +
                                    break
         | 
| 466 | 
            +
                            except Exception:
         | 
| 467 | 
            +
                                if len(layer_infos) == 0:
         | 
| 468 | 
            +
                                    print('Error loading layer')
         | 
| 469 | 
            +
                                if len(temp_name) > 0:
         | 
| 470 | 
            +
                                    temp_name += "_" + layer_infos.pop(0)
         | 
| 471 | 
            +
                                else:
         | 
| 472 | 
            +
                                    temp_name = layer_infos.pop(0)
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    origin_dtype = curr_layer.weight.data.dtype
         | 
| 475 | 
            +
                    origin_device = curr_layer.weight.data.device
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    curr_layer = curr_layer.to(device, dtype)
         | 
| 478 | 
            +
                    weight_up = elems['lora_up.weight'].to(device, dtype)
         | 
| 479 | 
            +
                    weight_down = elems['lora_down.weight'].to(device, dtype)
         | 
| 480 | 
            +
                    
         | 
| 481 | 
             
                    if 'alpha' in elems.keys():
         | 
| 482 | 
             
                        alpha = elems['alpha'].item() / weight_up.shape[1]
         | 
| 483 | 
             
                    else:
         | 
| 484 | 
             
                        alpha = 1.0
         | 
| 485 |  | 
|  | |
| 486 | 
             
                    if len(weight_up.shape) == 4:
         | 
| 487 | 
            +
                        curr_layer.weight.data -= multiplier * alpha * torch.mm(
         | 
| 488 | 
            +
                            weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
         | 
| 489 | 
            +
                        ).unsqueeze(2).unsqueeze(3)
         | 
| 490 | 
             
                    else:
         | 
| 491 | 
             
                        curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
         | 
| 492 | 
            +
                    curr_layer = curr_layer.to(origin_device, origin_dtype)
         | 
| 493 |  | 
| 494 | 
            +
                return pipeline
         | 
    	
        easyanimate/utils/utils.py
    CHANGED
    
    | @@ -169,47 +169,67 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide | |
| 169 | 
             
                return  input_video, input_video_mask, clip_image
         | 
| 170 |  | 
| 171 | 
             
            def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None):
         | 
| 172 | 
            -
                if  | 
| 173 | 
            -
                     | 
| 174 | 
            -
             | 
|  | |
| 175 |  | 
| 176 | 
            -
             | 
| 177 | 
            -
             | 
| 178 |  | 
| 179 | 
            -
             | 
| 180 |  | 
| 181 | 
            -
             | 
| 182 | 
            -
             | 
| 183 | 
            -
             | 
| 184 | 
            -
             | 
| 185 |  | 
| 186 | 
            -
             | 
| 187 | 
            -
             | 
| 188 | 
            -
             | 
| 189 |  | 
| 190 | 
            -
             | 
| 191 |  | 
| 192 | 
            -
             | 
| 193 | 
            -
             | 
| 194 | 
            -
             | 
|  | |
|  | |
|  | |
| 195 |  | 
| 196 | 
            -
             | 
| 197 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 198 |  | 
| 199 | 
             
                if ref_image is not None:
         | 
| 200 | 
            -
                     | 
| 201 | 
            -
             | 
| 202 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 203 |  | 
| 204 | 
            -
             | 
| 205 | 
            -
             | 
| 206 | 
            -
                     | 
| 207 | 
            -
             | 
| 208 | 
            -
             | 
| 209 | 
            -
             | 
| 210 | 
            -
             | 
| 211 | 
            -
             | 
| 212 | 
            -
             | 
| 213 | 
            -
             | 
| 214 |  | 
| 215 | 
            -
                return | 
|  | |
| 169 | 
             
                return  input_video, input_video_mask, clip_image
         | 
| 170 |  | 
| 171 | 
             
            def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None):
         | 
| 172 | 
            +
                if input_video_path is not None:
         | 
| 173 | 
            +
                    if isinstance(input_video_path, str):
         | 
| 174 | 
            +
                        cap = cv2.VideoCapture(input_video_path)
         | 
| 175 | 
            +
                        input_video = []
         | 
| 176 |  | 
| 177 | 
            +
                        original_fps = cap.get(cv2.CAP_PROP_FPS)
         | 
| 178 | 
            +
                        frame_skip = 1 if fps is None else int(original_fps // fps)
         | 
| 179 |  | 
| 180 | 
            +
                        frame_count = 0
         | 
| 181 |  | 
| 182 | 
            +
                        while True:
         | 
| 183 | 
            +
                            ret, frame = cap.read()
         | 
| 184 | 
            +
                            if not ret:
         | 
| 185 | 
            +
                                break
         | 
| 186 |  | 
| 187 | 
            +
                            if frame_count % frame_skip == 0:
         | 
| 188 | 
            +
                                frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
         | 
| 189 | 
            +
                                input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
         | 
| 190 |  | 
| 191 | 
            +
                            frame_count += 1
         | 
| 192 |  | 
| 193 | 
            +
                        cap.release()
         | 
| 194 | 
            +
                    else:
         | 
| 195 | 
            +
                        input_video = input_video_path
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    input_video = torch.from_numpy(np.array(input_video))[:video_length]
         | 
| 198 | 
            +
                    input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
         | 
| 199 |  | 
| 200 | 
            +
                    if validation_video_mask is not None:
         | 
| 201 | 
            +
                        validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
         | 
| 202 | 
            +
                        input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
         | 
| 203 | 
            +
                        
         | 
| 204 | 
            +
                        input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
         | 
| 205 | 
            +
                        input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
         | 
| 206 | 
            +
                        input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
         | 
| 207 | 
            +
                    else:
         | 
| 208 | 
            +
                        input_video_mask = torch.zeros_like(input_video[:, :1])
         | 
| 209 | 
            +
                        input_video_mask[:, :, :] = 255
         | 
| 210 | 
            +
                else:
         | 
| 211 | 
            +
                    input_video, input_video_mask = None, None
         | 
| 212 |  | 
| 213 | 
             
                if ref_image is not None:
         | 
| 214 | 
            +
                    if isinstance(ref_image, str):
         | 
| 215 | 
            +
                        ref_image = Image.open(ref_image).convert("RGB")
         | 
| 216 | 
            +
                        ref_image = ref_image.resize((sample_size[1], sample_size[0]))
         | 
| 217 | 
            +
                        ref_image = torch.from_numpy(np.array(ref_image))
         | 
| 218 | 
            +
                        ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
         | 
| 219 | 
            +
                    else:
         | 
| 220 | 
            +
                        ref_image = torch.from_numpy(np.array(ref_image))
         | 
| 221 | 
            +
                        ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
         | 
| 222 | 
            +
                return input_video, input_video_mask, ref_image
         | 
| 223 |  | 
| 224 | 
            +
            def get_image_latent(ref_image=None, sample_size=None):
         | 
| 225 | 
            +
                if ref_image is not None:
         | 
| 226 | 
            +
                    if isinstance(ref_image, str):
         | 
| 227 | 
            +
                        ref_image = Image.open(ref_image).convert("RGB")
         | 
| 228 | 
            +
                        ref_image = ref_image.resize((sample_size[1], sample_size[0]))
         | 
| 229 | 
            +
                        ref_image = torch.from_numpy(np.array(ref_image))
         | 
| 230 | 
            +
                        ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
         | 
| 231 | 
            +
                    else:
         | 
| 232 | 
            +
                        ref_image = torch.from_numpy(np.array(ref_image))
         | 
| 233 | 
            +
                        ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255
         | 
| 234 |  | 
| 235 | 
            +
                return ref_image
         | 
    	
        easyanimate/vae/ldm/models/autoencoder.py
    CHANGED
    
    | @@ -126,13 +126,13 @@ class AutoencoderKLMagvit(pl.LightningModule): | |
| 126 |  | 
| 127 | 
             
                def configure_optimizers(self):
         | 
| 128 | 
             
                    lr = self.learning_rate
         | 
| 129 | 
            -
                    opt_ae = torch.optim. | 
| 130 | 
             
                                              list(self.decoder.parameters())+
         | 
| 131 | 
             
                                              list(self.quant_conv.parameters())+
         | 
| 132 | 
             
                                              list(self.post_quant_conv.parameters()),
         | 
| 133 | 
            -
                                              lr=lr, betas=(0. | 
| 134 | 
            -
                    opt_disc = torch.optim. | 
| 135 | 
            -
                                                lr=lr, betas=(0. | 
| 136 | 
             
                    return [opt_ae, opt_disc], []
         | 
| 137 |  | 
| 138 | 
             
                def get_last_layer(self):
         | 
|  | |
| 126 |  | 
| 127 | 
             
                def configure_optimizers(self):
         | 
| 128 | 
             
                    lr = self.learning_rate
         | 
| 129 | 
            +
                    opt_ae = torch.optim.AdamW(list(self.encoder.parameters())+
         | 
| 130 | 
             
                                              list(self.decoder.parameters())+
         | 
| 131 | 
             
                                              list(self.quant_conv.parameters())+
         | 
| 132 | 
             
                                              list(self.post_quant_conv.parameters()),
         | 
| 133 | 
            +
                                              lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
         | 
| 134 | 
            +
                    opt_disc = torch.optim.AdamW(self.loss.discriminator.parameters(),
         | 
| 135 | 
            +
                                                lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
         | 
| 136 | 
             
                    return [opt_ae, opt_disc], []
         | 
| 137 |  | 
| 138 | 
             
                def get_last_layer(self):
         | 
    	
        easyanimate/vae/ldm/models/casual3dcnn.py
    CHANGED
    
    | @@ -279,13 +279,13 @@ class AutoencoderKL(pl.LightningModule): | |
| 279 |  | 
| 280 | 
             
                def configure_optimizers(self):
         | 
| 281 | 
             
                    lr = self.learning_rate
         | 
| 282 | 
            -
                    opt_ae = torch.optim. | 
| 283 | 
             
                                              list(self.decoder.parameters())+
         | 
| 284 | 
             
                                              list(self.quant_conv.parameters())+
         | 
| 285 | 
            -
                                              list(self.post_quant_conv.parameters()),
         | 
| 286 | 
            -
                                              lr=lr, betas=(0. | 
| 287 | 
            -
                    opt_disc = torch.optim. | 
| 288 | 
            -
                                                lr=lr, betas=(0. | 
| 289 | 
             
                    return [opt_ae, opt_disc], []
         | 
| 290 |  | 
| 291 | 
             
                def get_last_layer(self):
         | 
|  | |
| 279 |  | 
| 280 | 
             
                def configure_optimizers(self):
         | 
| 281 | 
             
                    lr = self.learning_rate
         | 
| 282 | 
            +
                    opt_ae = torch.optim.AdamW(list(self.encoder.parameters())+
         | 
| 283 | 
             
                                              list(self.decoder.parameters())+
         | 
| 284 | 
             
                                              list(self.quant_conv.parameters())+
         | 
| 285 | 
            +
                                              list(self.post_quant_conv.parameters()), \
         | 
| 286 | 
            +
                                              lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
         | 
| 287 | 
            +
                    opt_disc = torch.optim.AdamW(self.loss.discriminator.parameters(),
         | 
| 288 | 
            +
                                                lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
         | 
| 289 | 
             
                    return [opt_ae, opt_disc], []
         | 
| 290 |  | 
| 291 | 
             
                def get_last_layer(self):
         | 
    	
        easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py
    CHANGED
    
    | @@ -277,23 +277,23 @@ class AutoencoderKLMagvit_CogVideoX(pl.LightningModule): | |
| 277 | 
             
                            training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
         | 
| 278 | 
             
                        else:
         | 
| 279 | 
             
                            training_list = list(self.decoder.parameters())
         | 
| 280 | 
            -
                        opt_ae = torch.optim. | 
| 281 | 
             
                    elif self.train_encoder_only:
         | 
| 282 | 
             
                        if self.quant_conv is not None:
         | 
| 283 | 
             
                            training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
         | 
| 284 | 
             
                        else:
         | 
| 285 | 
             
                            training_list = list(self.encoder.parameters())
         | 
| 286 | 
            -
                        opt_ae = torch.optim. | 
| 287 | 
             
                    else:
         | 
| 288 | 
             
                        training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
         | 
| 289 | 
             
                        if self.quant_conv is not None:
         | 
| 290 | 
             
                            training_list = training_list + list(self.quant_conv.parameters())
         | 
| 291 | 
             
                        if self.post_quant_conv is not None:
         | 
| 292 | 
             
                            training_list = training_list + list(self.post_quant_conv.parameters())
         | 
| 293 | 
            -
                        opt_ae = torch.optim. | 
| 294 | 
            -
                    opt_disc = torch.optim. | 
| 295 | 
             
                        list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
         | 
| 296 | 
            -
                        lr=lr, betas=(0. | 
| 297 | 
             
                    )
         | 
| 298 | 
             
                    return [opt_ae, opt_disc], []
         | 
| 299 |  | 
|  | |
| 277 | 
             
                            training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
         | 
| 278 | 
             
                        else:
         | 
| 279 | 
             
                            training_list = list(self.decoder.parameters())
         | 
| 280 | 
            +
                        opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
         | 
| 281 | 
             
                    elif self.train_encoder_only:
         | 
| 282 | 
             
                        if self.quant_conv is not None:
         | 
| 283 | 
             
                            training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
         | 
| 284 | 
             
                        else:
         | 
| 285 | 
             
                            training_list = list(self.encoder.parameters())
         | 
| 286 | 
            +
                        opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
         | 
| 287 | 
             
                    else:
         | 
| 288 | 
             
                        training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
         | 
| 289 | 
             
                        if self.quant_conv is not None:
         | 
| 290 | 
             
                            training_list = training_list + list(self.quant_conv.parameters())
         | 
| 291 | 
             
                        if self.post_quant_conv is not None:
         | 
| 292 | 
             
                            training_list = training_list + list(self.post_quant_conv.parameters())
         | 
| 293 | 
            +
                        opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
         | 
| 294 | 
            +
                    opt_disc = torch.optim.AdamW(
         | 
| 295 | 
             
                        list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
         | 
| 296 | 
            +
                        lr=lr, betas=(0.9, 0.999), weight_decay=5e-2
         | 
| 297 | 
             
                    )
         | 
| 298 | 
             
                    return [opt_ae, opt_disc], []
         | 
| 299 |  | 
    	
        easyanimate/vae/ldm/models/omnigen_casual3dcnn.py
    CHANGED
    
    | @@ -95,6 +95,7 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): | |
| 95 | 
             
                    out_channels: int = 3,
         | 
| 96 | 
             
                    ch =  128,
         | 
| 97 | 
             
                    ch_mult = [ 1,2,4,4 ],
         | 
|  | |
| 98 | 
             
                    use_gc_blocks = None,
         | 
| 99 | 
             
                    down_block_types: tuple = None,
         | 
| 100 | 
             
                    up_block_types: tuple = None,
         | 
| @@ -129,8 +130,9 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): | |
| 129 | 
             
                        in_channels=in_channels,
         | 
| 130 | 
             
                        out_channels=latent_channels,
         | 
| 131 | 
             
                        down_block_types=down_block_types,
         | 
| 132 | 
            -
                        ch | 
| 133 | 
            -
                        ch_mult | 
|  | |
| 134 | 
             
                        use_gc_blocks=use_gc_blocks,
         | 
| 135 | 
             
                        mid_block_type=mid_block_type,
         | 
| 136 | 
             
                        mid_block_use_attention=mid_block_use_attention,
         | 
| @@ -144,6 +146,7 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): | |
| 144 | 
             
                        slice_mag_vae=slice_mag_vae,
         | 
| 145 | 
             
                        slice_compression_vae=slice_compression_vae,
         | 
| 146 | 
             
                        cache_compression_vae=cache_compression_vae,
         | 
|  | |
| 147 | 
             
                        spatial_group_norm=spatial_group_norm,
         | 
| 148 | 
             
                        mini_batch_encoder=mini_batch_encoder,
         | 
| 149 | 
             
                    )
         | 
| @@ -152,8 +155,9 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): | |
| 152 | 
             
                        in_channels=latent_channels,
         | 
| 153 | 
             
                        out_channels=out_channels,
         | 
| 154 | 
             
                        up_block_types=up_block_types,
         | 
| 155 | 
            -
                        ch | 
| 156 | 
            -
                        ch_mult | 
|  | |
| 157 | 
             
                        use_gc_blocks=use_gc_blocks,
         | 
| 158 | 
             
                        mid_block_type=mid_block_type,
         | 
| 159 | 
             
                        mid_block_use_attention=mid_block_use_attention,
         | 
| @@ -292,23 +296,23 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): | |
| 292 | 
             
                            training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
         | 
| 293 | 
             
                        else:
         | 
| 294 | 
             
                            training_list = list(self.decoder.parameters())
         | 
| 295 | 
            -
                        opt_ae = torch.optim. | 
| 296 | 
             
                    elif self.train_encoder_only:
         | 
| 297 | 
             
                        if self.quant_conv is not None:
         | 
| 298 | 
             
                            training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
         | 
| 299 | 
             
                        else:
         | 
| 300 | 
             
                            training_list = list(self.encoder.parameters())
         | 
| 301 | 
            -
                        opt_ae = torch.optim. | 
| 302 | 
             
                    else:
         | 
| 303 | 
             
                        training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
         | 
| 304 | 
             
                        if self.quant_conv is not None:
         | 
| 305 | 
             
                            training_list = training_list + list(self.quant_conv.parameters())
         | 
| 306 | 
             
                        if self.post_quant_conv is not None:
         | 
| 307 | 
             
                            training_list = training_list + list(self.post_quant_conv.parameters())
         | 
| 308 | 
            -
                        opt_ae = torch.optim. | 
| 309 | 
            -
                    opt_disc = torch.optim. | 
| 310 | 
             
                        list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
         | 
| 311 | 
            -
                        lr=lr, betas=(0. | 
| 312 | 
             
                    )
         | 
| 313 | 
             
                    return [opt_ae, opt_disc], []
         | 
| 314 |  | 
|  | |
| 95 | 
             
                    out_channels: int = 3,
         | 
| 96 | 
             
                    ch =  128,
         | 
| 97 | 
             
                    ch_mult = [ 1,2,4,4 ],
         | 
| 98 | 
            +
                    block_out_channels = [128, 256, 512, 512],
         | 
| 99 | 
             
                    use_gc_blocks = None,
         | 
| 100 | 
             
                    down_block_types: tuple = None,
         | 
| 101 | 
             
                    up_block_types: tuple = None,
         | 
|  | |
| 130 | 
             
                        in_channels=in_channels,
         | 
| 131 | 
             
                        out_channels=latent_channels,
         | 
| 132 | 
             
                        down_block_types=down_block_types,
         | 
| 133 | 
            +
                        ch=ch,
         | 
| 134 | 
            +
                        ch_mult=ch_mult,
         | 
| 135 | 
            +
                        block_out_channels=block_out_channels,
         | 
| 136 | 
             
                        use_gc_blocks=use_gc_blocks,
         | 
| 137 | 
             
                        mid_block_type=mid_block_type,
         | 
| 138 | 
             
                        mid_block_use_attention=mid_block_use_attention,
         | 
|  | |
| 146 | 
             
                        slice_mag_vae=slice_mag_vae,
         | 
| 147 | 
             
                        slice_compression_vae=slice_compression_vae,
         | 
| 148 | 
             
                        cache_compression_vae=cache_compression_vae,
         | 
| 149 | 
            +
                        cache_mag_vae=cache_mag_vae,
         | 
| 150 | 
             
                        spatial_group_norm=spatial_group_norm,
         | 
| 151 | 
             
                        mini_batch_encoder=mini_batch_encoder,
         | 
| 152 | 
             
                    )
         | 
|  | |
| 155 | 
             
                        in_channels=latent_channels,
         | 
| 156 | 
             
                        out_channels=out_channels,
         | 
| 157 | 
             
                        up_block_types=up_block_types,
         | 
| 158 | 
            +
                        ch=ch,
         | 
| 159 | 
            +
                        ch_mult=ch_mult,
         | 
| 160 | 
            +
                        block_out_channels=block_out_channels,
         | 
| 161 | 
             
                        use_gc_blocks=use_gc_blocks,
         | 
| 162 | 
             
                        mid_block_type=mid_block_type,
         | 
| 163 | 
             
                        mid_block_use_attention=mid_block_use_attention,
         | 
|  | |
| 296 | 
             
                            training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters())
         | 
| 297 | 
             
                        else:
         | 
| 298 | 
             
                            training_list = list(self.decoder.parameters())
         | 
| 299 | 
            +
                        opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
         | 
| 300 | 
             
                    elif self.train_encoder_only:
         | 
| 301 | 
             
                        if self.quant_conv is not None:
         | 
| 302 | 
             
                            training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters())
         | 
| 303 | 
             
                        else:
         | 
| 304 | 
             
                            training_list = list(self.encoder.parameters())
         | 
| 305 | 
            +
                        opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
         | 
| 306 | 
             
                    else:
         | 
| 307 | 
             
                        training_list = list(self.encoder.parameters()) + list(self.decoder.parameters())
         | 
| 308 | 
             
                        if self.quant_conv is not None:
         | 
| 309 | 
             
                            training_list = training_list + list(self.quant_conv.parameters())
         | 
| 310 | 
             
                        if self.post_quant_conv is not None:
         | 
| 311 | 
             
                            training_list = training_list + list(self.post_quant_conv.parameters())
         | 
| 312 | 
            +
                        opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2)
         | 
| 313 | 
            +
                    opt_disc = torch.optim.AdamW(
         | 
| 314 | 
             
                        list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()),
         | 
| 315 | 
            +
                        lr=lr, betas=(0.9, 0.999), weight_decay=5e-2
         | 
| 316 | 
             
                    )
         | 
| 317 | 
             
                    return [opt_ae, opt_disc], []
         | 
| 318 |  | 
    	
        easyanimate/vae/ldm/models/omnigen_enc_dec.py
    CHANGED
    
    | @@ -58,6 +58,7 @@ class Encoder(nn.Module): | |
| 58 | 
             
                    down_block_types = ("SpatialDownBlock3D",),
         | 
| 59 | 
             
                    ch = 128,
         | 
| 60 | 
             
                    ch_mult = [1,2,4,4,],
         | 
|  | |
| 61 | 
             
                    use_gc_blocks = None,
         | 
| 62 | 
             
                    mid_block_type: str = "MidBlock3D",
         | 
| 63 | 
             
                    mid_block_use_attention: bool = True,
         | 
| @@ -77,7 +78,8 @@ class Encoder(nn.Module): | |
| 77 | 
             
                    verbose = False,
         | 
| 78 | 
             
                ):
         | 
| 79 | 
             
                    super().__init__()
         | 
| 80 | 
            -
                    block_out_channels  | 
|  | |
| 81 | 
             
                    assert len(down_block_types) == len(block_out_channels), (
         | 
| 82 | 
             
                        "Number of down block types must match number of block output channels."
         | 
| 83 | 
             
                    )
         | 
| @@ -364,6 +366,7 @@ class Decoder(nn.Module): | |
| 364 | 
             
                    up_block_types  = ("SpatialUpBlock3D",),
         | 
| 365 | 
             
                    ch = 128,
         | 
| 366 | 
             
                    ch_mult = [1,2,4,4,],
         | 
|  | |
| 367 | 
             
                    use_gc_blocks = None,
         | 
| 368 | 
             
                    mid_block_type: str = "MidBlock3D",
         | 
| 369 | 
             
                    mid_block_use_attention: bool = True,
         | 
| @@ -382,7 +385,8 @@ class Decoder(nn.Module): | |
| 382 | 
             
                    verbose = False,
         | 
| 383 | 
             
                ):
         | 
| 384 | 
             
                    super().__init__()
         | 
| 385 | 
            -
                    block_out_channels  | 
|  | |
| 386 | 
             
                    assert len(up_block_types) == len(block_out_channels), (
         | 
| 387 | 
             
                        "Number of up block types must match number of block output channels."
         | 
| 388 | 
             
                    )
         | 
|  | |
| 58 | 
             
                    down_block_types = ("SpatialDownBlock3D",),
         | 
| 59 | 
             
                    ch = 128,
         | 
| 60 | 
             
                    ch_mult = [1,2,4,4,],
         | 
| 61 | 
            +
                    block_out_channels = [128, 256, 512, 512],
         | 
| 62 | 
             
                    use_gc_blocks = None,
         | 
| 63 | 
             
                    mid_block_type: str = "MidBlock3D",
         | 
| 64 | 
             
                    mid_block_use_attention: bool = True,
         | 
|  | |
| 78 | 
             
                    verbose = False,
         | 
| 79 | 
             
                ):
         | 
| 80 | 
             
                    super().__init__()
         | 
| 81 | 
            +
                    if block_out_channels is None:
         | 
| 82 | 
            +
                        block_out_channels = [ch * i for i in ch_mult]
         | 
| 83 | 
             
                    assert len(down_block_types) == len(block_out_channels), (
         | 
| 84 | 
             
                        "Number of down block types must match number of block output channels."
         | 
| 85 | 
             
                    )
         | 
|  | |
| 366 | 
             
                    up_block_types  = ("SpatialUpBlock3D",),
         | 
| 367 | 
             
                    ch = 128,
         | 
| 368 | 
             
                    ch_mult = [1,2,4,4,],
         | 
| 369 | 
            +
                    block_out_channels = [128, 256, 512, 512],
         | 
| 370 | 
             
                    use_gc_blocks = None,
         | 
| 371 | 
             
                    mid_block_type: str = "MidBlock3D",
         | 
| 372 | 
             
                    mid_block_use_attention: bool = True,
         | 
|  | |
| 385 | 
             
                    verbose = False,
         | 
| 386 | 
             
                ):
         | 
| 387 | 
             
                    super().__init__()
         | 
| 388 | 
            +
                    if block_out_channels is None:
         | 
| 389 | 
            +
                        block_out_channels = [ch * i for i in ch_mult]
         | 
| 390 | 
             
                    assert len(up_block_types) == len(block_out_channels), (
         | 
| 391 | 
             
                        "Number of up block types must match number of block output channels."
         | 
| 392 | 
             
                    )
         | 
    	
        easyanimate/vae/ldm/modules/losses/contperceptual.py
    CHANGED
    
    | @@ -9,7 +9,8 @@ from ..vaemodules.discriminator import Discriminator3D | |
| 9 | 
             
            class LPIPSWithDiscriminator(nn.Module):
         | 
| 10 | 
             
                def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
         | 
| 11 | 
             
                             disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
         | 
| 12 | 
            -
                             perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
         | 
|  | |
| 13 | 
             
                             disc_loss="hinge", l2_loss_weight=0.0, l1_loss_weight=1.0):
         | 
| 14 |  | 
| 15 | 
             
                    super().__init__()
         | 
| @@ -34,6 +35,8 @@ class LPIPSWithDiscriminator(nn.Module): | |
| 34 | 
             
                    self.disc_factor = disc_factor
         | 
| 35 | 
             
                    self.discriminator_weight = disc_weight
         | 
| 36 | 
             
                    self.disc_conditional = disc_conditional
         | 
|  | |
|  | |
| 37 | 
             
                    self.l1_loss_weight = l1_loss_weight
         | 
| 38 | 
             
                    self.l2_loss_weight = l2_loss_weight
         | 
| 39 |  | 
| @@ -50,6 +53,18 @@ class LPIPSWithDiscriminator(nn.Module): | |
| 50 | 
             
                    d_weight = d_weight * self.discriminator_weight
         | 
| 51 | 
             
                    return d_weight
         | 
| 52 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 53 | 
             
                def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
         | 
| 54 | 
             
                            global_step, last_layer=None, cond=None, split="train",
         | 
| 55 | 
             
                            weights=None):
         | 
| @@ -86,6 +101,8 @@ class LPIPSWithDiscriminator(nn.Module): | |
| 86 | 
             
                    kl_loss = posteriors.kl()
         | 
| 87 | 
             
                    kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
         | 
| 88 |  | 
|  | |
|  | |
| 89 | 
             
                    # now the GAN part
         | 
| 90 | 
             
                    if optimizer_idx == 0:
         | 
| 91 | 
             
                        # generator update
         | 
| @@ -102,13 +119,13 @@ class LPIPSWithDiscriminator(nn.Module): | |
| 102 | 
             
                            try:
         | 
| 103 | 
             
                                d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
         | 
| 104 | 
             
                            except RuntimeError:
         | 
| 105 | 
            -
                                assert not self.training
         | 
| 106 | 
             
                                d_weight = torch.tensor(0.0)
         | 
| 107 | 
             
                        else:
         | 
| 108 | 
             
                            d_weight = torch.tensor(0.0)
         | 
| 109 |  | 
| 110 | 
             
                        disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
         | 
| 111 | 
            -
                        loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
         | 
| 112 |  | 
| 113 | 
             
                        log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
         | 
| 114 | 
             
                               "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
         | 
|  | |
| 9 | 
             
            class LPIPSWithDiscriminator(nn.Module):
         | 
| 10 | 
             
                def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
         | 
| 11 | 
             
                             disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
         | 
| 12 | 
            +
                             perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 
         | 
| 13 | 
            +
                             outlier_penalty_loss_r=3.0, outlier_penalty_loss_weight=1e5,
         | 
| 14 | 
             
                             disc_loss="hinge", l2_loss_weight=0.0, l1_loss_weight=1.0):
         | 
| 15 |  | 
| 16 | 
             
                    super().__init__()
         | 
|  | |
| 35 | 
             
                    self.disc_factor = disc_factor
         | 
| 36 | 
             
                    self.discriminator_weight = disc_weight
         | 
| 37 | 
             
                    self.disc_conditional = disc_conditional
         | 
| 38 | 
            +
                    self.outlier_penalty_loss_r = outlier_penalty_loss_r
         | 
| 39 | 
            +
                    self.outlier_penalty_loss_weight = outlier_penalty_loss_weight
         | 
| 40 | 
             
                    self.l1_loss_weight = l1_loss_weight
         | 
| 41 | 
             
                    self.l2_loss_weight = l2_loss_weight
         | 
| 42 |  | 
|  | |
| 53 | 
             
                    d_weight = d_weight * self.discriminator_weight
         | 
| 54 | 
             
                    return d_weight
         | 
| 55 |  | 
| 56 | 
            +
                def outlier_penalty_loss(self, posteriors, r):
         | 
| 57 | 
            +
                    batch_size, channels, frames, height, width = posteriors.shape
         | 
| 58 | 
            +
                    mean_X = posteriors.mean(dim=(3, 4), keepdim=True)
         | 
| 59 | 
            +
                    std_X = posteriors.std(dim=(3, 4), keepdim=True)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    diff = torch.abs(posteriors - mean_X)
         | 
| 62 | 
            +
                    penalty = torch.maximum(diff - r * std_X, torch.zeros_like(diff))
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    opl = penalty.sum(dim=(3, 4)) / (height * width)
         | 
| 65 | 
            +
                    opl_final = opl.mean(dim=(0, 1, 2))
         | 
| 66 | 
            +
                    return opl_final
         | 
| 67 | 
            +
             | 
| 68 | 
             
                def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
         | 
| 69 | 
             
                            global_step, last_layer=None, cond=None, split="train",
         | 
| 70 | 
             
                            weights=None):
         | 
|  | |
| 101 | 
             
                    kl_loss = posteriors.kl()
         | 
| 102 | 
             
                    kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
         | 
| 103 |  | 
| 104 | 
            +
                    outlier_penalty_loss = self.outlier_penalty_loss(posteriors.mode(), self.outlier_penalty_loss_r) * self.outlier_penalty_loss_weight
         | 
| 105 | 
            +
             | 
| 106 | 
             
                    # now the GAN part
         | 
| 107 | 
             
                    if optimizer_idx == 0:
         | 
| 108 | 
             
                        # generator update
         | 
|  | |
| 119 | 
             
                            try:
         | 
| 120 | 
             
                                d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
         | 
| 121 | 
             
                            except RuntimeError:
         | 
| 122 | 
            +
                                # assert not self.training
         | 
| 123 | 
             
                                d_weight = torch.tensor(0.0)
         | 
| 124 | 
             
                        else:
         | 
| 125 | 
             
                            d_weight = torch.tensor(0.0)
         | 
| 126 |  | 
| 127 | 
             
                        disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
         | 
| 128 | 
            +
                        loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + outlier_penalty_loss
         | 
| 129 |  | 
| 130 | 
             
                        log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
         | 
| 131 | 
             
                               "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
         | 
    	
        easyanimate/vae/ldm/modules/vaemodules/__init__.py
    CHANGED
    
    | 
            File without changes
         | 
    	
        easyanimate/vae/ldm/modules/vaemodules/activations.py
    CHANGED
    
    | 
            File without changes
         | 
    	
        easyanimate/vae/ldm/modules/vaemodules/common.py
    CHANGED
    
    | @@ -8,6 +8,17 @@ from einops import rearrange, repeat | |
| 8 | 
             
            from .activations import get_activation
         | 
| 9 |  | 
| 10 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 11 | 
             
            def cast_tuple(t, length = 1):
         | 
| 12 | 
             
                return t if isinstance(t, tuple) else ((t,) * length)
         | 
| 13 |  | 
| @@ -66,10 +77,15 @@ class CausalConv3d(nn.Conv3d): | |
| 66 | 
             
                        **kwargs,
         | 
| 67 | 
             
                    )
         | 
| 68 |  | 
|  | |
|  | |
|  | |
|  | |
| 69 | 
             
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 70 | 
             
                    # x: (B, C, T, H, W)
         | 
| 71 | 
             
                    dtype = x.dtype
         | 
| 72 | 
            -
                     | 
|  | |
| 73 | 
             
                    if self.padding_flag == 0:
         | 
| 74 | 
             
                        x = F.pad(
         | 
| 75 | 
             
                            x,
         | 
| @@ -85,7 +101,11 @@ class CausalConv3d(nn.Conv3d): | |
| 85 | 
             
                            mode="replicate",     # TODO: check if this is necessary
         | 
| 86 | 
             
                        )
         | 
| 87 | 
             
                        x = x.to(dtype=dtype)
         | 
| 88 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 89 |  | 
| 90 | 
             
                        b, c, f, h, w = x.size()
         | 
| 91 | 
             
                        outputs = []
         | 
| @@ -105,7 +125,11 @@ class CausalConv3d(nn.Conv3d): | |
| 105 | 
             
                                [self.prev_features, x], dim = 2
         | 
| 106 | 
             
                            )
         | 
| 107 | 
             
                        x = x.to(dtype=dtype)
         | 
| 108 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 109 |  | 
| 110 | 
             
                        b, c, f, h, w = x.size()
         | 
| 111 | 
             
                        outputs = []
         | 
| @@ -122,7 +146,12 @@ class CausalConv3d(nn.Conv3d): | |
| 122 | 
             
                            mode="replicate",     # TODO: check if this is necessary
         | 
| 123 | 
             
                        )
         | 
| 124 | 
             
                        x = x.to(dtype=dtype)
         | 
| 125 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 126 | 
             
                        return super().forward(x)
         | 
| 127 | 
             
                    elif self.padding_flag == 6:
         | 
| 128 | 
             
                        if self.t_stride == 2:
         | 
| @@ -133,7 +162,12 @@ class CausalConv3d(nn.Conv3d): | |
| 133 | 
             
                            x = torch.concat(
         | 
| 134 | 
             
                                [self.prev_features, x], dim = 2
         | 
| 135 | 
             
                            )
         | 
| 136 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 137 | 
             
                        x = x.to(dtype=dtype)
         | 
| 138 | 
             
                        return super().forward(x)
         | 
| 139 | 
             
                    else:
         | 
|  | |
| 8 | 
             
            from .activations import get_activation
         | 
| 9 |  | 
| 10 |  | 
| 11 | 
            +
            try:
         | 
| 12 | 
            +
                current_version = torch.__version__
         | 
| 13 | 
            +
                version_numbers = [int(x) for x in current_version.split('.')[:2]]
         | 
| 14 | 
            +
                if version_numbers[0] < 2 or (version_numbers[0] == 2 and version_numbers[1] < 2):
         | 
| 15 | 
            +
                    need_to_float = True
         | 
| 16 | 
            +
                else:
         | 
| 17 | 
            +
                    need_to_float = False
         | 
| 18 | 
            +
            except Exception as e:
         | 
| 19 | 
            +
                print("Encountered an error with Torch version. Set the data type to float in the VAE. ")
         | 
| 20 | 
            +
                need_to_float = False
         | 
| 21 | 
            +
             | 
| 22 | 
             
            def cast_tuple(t, length = 1):
         | 
| 23 | 
             
                return t if isinstance(t, tuple) else ((t,) * length)
         | 
| 24 |  | 
|  | |
| 77 | 
             
                        **kwargs,
         | 
| 78 | 
             
                    )
         | 
| 79 |  | 
| 80 | 
            +
                def _clear_conv_cache(self):
         | 
| 81 | 
            +
                    del self.prev_features
         | 
| 82 | 
            +
                    self.prev_features = None
         | 
| 83 | 
            +
             | 
| 84 | 
             
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 85 | 
             
                    # x: (B, C, T, H, W)
         | 
| 86 | 
             
                    dtype = x.dtype
         | 
| 87 | 
            +
                    if need_to_float:
         | 
| 88 | 
            +
                        x = x.float()
         | 
| 89 | 
             
                    if self.padding_flag == 0:
         | 
| 90 | 
             
                        x = F.pad(
         | 
| 91 | 
             
                            x,
         | 
|  | |
| 101 | 
             
                            mode="replicate",     # TODO: check if this is necessary
         | 
| 102 | 
             
                        )
         | 
| 103 | 
             
                        x = x.to(dtype=dtype)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                        # Clear cache before
         | 
| 106 | 
            +
                        self._clear_conv_cache()
         | 
| 107 | 
            +
                        # We could move these to the cpu for a lower VRAM
         | 
| 108 | 
            +
                        self.prev_features = x[:, :, -self.temporal_padding:].clone()
         | 
| 109 |  | 
| 110 | 
             
                        b, c, f, h, w = x.size()
         | 
| 111 | 
             
                        outputs = []
         | 
|  | |
| 125 | 
             
                                [self.prev_features, x], dim = 2
         | 
| 126 | 
             
                            )
         | 
| 127 | 
             
                        x = x.to(dtype=dtype)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                        # Clear cache before
         | 
| 130 | 
            +
                        self._clear_conv_cache()
         | 
| 131 | 
            +
                        # We could move these to the cpu for a lower VRAM
         | 
| 132 | 
            +
                        self.prev_features = x[:, :, -self.temporal_padding:].clone()
         | 
| 133 |  | 
| 134 | 
             
                        b, c, f, h, w = x.size()
         | 
| 135 | 
             
                        outputs = []
         | 
|  | |
| 146 | 
             
                            mode="replicate",     # TODO: check if this is necessary
         | 
| 147 | 
             
                        )
         | 
| 148 | 
             
                        x = x.to(dtype=dtype)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                        # Clear cache before
         | 
| 151 | 
            +
                        self._clear_conv_cache()
         | 
| 152 | 
            +
                        # We could move these to the cpu for a lower VRAM
         | 
| 153 | 
            +
                        self.prev_features = x[:, :, -self.temporal_padding:].clone()
         | 
| 154 | 
            +
                        
         | 
| 155 | 
             
                        return super().forward(x)
         | 
| 156 | 
             
                    elif self.padding_flag == 6:
         | 
| 157 | 
             
                        if self.t_stride == 2:
         | 
|  | |
| 162 | 
             
                            x = torch.concat(
         | 
| 163 | 
             
                                [self.prev_features, x], dim = 2
         | 
| 164 | 
             
                            )
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                        # Clear cache before
         | 
| 167 | 
            +
                        self._clear_conv_cache()
         | 
| 168 | 
            +
                        # We could move these to the cpu for a lower VRAM
         | 
| 169 | 
            +
                        self.prev_features = x[:, :, -self.temporal_padding:].clone()
         | 
| 170 | 
            +
                        
         | 
| 171 | 
             
                        x = x.to(dtype=dtype)
         | 
| 172 | 
             
                        return super().forward(x)
         | 
| 173 | 
             
                    else:
         | 
    	
        easyanimate/vae/ldm/modules/vaemodules/down_blocks.py
    CHANGED
    
    | 
            File without changes
         | 
    	
        easyanimate/vae/ldm/modules/vaemodules/mid_blocks.py
    CHANGED
    
    | 
            File without changes
         | 
    	
        easyanimate/vae/ldm/modules/vaemodules/up_blocks.py
    CHANGED
    
    | 
            File without changes
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -6,7 +6,6 @@ tomesd | |
| 6 | 
             
            torch>=2.1.2
         | 
| 7 | 
             
            torchdiffeq
         | 
| 8 | 
             
            torchsde
         | 
| 9 | 
            -
            xformers
         | 
| 10 | 
             
            decord
         | 
| 11 | 
             
            datasets
         | 
| 12 | 
             
            numpy
         | 
| @@ -21,8 +20,6 @@ tensorboard | |
| 21 | 
             
            beautifulsoup4
         | 
| 22 | 
             
            ftfy
         | 
| 23 | 
             
            func_timeout
         | 
| 24 | 
            -
            deepspeed
         | 
| 25 | 
             
            accelerate>=0.25.0
         | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
            transformers>=4.37.2
         | 
|  | |
| 6 | 
             
            torch>=2.1.2
         | 
| 7 | 
             
            torchdiffeq
         | 
| 8 | 
             
            torchsde
         | 
|  | |
| 9 | 
             
            decord
         | 
| 10 | 
             
            datasets
         | 
| 11 | 
             
            numpy
         | 
|  | |
| 20 | 
             
            beautifulsoup4
         | 
| 21 | 
             
            ftfy
         | 
| 22 | 
             
            func_timeout
         | 
|  | |
| 23 | 
             
            accelerate>=0.25.0
         | 
| 24 | 
            +
            diffusers==0.30.1
         | 
| 25 | 
            +
            transformers==4.46.2
         | 
|  |