Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Upload 35 files
Browse files- configs/ltxv-13b-0.9.7-dev.yaml +34 -0
- configs/ltxv-2b-0.9.1.yaml +17 -0
- configs/ltxv-2b-0.9.5.yaml +17 -0
- configs/ltxv-2b-0.9.6-dev.yaml +17 -0
- configs/ltxv-2b-0.9.6-distilled.yaml +16 -0
- configs/ltxv-2b-0.9.yaml +17 -0
- inference.py +778 -0
- ltx_video/__init__.py +0 -0
- ltx_video/models/__init__.py +0 -0
- ltx_video/models/autoencoders/__init__.py +0 -0
- ltx_video/models/autoencoders/causal_conv3d.py +63 -0
- ltx_video/models/autoencoders/causal_video_autoencoder.py +1403 -0
- ltx_video/models/autoencoders/conv_nd_factory.py +90 -0
- ltx_video/models/autoencoders/dual_conv3d.py +217 -0
- ltx_video/models/autoencoders/latent_upsampler.py +203 -0
- ltx_video/models/autoencoders/pixel_norm.py +12 -0
- ltx_video/models/autoencoders/pixel_shuffle.py +33 -0
- ltx_video/models/autoencoders/vae.py +380 -0
- ltx_video/models/autoencoders/vae_encode.py +247 -0
- ltx_video/models/autoencoders/video_autoencoder.py +1045 -0
- ltx_video/models/transformers/__init__.py +0 -0
- ltx_video/models/transformers/attention.py +1265 -0
- ltx_video/models/transformers/embeddings.py +129 -0
- ltx_video/models/transformers/symmetric_patchifier.py +84 -0
- ltx_video/models/transformers/transformer3d.py +507 -0
- ltx_video/pipelines/__init__.py +0 -0
- ltx_video/pipelines/crf_compressor.py +50 -0
- ltx_video/pipelines/pipeline_ltx_video.py +1845 -0
- ltx_video/schedulers/__init__.py +0 -0
- ltx_video/schedulers/rf.py +386 -0
- ltx_video/utils/__init__.py +0 -0
- ltx_video/utils/diffusers_config_mapping.py +174 -0
- ltx_video/utils/prompt_enhance_utils.py +226 -0
- ltx_video/utils/skip_layer_strategy.py +8 -0
- ltx_video/utils/torch_utils.py +25 -0
    	
        configs/ltxv-13b-0.9.7-dev.yaml
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pipeline_type: multi-scale
         | 
| 2 | 
            +
            checkpoint_path: "ltxv-13b-0.9.7-dev.safetensors"
         | 
| 3 | 
            +
            downscale_factor: 0.6666666
         | 
| 4 | 
            +
            spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
         | 
| 5 | 
            +
            stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
         | 
| 6 | 
            +
            decode_timestep: 0.05
         | 
| 7 | 
            +
            decode_noise_scale: 0.025
         | 
| 8 | 
            +
            text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
         | 
| 9 | 
            +
            precision: "bfloat16"
         | 
| 10 | 
            +
            sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
         | 
| 11 | 
            +
            prompt_enhancement_words_threshold: 120
         | 
| 12 | 
            +
            prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
         | 
| 13 | 
            +
            prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
         | 
| 14 | 
            +
            stochastic_sampling: false
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            first_pass:
         | 
| 17 | 
            +
              guidance_scale: [1, 1, 6, 8, 6, 1, 1]
         | 
| 18 | 
            +
              stg_scale: [0, 0, 4, 4, 4, 2, 1]
         | 
| 19 | 
            +
              rescaling_scale: [1, 1, 0.5, 0.5, 1, 1, 1]
         | 
| 20 | 
            +
              guidance_timesteps: [1.0, 0.996,  0.9933, 0.9850, 0.9767, 0.9008, 0.6180]
         | 
| 21 | 
            +
              skip_block_list: [[], [11, 25, 35, 39], [22, 35, 39], [28], [28], [28], [28]]
         | 
| 22 | 
            +
              num_inference_steps: 30
         | 
| 23 | 
            +
              skip_final_inference_steps: 3
         | 
| 24 | 
            +
              cfg_star_rescale: true
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            second_pass:
         | 
| 27 | 
            +
              guidance_scale: [1]
         | 
| 28 | 
            +
              stg_scale: [1]
         | 
| 29 | 
            +
              rescaling_scale: [1]
         | 
| 30 | 
            +
              guidance_timesteps: [1.0]
         | 
| 31 | 
            +
              skip_block_list: [27]
         | 
| 32 | 
            +
              num_inference_steps: 30
         | 
| 33 | 
            +
              skip_initial_inference_steps: 17
         | 
| 34 | 
            +
              cfg_star_rescale: true
         | 
    	
        configs/ltxv-2b-0.9.1.yaml
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pipeline_type: base
         | 
| 2 | 
            +
            checkpoint_path: "ltx-video-2b-v0.9.1.safetensors"
         | 
| 3 | 
            +
            guidance_scale: 3
         | 
| 4 | 
            +
            stg_scale: 1
         | 
| 5 | 
            +
            rescaling_scale: 0.7
         | 
| 6 | 
            +
            skip_block_list: [19]
         | 
| 7 | 
            +
            num_inference_steps: 40
         | 
| 8 | 
            +
            stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
         | 
| 9 | 
            +
            decode_timestep: 0.05
         | 
| 10 | 
            +
            decode_noise_scale: 0.025
         | 
| 11 | 
            +
            text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
         | 
| 12 | 
            +
            precision: "bfloat16"
         | 
| 13 | 
            +
            sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
         | 
| 14 | 
            +
            prompt_enhancement_words_threshold: 120
         | 
| 15 | 
            +
            prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
         | 
| 16 | 
            +
            prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
         | 
| 17 | 
            +
            stochastic_sampling: false
         | 
    	
        configs/ltxv-2b-0.9.5.yaml
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pipeline_type: base
         | 
| 2 | 
            +
            checkpoint_path: "ltx-video-2b-v0.9.5.safetensors"
         | 
| 3 | 
            +
            guidance_scale: 3
         | 
| 4 | 
            +
            stg_scale: 1
         | 
| 5 | 
            +
            rescaling_scale: 0.7
         | 
| 6 | 
            +
            skip_block_list: [19]
         | 
| 7 | 
            +
            num_inference_steps: 40
         | 
| 8 | 
            +
            stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
         | 
| 9 | 
            +
            decode_timestep: 0.05
         | 
| 10 | 
            +
            decode_noise_scale: 0.025
         | 
| 11 | 
            +
            text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
         | 
| 12 | 
            +
            precision: "bfloat16"
         | 
| 13 | 
            +
            sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
         | 
| 14 | 
            +
            prompt_enhancement_words_threshold: 120
         | 
| 15 | 
            +
            prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
         | 
| 16 | 
            +
            prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
         | 
| 17 | 
            +
            stochastic_sampling: false
         | 
    	
        configs/ltxv-2b-0.9.6-dev.yaml
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pipeline_type: base
         | 
| 2 | 
            +
            checkpoint_path: "ltxv-2b-0.9.6-dev-04-25.safetensors"
         | 
| 3 | 
            +
            guidance_scale: 3
         | 
| 4 | 
            +
            stg_scale: 1
         | 
| 5 | 
            +
            rescaling_scale: 0.7
         | 
| 6 | 
            +
            skip_block_list: [19]
         | 
| 7 | 
            +
            num_inference_steps: 40
         | 
| 8 | 
            +
            stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
         | 
| 9 | 
            +
            decode_timestep: 0.05
         | 
| 10 | 
            +
            decode_noise_scale: 0.025
         | 
| 11 | 
            +
            text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
         | 
| 12 | 
            +
            precision: "bfloat16"
         | 
| 13 | 
            +
            sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
         | 
| 14 | 
            +
            prompt_enhancement_words_threshold: 120
         | 
| 15 | 
            +
            prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
         | 
| 16 | 
            +
            prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
         | 
| 17 | 
            +
            stochastic_sampling: false
         | 
    	
        configs/ltxv-2b-0.9.6-distilled.yaml
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pipeline_type: base
         | 
| 2 | 
            +
            checkpoint_path: "ltxv-2b-0.9.6-distilled-04-25.safetensors"
         | 
| 3 | 
            +
            guidance_scale: 1
         | 
| 4 | 
            +
            stg_scale: 0
         | 
| 5 | 
            +
            rescaling_scale: 1
         | 
| 6 | 
            +
            num_inference_steps: 8
         | 
| 7 | 
            +
            stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
         | 
| 8 | 
            +
            decode_timestep: 0.05
         | 
| 9 | 
            +
            decode_noise_scale: 0.025
         | 
| 10 | 
            +
            text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
         | 
| 11 | 
            +
            precision: "bfloat16"
         | 
| 12 | 
            +
            sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
         | 
| 13 | 
            +
            prompt_enhancement_words_threshold: 120
         | 
| 14 | 
            +
            prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
         | 
| 15 | 
            +
            prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
         | 
| 16 | 
            +
            stochastic_sampling: true
         | 
    	
        configs/ltxv-2b-0.9.yaml
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            pipeline_type: base
         | 
| 2 | 
            +
            checkpoint_path: "ltx-video-2b-v0.9.safetensors"
         | 
| 3 | 
            +
            guidance_scale: 3
         | 
| 4 | 
            +
            stg_scale: 1
         | 
| 5 | 
            +
            rescaling_scale: 0.7
         | 
| 6 | 
            +
            skip_block_list: [19]
         | 
| 7 | 
            +
            num_inference_steps: 40
         | 
| 8 | 
            +
            stg_mode: "attention_values" # options: "attention_values", "attention_skip", "residual", "transformer_block"
         | 
| 9 | 
            +
            decode_timestep: 0.05
         | 
| 10 | 
            +
            decode_noise_scale: 0.025
         | 
| 11 | 
            +
            text_encoder_model_name_or_path: "PixArt-alpha/PixArt-XL-2-1024-MS"
         | 
| 12 | 
            +
            precision: "bfloat16"
         | 
| 13 | 
            +
            sampler: "from_checkpoint" # options: "uniform", "linear-quadratic", "from_checkpoint"
         | 
| 14 | 
            +
            prompt_enhancement_words_threshold: 120
         | 
| 15 | 
            +
            prompt_enhancer_image_caption_model_name_or_path: "MiaoshouAI/Florence-2-large-PromptGen-v2.0"
         | 
| 16 | 
            +
            prompt_enhancer_llm_model_name_or_path: "unsloth/Llama-3.2-3B-Instruct"
         | 
| 17 | 
            +
            stochastic_sampling: false
         | 
    	
        inference.py
    ADDED
    
    | @@ -0,0 +1,778 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            from datetime import datetime
         | 
| 5 | 
            +
            from pathlib import Path
         | 
| 6 | 
            +
            from diffusers.utils import logging
         | 
| 7 | 
            +
            from typing import Optional, List, Union
         | 
| 8 | 
            +
            import yaml
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import imageio
         | 
| 11 | 
            +
            import json
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            from safetensors import safe_open
         | 
| 15 | 
            +
            from PIL import Image
         | 
| 16 | 
            +
            from transformers import (
         | 
| 17 | 
            +
                T5EncoderModel,
         | 
| 18 | 
            +
                T5Tokenizer,
         | 
| 19 | 
            +
                AutoModelForCausalLM,
         | 
| 20 | 
            +
                AutoProcessor,
         | 
| 21 | 
            +
                AutoTokenizer,
         | 
| 22 | 
            +
            )
         | 
| 23 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from ltx_video.models.autoencoders.causal_video_autoencoder import (
         | 
| 26 | 
            +
                CausalVideoAutoencoder,
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
            from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier
         | 
| 29 | 
            +
            from ltx_video.models.transformers.transformer3d import Transformer3DModel
         | 
| 30 | 
            +
            from ltx_video.pipelines.pipeline_ltx_video import (
         | 
| 31 | 
            +
                ConditioningItem,
         | 
| 32 | 
            +
                LTXVideoPipeline,
         | 
| 33 | 
            +
                LTXMultiScalePipeline,
         | 
| 34 | 
            +
            )
         | 
| 35 | 
            +
            from ltx_video.schedulers.rf import RectifiedFlowScheduler
         | 
| 36 | 
            +
            from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
         | 
| 37 | 
            +
            from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            MAX_HEIGHT = 720
         | 
| 40 | 
            +
            MAX_WIDTH = 1280
         | 
| 41 | 
            +
            MAX_NUM_FRAMES = 257
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            logger = logging.get_logger("LTX-Video")
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def get_total_gpu_memory():
         | 
| 47 | 
            +
                if torch.cuda.is_available():
         | 
| 48 | 
            +
                    total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
         | 
| 49 | 
            +
                    return total_memory
         | 
| 50 | 
            +
                return 0
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def get_device():
         | 
| 54 | 
            +
                if torch.cuda.is_available():
         | 
| 55 | 
            +
                    return "cuda"
         | 
| 56 | 
            +
                elif torch.backends.mps.is_available():
         | 
| 57 | 
            +
                    return "mps"
         | 
| 58 | 
            +
                return "cpu"
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def load_image_to_tensor_with_resize_and_crop(
         | 
| 62 | 
            +
                image_input: Union[str, Image.Image],
         | 
| 63 | 
            +
                target_height: int = 512,
         | 
| 64 | 
            +
                target_width: int = 768,
         | 
| 65 | 
            +
                just_crop: bool = False,
         | 
| 66 | 
            +
            ) -> torch.Tensor:
         | 
| 67 | 
            +
                """Load and process an image into a tensor.
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                Args:
         | 
| 70 | 
            +
                    image_input: Either a file path (str) or a PIL Image object
         | 
| 71 | 
            +
                    target_height: Desired height of output tensor
         | 
| 72 | 
            +
                    target_width: Desired width of output tensor
         | 
| 73 | 
            +
                    just_crop: If True, only crop the image to the target size without resizing
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                if isinstance(image_input, str):
         | 
| 76 | 
            +
                    image = Image.open(image_input).convert("RGB")
         | 
| 77 | 
            +
                elif isinstance(image_input, Image.Image):
         | 
| 78 | 
            +
                    image = image_input
         | 
| 79 | 
            +
                else:
         | 
| 80 | 
            +
                    raise ValueError("image_input must be either a file path or a PIL Image object")
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                input_width, input_height = image.size
         | 
| 83 | 
            +
                aspect_ratio_target = target_width / target_height
         | 
| 84 | 
            +
                aspect_ratio_frame = input_width / input_height
         | 
| 85 | 
            +
                if aspect_ratio_frame > aspect_ratio_target:
         | 
| 86 | 
            +
                    new_width = int(input_height * aspect_ratio_target)
         | 
| 87 | 
            +
                    new_height = input_height
         | 
| 88 | 
            +
                    x_start = (input_width - new_width) // 2
         | 
| 89 | 
            +
                    y_start = 0
         | 
| 90 | 
            +
                else:
         | 
| 91 | 
            +
                    new_width = input_width
         | 
| 92 | 
            +
                    new_height = int(input_width / aspect_ratio_target)
         | 
| 93 | 
            +
                    x_start = 0
         | 
| 94 | 
            +
                    y_start = (input_height - new_height) // 2
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
         | 
| 97 | 
            +
                if not just_crop:
         | 
| 98 | 
            +
                    image = image.resize((target_width, target_height))
         | 
| 99 | 
            +
                frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
         | 
| 100 | 
            +
                frame_tensor = (frame_tensor / 127.5) - 1.0
         | 
| 101 | 
            +
                # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width)
         | 
| 102 | 
            +
                return frame_tensor.unsqueeze(0).unsqueeze(2)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            def calculate_padding(
         | 
| 106 | 
            +
                source_height: int, source_width: int, target_height: int, target_width: int
         | 
| 107 | 
            +
            ) -> tuple[int, int, int, int]:
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                # Calculate total padding needed
         | 
| 110 | 
            +
                pad_height = target_height - source_height
         | 
| 111 | 
            +
                pad_width = target_width - source_width
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                # Calculate padding for each side
         | 
| 114 | 
            +
                pad_top = pad_height // 2
         | 
| 115 | 
            +
                pad_bottom = pad_height - pad_top  # Handles odd padding
         | 
| 116 | 
            +
                pad_left = pad_width // 2
         | 
| 117 | 
            +
                pad_right = pad_width - pad_left  # Handles odd padding
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                # Return padded tensor
         | 
| 120 | 
            +
                # Padding format is (left, right, top, bottom)
         | 
| 121 | 
            +
                padding = (pad_left, pad_right, pad_top, pad_bottom)
         | 
| 122 | 
            +
                return padding
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            def convert_prompt_to_filename(text: str, max_len: int = 20) -> str:
         | 
| 126 | 
            +
                # Remove non-letters and convert to lowercase
         | 
| 127 | 
            +
                clean_text = "".join(
         | 
| 128 | 
            +
                    char.lower() for char in text if char.isalpha() or char.isspace()
         | 
| 129 | 
            +
                )
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                # Split into words
         | 
| 132 | 
            +
                words = clean_text.split()
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                # Build result string keeping track of length
         | 
| 135 | 
            +
                result = []
         | 
| 136 | 
            +
                current_length = 0
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                for word in words:
         | 
| 139 | 
            +
                    # Add word length plus 1 for underscore (except for first word)
         | 
| 140 | 
            +
                    new_length = current_length + len(word)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    if new_length <= max_len:
         | 
| 143 | 
            +
                        result.append(word)
         | 
| 144 | 
            +
                        current_length += len(word)
         | 
| 145 | 
            +
                    else:
         | 
| 146 | 
            +
                        break
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                return "-".join(result)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            # Generate output video name
         | 
| 152 | 
            +
            def get_unique_filename(
         | 
| 153 | 
            +
                base: str,
         | 
| 154 | 
            +
                ext: str,
         | 
| 155 | 
            +
                prompt: str,
         | 
| 156 | 
            +
                seed: int,
         | 
| 157 | 
            +
                resolution: tuple[int, int, int],
         | 
| 158 | 
            +
                dir: Path,
         | 
| 159 | 
            +
                endswith=None,
         | 
| 160 | 
            +
                index_range=1000,
         | 
| 161 | 
            +
            ) -> Path:
         | 
| 162 | 
            +
                base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}"
         | 
| 163 | 
            +
                for i in range(index_range):
         | 
| 164 | 
            +
                    filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}"
         | 
| 165 | 
            +
                    if not os.path.exists(filename):
         | 
| 166 | 
            +
                        return filename
         | 
| 167 | 
            +
                raise FileExistsError(
         | 
| 168 | 
            +
                    f"Could not find a unique filename after {index_range} attempts."
         | 
| 169 | 
            +
                )
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
            def seed_everething(seed: int):
         | 
| 173 | 
            +
                random.seed(seed)
         | 
| 174 | 
            +
                np.random.seed(seed)
         | 
| 175 | 
            +
                torch.manual_seed(seed)
         | 
| 176 | 
            +
                if torch.cuda.is_available():
         | 
| 177 | 
            +
                    torch.cuda.manual_seed(seed)
         | 
| 178 | 
            +
                if torch.backends.mps.is_available():
         | 
| 179 | 
            +
                    torch.mps.manual_seed(seed)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
             | 
| 182 | 
            +
            def main():
         | 
| 183 | 
            +
                parser = argparse.ArgumentParser(
         | 
| 184 | 
            +
                    description="Load models from separate directories and run the pipeline."
         | 
| 185 | 
            +
                )
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                # Directories
         | 
| 188 | 
            +
                parser.add_argument(
         | 
| 189 | 
            +
                    "--output_path",
         | 
| 190 | 
            +
                    type=str,
         | 
| 191 | 
            +
                    default=None,
         | 
| 192 | 
            +
                    help="Path to the folder to save output video, if None will save in outputs/ directory.",
         | 
| 193 | 
            +
                )
         | 
| 194 | 
            +
                parser.add_argument("--seed", type=int, default="171198")
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                # Pipeline parameters
         | 
| 197 | 
            +
                parser.add_argument(
         | 
| 198 | 
            +
                    "--num_images_per_prompt",
         | 
| 199 | 
            +
                    type=int,
         | 
| 200 | 
            +
                    default=1,
         | 
| 201 | 
            +
                    help="Number of images per prompt",
         | 
| 202 | 
            +
                )
         | 
| 203 | 
            +
                parser.add_argument(
         | 
| 204 | 
            +
                    "--image_cond_noise_scale",
         | 
| 205 | 
            +
                    type=float,
         | 
| 206 | 
            +
                    default=0.15,
         | 
| 207 | 
            +
                    help="Amount of noise to add to the conditioned image",
         | 
| 208 | 
            +
                )
         | 
| 209 | 
            +
                parser.add_argument(
         | 
| 210 | 
            +
                    "--height",
         | 
| 211 | 
            +
                    type=int,
         | 
| 212 | 
            +
                    default=704,
         | 
| 213 | 
            +
                    help="Height of the output video frames. Optional if an input image provided.",
         | 
| 214 | 
            +
                )
         | 
| 215 | 
            +
                parser.add_argument(
         | 
| 216 | 
            +
                    "--width",
         | 
| 217 | 
            +
                    type=int,
         | 
| 218 | 
            +
                    default=1216,
         | 
| 219 | 
            +
                    help="Width of the output video frames. If None will infer from input image.",
         | 
| 220 | 
            +
                )
         | 
| 221 | 
            +
                parser.add_argument(
         | 
| 222 | 
            +
                    "--num_frames",
         | 
| 223 | 
            +
                    type=int,
         | 
| 224 | 
            +
                    default=121,
         | 
| 225 | 
            +
                    help="Number of frames to generate in the output video",
         | 
| 226 | 
            +
                )
         | 
| 227 | 
            +
                parser.add_argument(
         | 
| 228 | 
            +
                    "--frame_rate", type=int, default=30, help="Frame rate for the output video"
         | 
| 229 | 
            +
                )
         | 
| 230 | 
            +
                parser.add_argument(
         | 
| 231 | 
            +
                    "--device",
         | 
| 232 | 
            +
                    default=None,
         | 
| 233 | 
            +
                    help="Device to run inference on. If not specified, will automatically detect and use CUDA or MPS if available, else CPU.",
         | 
| 234 | 
            +
                )
         | 
| 235 | 
            +
                parser.add_argument(
         | 
| 236 | 
            +
                    "--pipeline_config",
         | 
| 237 | 
            +
                    type=str,
         | 
| 238 | 
            +
                    default="configs/ltxv-13b-0.9.7-dev.yaml",
         | 
| 239 | 
            +
                    help="The path to the config file for the pipeline, which contains the parameters for the pipeline",
         | 
| 240 | 
            +
                )
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                # Prompts
         | 
| 243 | 
            +
                parser.add_argument(
         | 
| 244 | 
            +
                    "--prompt",
         | 
| 245 | 
            +
                    type=str,
         | 
| 246 | 
            +
                    help="Text prompt to guide generation",
         | 
| 247 | 
            +
                )
         | 
| 248 | 
            +
                parser.add_argument(
         | 
| 249 | 
            +
                    "--negative_prompt",
         | 
| 250 | 
            +
                    type=str,
         | 
| 251 | 
            +
                    default="worst quality, inconsistent motion, blurry, jittery, distorted",
         | 
| 252 | 
            +
                    help="Negative prompt for undesired features",
         | 
| 253 | 
            +
                )
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                parser.add_argument(
         | 
| 256 | 
            +
                    "--offload_to_cpu",
         | 
| 257 | 
            +
                    action="store_true",
         | 
| 258 | 
            +
                    help="Offloading unnecessary computations to CPU.",
         | 
| 259 | 
            +
                )
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                # video-to-video arguments:
         | 
| 262 | 
            +
                parser.add_argument(
         | 
| 263 | 
            +
                    "--input_media_path",
         | 
| 264 | 
            +
                    type=str,
         | 
| 265 | 
            +
                    default=None,
         | 
| 266 | 
            +
                    help="Path to the input video (or imaage) to be modified using the video-to-video pipeline",
         | 
| 267 | 
            +
                )
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                parser.add_argument(
         | 
| 270 | 
            +
                    "--strength",
         | 
| 271 | 
            +
                    type=float,
         | 
| 272 | 
            +
                    default=1.0,
         | 
| 273 | 
            +
                    help="Editing strength (noising level) for video-to-video pipeline.",
         | 
| 274 | 
            +
                )
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                # Conditioning arguments
         | 
| 277 | 
            +
                parser.add_argument(
         | 
| 278 | 
            +
                    "--conditioning_media_paths",
         | 
| 279 | 
            +
                    type=str,
         | 
| 280 | 
            +
                    nargs="*",
         | 
| 281 | 
            +
                    help="List of paths to conditioning media (images or videos). Each path will be used as a conditioning item.",
         | 
| 282 | 
            +
                )
         | 
| 283 | 
            +
                parser.add_argument(
         | 
| 284 | 
            +
                    "--conditioning_strengths",
         | 
| 285 | 
            +
                    type=float,
         | 
| 286 | 
            +
                    nargs="*",
         | 
| 287 | 
            +
                    help="List of conditioning strengths (between 0 and 1) for each conditioning item. Must match the number of conditioning items.",
         | 
| 288 | 
            +
                )
         | 
| 289 | 
            +
                parser.add_argument(
         | 
| 290 | 
            +
                    "--conditioning_start_frames",
         | 
| 291 | 
            +
                    type=int,
         | 
| 292 | 
            +
                    nargs="*",
         | 
| 293 | 
            +
                    help="List of frame indices where each conditioning item should be applied. Must match the number of conditioning items.",
         | 
| 294 | 
            +
                )
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                args = parser.parse_args()
         | 
| 297 | 
            +
                logger.warning(f"Running generation with arguments: {args}")
         | 
| 298 | 
            +
                infer(**vars(args))
         | 
| 299 | 
            +
             | 
| 300 | 
            +
             | 
| 301 | 
            +
            def create_ltx_video_pipeline(
         | 
| 302 | 
            +
                ckpt_path: str,
         | 
| 303 | 
            +
                precision: str,
         | 
| 304 | 
            +
                text_encoder_model_name_or_path: str,
         | 
| 305 | 
            +
                sampler: Optional[str] = None,
         | 
| 306 | 
            +
                device: Optional[str] = None,
         | 
| 307 | 
            +
                enhance_prompt: bool = False,
         | 
| 308 | 
            +
                prompt_enhancer_image_caption_model_name_or_path: Optional[str] = None,
         | 
| 309 | 
            +
                prompt_enhancer_llm_model_name_or_path: Optional[str] = None,
         | 
| 310 | 
            +
            ) -> LTXVideoPipeline:
         | 
| 311 | 
            +
                ckpt_path = Path(ckpt_path)
         | 
| 312 | 
            +
                assert os.path.exists(
         | 
| 313 | 
            +
                    ckpt_path
         | 
| 314 | 
            +
                ), f"Ckpt path provided (--ckpt_path) {ckpt_path} does not exist"
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                with safe_open(ckpt_path, framework="pt") as f:
         | 
| 317 | 
            +
                    metadata = f.metadata()
         | 
| 318 | 
            +
                    config_str = metadata.get("config")
         | 
| 319 | 
            +
                    configs = json.loads(config_str)
         | 
| 320 | 
            +
                    allowed_inference_steps = configs.get("allowed_inference_steps", None)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
         | 
| 323 | 
            +
                transformer = Transformer3DModel.from_pretrained(ckpt_path)
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                # Use constructor if sampler is specified, otherwise use from_pretrained
         | 
| 326 | 
            +
                if sampler == "from_checkpoint" or not sampler:
         | 
| 327 | 
            +
                    scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)
         | 
| 328 | 
            +
                else:
         | 
| 329 | 
            +
                    scheduler = RectifiedFlowScheduler(
         | 
| 330 | 
            +
                        sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
         | 
| 331 | 
            +
                    )
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                text_encoder = T5EncoderModel.from_pretrained(
         | 
| 334 | 
            +
                    text_encoder_model_name_or_path, subfolder="text_encoder"
         | 
| 335 | 
            +
                )
         | 
| 336 | 
            +
                patchifier = SymmetricPatchifier(patch_size=1)
         | 
| 337 | 
            +
                tokenizer = T5Tokenizer.from_pretrained(
         | 
| 338 | 
            +
                    text_encoder_model_name_or_path, subfolder="tokenizer"
         | 
| 339 | 
            +
                )
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                transformer = transformer.to(device)
         | 
| 342 | 
            +
                vae = vae.to(device)
         | 
| 343 | 
            +
                text_encoder = text_encoder.to(device)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                if enhance_prompt:
         | 
| 346 | 
            +
                    prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(
         | 
| 347 | 
            +
                        prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
         | 
| 348 | 
            +
                    )
         | 
| 349 | 
            +
                    prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(
         | 
| 350 | 
            +
                        prompt_enhancer_image_caption_model_name_or_path, trust_remote_code=True
         | 
| 351 | 
            +
                    )
         | 
| 352 | 
            +
                    prompt_enhancer_llm_model = AutoModelForCausalLM.from_pretrained(
         | 
| 353 | 
            +
                        prompt_enhancer_llm_model_name_or_path,
         | 
| 354 | 
            +
                        torch_dtype="bfloat16",
         | 
| 355 | 
            +
                    )
         | 
| 356 | 
            +
                    prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(
         | 
| 357 | 
            +
                        prompt_enhancer_llm_model_name_or_path,
         | 
| 358 | 
            +
                    )
         | 
| 359 | 
            +
                else:
         | 
| 360 | 
            +
                    prompt_enhancer_image_caption_model = None
         | 
| 361 | 
            +
                    prompt_enhancer_image_caption_processor = None
         | 
| 362 | 
            +
                    prompt_enhancer_llm_model = None
         | 
| 363 | 
            +
                    prompt_enhancer_llm_tokenizer = None
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                vae = vae.to(torch.bfloat16)
         | 
| 366 | 
            +
                if precision == "bfloat16" and transformer.dtype != torch.bfloat16:
         | 
| 367 | 
            +
                    transformer = transformer.to(torch.bfloat16)
         | 
| 368 | 
            +
                text_encoder = text_encoder.to(torch.bfloat16)
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                # Use submodels for the pipeline
         | 
| 371 | 
            +
                submodel_dict = {
         | 
| 372 | 
            +
                    "transformer": transformer,
         | 
| 373 | 
            +
                    "patchifier": patchifier,
         | 
| 374 | 
            +
                    "text_encoder": text_encoder,
         | 
| 375 | 
            +
                    "tokenizer": tokenizer,
         | 
| 376 | 
            +
                    "scheduler": scheduler,
         | 
| 377 | 
            +
                    "vae": vae,
         | 
| 378 | 
            +
                    "prompt_enhancer_image_caption_model": prompt_enhancer_image_caption_model,
         | 
| 379 | 
            +
                    "prompt_enhancer_image_caption_processor": prompt_enhancer_image_caption_processor,
         | 
| 380 | 
            +
                    "prompt_enhancer_llm_model": prompt_enhancer_llm_model,
         | 
| 381 | 
            +
                    "prompt_enhancer_llm_tokenizer": prompt_enhancer_llm_tokenizer,
         | 
| 382 | 
            +
                    "allowed_inference_steps": allowed_inference_steps,
         | 
| 383 | 
            +
                }
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                pipeline = LTXVideoPipeline(**submodel_dict)
         | 
| 386 | 
            +
                pipeline = pipeline.to(device)
         | 
| 387 | 
            +
                return pipeline
         | 
| 388 | 
            +
             | 
| 389 | 
            +
             | 
| 390 | 
            +
            def create_latent_upsampler(latent_upsampler_model_path: str, device: str):
         | 
| 391 | 
            +
                latent_upsampler = LatentUpsampler.from_pretrained(latent_upsampler_model_path)
         | 
| 392 | 
            +
                latent_upsampler.to(device)
         | 
| 393 | 
            +
                latent_upsampler.eval()
         | 
| 394 | 
            +
                return latent_upsampler
         | 
| 395 | 
            +
             | 
| 396 | 
            +
             | 
| 397 | 
            +
            def infer(
         | 
| 398 | 
            +
                output_path: Optional[str],
         | 
| 399 | 
            +
                seed: int,
         | 
| 400 | 
            +
                pipeline_config: str,
         | 
| 401 | 
            +
                image_cond_noise_scale: float,
         | 
| 402 | 
            +
                height: Optional[int],
         | 
| 403 | 
            +
                width: Optional[int],
         | 
| 404 | 
            +
                num_frames: int,
         | 
| 405 | 
            +
                frame_rate: int,
         | 
| 406 | 
            +
                prompt: str,
         | 
| 407 | 
            +
                negative_prompt: str,
         | 
| 408 | 
            +
                offload_to_cpu: bool,
         | 
| 409 | 
            +
                input_media_path: Optional[str] = None,
         | 
| 410 | 
            +
                strength: Optional[float] = 1.0,
         | 
| 411 | 
            +
                conditioning_media_paths: Optional[List[str]] = None,
         | 
| 412 | 
            +
                conditioning_strengths: Optional[List[float]] = None,
         | 
| 413 | 
            +
                conditioning_start_frames: Optional[List[int]] = None,
         | 
| 414 | 
            +
                device: Optional[str] = None,
         | 
| 415 | 
            +
                **kwargs,
         | 
| 416 | 
            +
            ):
         | 
| 417 | 
            +
                # check if pipeline_config is a file
         | 
| 418 | 
            +
                if not os.path.isfile(pipeline_config):
         | 
| 419 | 
            +
                    raise ValueError(f"Pipeline config file {pipeline_config} does not exist")
         | 
| 420 | 
            +
                with open(pipeline_config, "r") as f:
         | 
| 421 | 
            +
                    pipeline_config = yaml.safe_load(f)
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                models_dir = "MODEL_DIR"
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                #ltxv_model_name_or_path = pipeline_config["checkpoint_path"]
         | 
| 426 | 
            +
                ltxv_model_name_or_path = "ltxv-13b-0.9.7-distilled-rc3.safetensors"
         | 
| 427 | 
            +
                if not os.path.isfile(ltxv_model_name_or_path):
         | 
| 428 | 
            +
                    ltxv_model_path = hf_hub_download(
         | 
| 429 | 
            +
                        repo_id="LTX-Colab/LTX-Video-Preview",
         | 
| 430 | 
            +
                        #repo_id="Lightricks/LTX-Video",
         | 
| 431 | 
            +
                        filename=ltxv_model_name_or_path,
         | 
| 432 | 
            +
                        local_dir=models_dir,
         | 
| 433 | 
            +
                        repo_type="model",
         | 
| 434 | 
            +
                    )
         | 
| 435 | 
            +
                else:
         | 
| 436 | 
            +
                    ltxv_model_path = ltxv_model_name_or_path
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                spatial_upscaler_model_name_or_path = pipeline_config.get(
         | 
| 439 | 
            +
                    "spatial_upscaler_model_path"
         | 
| 440 | 
            +
                )
         | 
| 441 | 
            +
                if spatial_upscaler_model_name_or_path and not os.path.isfile(
         | 
| 442 | 
            +
                    spatial_upscaler_model_name_or_path
         | 
| 443 | 
            +
                ):
         | 
| 444 | 
            +
                    spatial_upscaler_model_path = hf_hub_download(
         | 
| 445 | 
            +
                        repo_id="Lightricks/LTX-Video",
         | 
| 446 | 
            +
                        filename=spatial_upscaler_model_name_or_path,
         | 
| 447 | 
            +
                        local_dir=models_dir,
         | 
| 448 | 
            +
                        repo_type="model",
         | 
| 449 | 
            +
                    )
         | 
| 450 | 
            +
                else:
         | 
| 451 | 
            +
                    spatial_upscaler_model_path = spatial_upscaler_model_name_or_path
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                if kwargs.get("input_image_path", None):
         | 
| 454 | 
            +
                    logger.warning(
         | 
| 455 | 
            +
                        "Please use conditioning_media_paths instead of input_image_path."
         | 
| 456 | 
            +
                    )
         | 
| 457 | 
            +
                    assert not conditioning_media_paths and not conditioning_start_frames
         | 
| 458 | 
            +
                    conditioning_media_paths = [kwargs["input_image_path"]]
         | 
| 459 | 
            +
                    conditioning_start_frames = [0]
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                # Validate conditioning arguments
         | 
| 462 | 
            +
                if conditioning_media_paths:
         | 
| 463 | 
            +
                    # Use default strengths of 1.0
         | 
| 464 | 
            +
                    if not conditioning_strengths:
         | 
| 465 | 
            +
                        conditioning_strengths = [1.0] * len(conditioning_media_paths)
         | 
| 466 | 
            +
                    if not conditioning_start_frames:
         | 
| 467 | 
            +
                        raise ValueError(
         | 
| 468 | 
            +
                            "If `conditioning_media_paths` is provided, "
         | 
| 469 | 
            +
                            "`conditioning_start_frames` must also be provided"
         | 
| 470 | 
            +
                        )
         | 
| 471 | 
            +
                    if len(conditioning_media_paths) != len(conditioning_strengths) or len(
         | 
| 472 | 
            +
                        conditioning_media_paths
         | 
| 473 | 
            +
                    ) != len(conditioning_start_frames):
         | 
| 474 | 
            +
                        raise ValueError(
         | 
| 475 | 
            +
                            "`conditioning_media_paths`, `conditioning_strengths`, "
         | 
| 476 | 
            +
                            "and `conditioning_start_frames` must have the same length"
         | 
| 477 | 
            +
                        )
         | 
| 478 | 
            +
                    if any(s < 0 or s > 1 for s in conditioning_strengths):
         | 
| 479 | 
            +
                        raise ValueError("All conditioning strengths must be between 0 and 1")
         | 
| 480 | 
            +
                    if any(f < 0 or f >= num_frames for f in conditioning_start_frames):
         | 
| 481 | 
            +
                        raise ValueError(
         | 
| 482 | 
            +
                            f"All conditioning start frames must be between 0 and {num_frames-1}"
         | 
| 483 | 
            +
                        )
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                seed_everething(seed)
         | 
| 486 | 
            +
                if offload_to_cpu and not torch.cuda.is_available():
         | 
| 487 | 
            +
                    logger.warning(
         | 
| 488 | 
            +
                        "offload_to_cpu is set to True, but offloading will not occur since the model is already running on CPU."
         | 
| 489 | 
            +
                    )
         | 
| 490 | 
            +
                    offload_to_cpu = False
         | 
| 491 | 
            +
                else:
         | 
| 492 | 
            +
                    offload_to_cpu = offload_to_cpu and get_total_gpu_memory() < 30
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                output_dir = (
         | 
| 495 | 
            +
                    Path(output_path)
         | 
| 496 | 
            +
                    if output_path
         | 
| 497 | 
            +
                    else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}")
         | 
| 498 | 
            +
                )
         | 
| 499 | 
            +
                output_dir.mkdir(parents=True, exist_ok=True)
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1)
         | 
| 502 | 
            +
                height_padded = ((height - 1) // 32 + 1) * 32
         | 
| 503 | 
            +
                width_padded = ((width - 1) // 32 + 1) * 32
         | 
| 504 | 
            +
                num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                padding = calculate_padding(height, width, height_padded, width_padded)
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                logger.warning(
         | 
| 509 | 
            +
                    f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}"
         | 
| 510 | 
            +
                )
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                prompt_enhancement_words_threshold = pipeline_config[
         | 
| 513 | 
            +
                    "prompt_enhancement_words_threshold"
         | 
| 514 | 
            +
                ]
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                prompt_word_count = len(prompt.split())
         | 
| 517 | 
            +
                enhance_prompt = (
         | 
| 518 | 
            +
                    prompt_enhancement_words_threshold > 0
         | 
| 519 | 
            +
                    and prompt_word_count < prompt_enhancement_words_threshold
         | 
| 520 | 
            +
                )
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                if prompt_enhancement_words_threshold > 0 and not enhance_prompt:
         | 
| 523 | 
            +
                    logger.info(
         | 
| 524 | 
            +
                        f"Prompt has {prompt_word_count} words, which exceeds the threshold of {prompt_enhancement_words_threshold}. Prompt enhancement disabled."
         | 
| 525 | 
            +
                    )
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                precision = pipeline_config["precision"]
         | 
| 528 | 
            +
                text_encoder_model_name_or_path = pipeline_config["text_encoder_model_name_or_path"]
         | 
| 529 | 
            +
                sampler = pipeline_config["sampler"]
         | 
| 530 | 
            +
                prompt_enhancer_image_caption_model_name_or_path = pipeline_config[
         | 
| 531 | 
            +
                    "prompt_enhancer_image_caption_model_name_or_path"
         | 
| 532 | 
            +
                ]
         | 
| 533 | 
            +
                prompt_enhancer_llm_model_name_or_path = pipeline_config[
         | 
| 534 | 
            +
                    "prompt_enhancer_llm_model_name_or_path"
         | 
| 535 | 
            +
                ]
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                pipeline = create_ltx_video_pipeline(
         | 
| 538 | 
            +
                    ckpt_path=ltxv_model_path,
         | 
| 539 | 
            +
                    precision=precision,
         | 
| 540 | 
            +
                    text_encoder_model_name_or_path=text_encoder_model_name_or_path,
         | 
| 541 | 
            +
                    sampler=sampler,
         | 
| 542 | 
            +
                    device=kwargs.get("device", get_device()),
         | 
| 543 | 
            +
                    enhance_prompt=enhance_prompt,
         | 
| 544 | 
            +
                    prompt_enhancer_image_caption_model_name_or_path=prompt_enhancer_image_caption_model_name_or_path,
         | 
| 545 | 
            +
                    prompt_enhancer_llm_model_name_or_path=prompt_enhancer_llm_model_name_or_path,
         | 
| 546 | 
            +
                )
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                if pipeline_config.get("pipeline_type", None) == "multi-scale":
         | 
| 549 | 
            +
                    if not spatial_upscaler_model_path:
         | 
| 550 | 
            +
                        raise ValueError(
         | 
| 551 | 
            +
                            "spatial upscaler model path is missing from pipeline config file and is required for multi-scale rendering"
         | 
| 552 | 
            +
                        )
         | 
| 553 | 
            +
                    latent_upsampler = create_latent_upsampler(
         | 
| 554 | 
            +
                        spatial_upscaler_model_path, pipeline.device
         | 
| 555 | 
            +
                    )
         | 
| 556 | 
            +
                    pipeline = LTXMultiScalePipeline(pipeline, latent_upsampler=latent_upsampler)
         | 
| 557 | 
            +
             | 
| 558 | 
            +
                media_item = None
         | 
| 559 | 
            +
                if input_media_path:
         | 
| 560 | 
            +
                    media_item = load_media_file(
         | 
| 561 | 
            +
                        media_path=input_media_path,
         | 
| 562 | 
            +
                        height=height,
         | 
| 563 | 
            +
                        width=width,
         | 
| 564 | 
            +
                        max_frames=num_frames_padded,
         | 
| 565 | 
            +
                        padding=padding,
         | 
| 566 | 
            +
                    )
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                conditioning_items = (
         | 
| 569 | 
            +
                    prepare_conditioning(
         | 
| 570 | 
            +
                        conditioning_media_paths=conditioning_media_paths,
         | 
| 571 | 
            +
                        conditioning_strengths=conditioning_strengths,
         | 
| 572 | 
            +
                        conditioning_start_frames=conditioning_start_frames,
         | 
| 573 | 
            +
                        height=height,
         | 
| 574 | 
            +
                        width=width,
         | 
| 575 | 
            +
                        num_frames=num_frames,
         | 
| 576 | 
            +
                        padding=padding,
         | 
| 577 | 
            +
                        pipeline=pipeline,
         | 
| 578 | 
            +
                    )
         | 
| 579 | 
            +
                    if conditioning_media_paths
         | 
| 580 | 
            +
                    else None
         | 
| 581 | 
            +
                )
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                stg_mode = pipeline_config.get("stg_mode", "attention_values")
         | 
| 584 | 
            +
                del pipeline_config["stg_mode"]
         | 
| 585 | 
            +
                if stg_mode.lower() == "stg_av" or stg_mode.lower() == "attention_values":
         | 
| 586 | 
            +
                    skip_layer_strategy = SkipLayerStrategy.AttentionValues
         | 
| 587 | 
            +
                elif stg_mode.lower() == "stg_as" or stg_mode.lower() == "attention_skip":
         | 
| 588 | 
            +
                    skip_layer_strategy = SkipLayerStrategy.AttentionSkip
         | 
| 589 | 
            +
                elif stg_mode.lower() == "stg_r" or stg_mode.lower() == "residual":
         | 
| 590 | 
            +
                    skip_layer_strategy = SkipLayerStrategy.Residual
         | 
| 591 | 
            +
                elif stg_mode.lower() == "stg_t" or stg_mode.lower() == "transformer_block":
         | 
| 592 | 
            +
                    skip_layer_strategy = SkipLayerStrategy.TransformerBlock
         | 
| 593 | 
            +
                else:
         | 
| 594 | 
            +
                    raise ValueError(f"Invalid spatiotemporal guidance mode: {stg_mode}")
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                # Prepare input for the pipeline
         | 
| 597 | 
            +
                sample = {
         | 
| 598 | 
            +
                    "prompt": prompt,
         | 
| 599 | 
            +
                    "prompt_attention_mask": None,
         | 
| 600 | 
            +
                    "negative_prompt": negative_prompt,
         | 
| 601 | 
            +
                    "negative_prompt_attention_mask": None,
         | 
| 602 | 
            +
                }
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                device = device or get_device()
         | 
| 605 | 
            +
                generator = torch.Generator(device=device).manual_seed(seed)
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                images = pipeline(
         | 
| 608 | 
            +
                    **pipeline_config,
         | 
| 609 | 
            +
                    skip_layer_strategy=skip_layer_strategy,
         | 
| 610 | 
            +
                    generator=generator,
         | 
| 611 | 
            +
                    output_type="pt",
         | 
| 612 | 
            +
                    callback_on_step_end=None,
         | 
| 613 | 
            +
                    height=height_padded,
         | 
| 614 | 
            +
                    width=width_padded,
         | 
| 615 | 
            +
                    num_frames=num_frames_padded,
         | 
| 616 | 
            +
                    frame_rate=frame_rate,
         | 
| 617 | 
            +
                    **sample,
         | 
| 618 | 
            +
                    media_items=media_item,
         | 
| 619 | 
            +
                    strength=strength,
         | 
| 620 | 
            +
                    conditioning_items=conditioning_items,
         | 
| 621 | 
            +
                    is_video=True,
         | 
| 622 | 
            +
                    vae_per_channel_normalize=True,
         | 
| 623 | 
            +
                    image_cond_noise_scale=image_cond_noise_scale,
         | 
| 624 | 
            +
                    mixed_precision=(precision == "mixed_precision"),
         | 
| 625 | 
            +
                    offload_to_cpu=offload_to_cpu,
         | 
| 626 | 
            +
                    device=device,
         | 
| 627 | 
            +
                    enhance_prompt=enhance_prompt,
         | 
| 628 | 
            +
                ).images
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                # Crop the padded images to the desired resolution and number of frames
         | 
| 631 | 
            +
                (pad_left, pad_right, pad_top, pad_bottom) = padding
         | 
| 632 | 
            +
                pad_bottom = -pad_bottom
         | 
| 633 | 
            +
                pad_right = -pad_right
         | 
| 634 | 
            +
                if pad_bottom == 0:
         | 
| 635 | 
            +
                    pad_bottom = images.shape[3]
         | 
| 636 | 
            +
                if pad_right == 0:
         | 
| 637 | 
            +
                    pad_right = images.shape[4]
         | 
| 638 | 
            +
                images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right]
         | 
| 639 | 
            +
             | 
| 640 | 
            +
                for i in range(images.shape[0]):
         | 
| 641 | 
            +
                    # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C
         | 
| 642 | 
            +
                    video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy()
         | 
| 643 | 
            +
                    # Unnormalizing images to [0, 255] range
         | 
| 644 | 
            +
                    video_np = (video_np * 255).astype(np.uint8)
         | 
| 645 | 
            +
                    fps = frame_rate
         | 
| 646 | 
            +
                    height, width = video_np.shape[1:3]
         | 
| 647 | 
            +
                    # In case a single image is generated
         | 
| 648 | 
            +
                    if video_np.shape[0] == 1:
         | 
| 649 | 
            +
                        output_filename = get_unique_filename(
         | 
| 650 | 
            +
                            f"image_output_{i}",
         | 
| 651 | 
            +
                            ".png",
         | 
| 652 | 
            +
                            prompt=prompt,
         | 
| 653 | 
            +
                            seed=seed,
         | 
| 654 | 
            +
                            resolution=(height, width, num_frames),
         | 
| 655 | 
            +
                            dir=output_dir,
         | 
| 656 | 
            +
                        )
         | 
| 657 | 
            +
                        imageio.imwrite(output_filename, video_np[0])
         | 
| 658 | 
            +
                    else:
         | 
| 659 | 
            +
                        output_filename = get_unique_filename(
         | 
| 660 | 
            +
                            f"video_output_{i}",
         | 
| 661 | 
            +
                            ".mp4",
         | 
| 662 | 
            +
                            prompt=prompt,
         | 
| 663 | 
            +
                            seed=seed,
         | 
| 664 | 
            +
                            resolution=(height, width, num_frames),
         | 
| 665 | 
            +
                            dir=output_dir,
         | 
| 666 | 
            +
                        )
         | 
| 667 | 
            +
             | 
| 668 | 
            +
                        # Write video
         | 
| 669 | 
            +
                        with imageio.get_writer(output_filename, fps=fps) as video:
         | 
| 670 | 
            +
                            for frame in video_np:
         | 
| 671 | 
            +
                                video.append_data(frame)
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                    logger.warning(f"Output saved to {output_filename}")
         | 
| 674 | 
            +
             | 
| 675 | 
            +
             | 
| 676 | 
            +
            def prepare_conditioning(
         | 
| 677 | 
            +
                conditioning_media_paths: List[str],
         | 
| 678 | 
            +
                conditioning_strengths: List[float],
         | 
| 679 | 
            +
                conditioning_start_frames: List[int],
         | 
| 680 | 
            +
                height: int,
         | 
| 681 | 
            +
                width: int,
         | 
| 682 | 
            +
                num_frames: int,
         | 
| 683 | 
            +
                padding: tuple[int, int, int, int],
         | 
| 684 | 
            +
                pipeline: LTXVideoPipeline,
         | 
| 685 | 
            +
            ) -> Optional[List[ConditioningItem]]:
         | 
| 686 | 
            +
                """Prepare conditioning items based on input media paths and their parameters.
         | 
| 687 | 
            +
             | 
| 688 | 
            +
                Args:
         | 
| 689 | 
            +
                    conditioning_media_paths: List of paths to conditioning media (images or videos)
         | 
| 690 | 
            +
                    conditioning_strengths: List of conditioning strengths for each media item
         | 
| 691 | 
            +
                    conditioning_start_frames: List of frame indices where each item should be applied
         | 
| 692 | 
            +
                    height: Height of the output frames
         | 
| 693 | 
            +
                    width: Width of the output frames
         | 
| 694 | 
            +
                    num_frames: Number of frames in the output video
         | 
| 695 | 
            +
                    padding: Padding to apply to the frames
         | 
| 696 | 
            +
                    pipeline: LTXVideoPipeline object used for condition video trimming
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                Returns:
         | 
| 699 | 
            +
                    A list of ConditioningItem objects.
         | 
| 700 | 
            +
                """
         | 
| 701 | 
            +
                conditioning_items = []
         | 
| 702 | 
            +
                for path, strength, start_frame in zip(
         | 
| 703 | 
            +
                    conditioning_media_paths, conditioning_strengths, conditioning_start_frames
         | 
| 704 | 
            +
                ):
         | 
| 705 | 
            +
                    num_input_frames = orig_num_input_frames = get_media_num_frames(path)
         | 
| 706 | 
            +
                    if hasattr(pipeline, "trim_conditioning_sequence") and callable(
         | 
| 707 | 
            +
                        getattr(pipeline, "trim_conditioning_sequence")
         | 
| 708 | 
            +
                    ):
         | 
| 709 | 
            +
                        num_input_frames = pipeline.trim_conditioning_sequence(
         | 
| 710 | 
            +
                            start_frame, orig_num_input_frames, num_frames
         | 
| 711 | 
            +
                        )
         | 
| 712 | 
            +
                    if num_input_frames < orig_num_input_frames:
         | 
| 713 | 
            +
                        logger.warning(
         | 
| 714 | 
            +
                            f"Trimming conditioning video {path} from {orig_num_input_frames} to {num_input_frames} frames."
         | 
| 715 | 
            +
                        )
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                    media_tensor = load_media_file(
         | 
| 718 | 
            +
                        media_path=path,
         | 
| 719 | 
            +
                        height=height,
         | 
| 720 | 
            +
                        width=width,
         | 
| 721 | 
            +
                        max_frames=num_input_frames,
         | 
| 722 | 
            +
                        padding=padding,
         | 
| 723 | 
            +
                        just_crop=True,
         | 
| 724 | 
            +
                    )
         | 
| 725 | 
            +
                    conditioning_items.append(ConditioningItem(media_tensor, start_frame, strength))
         | 
| 726 | 
            +
                return conditioning_items
         | 
| 727 | 
            +
             | 
| 728 | 
            +
             | 
| 729 | 
            +
            def get_media_num_frames(media_path: str) -> int:
         | 
| 730 | 
            +
                is_video = any(
         | 
| 731 | 
            +
                    media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
         | 
| 732 | 
            +
                )
         | 
| 733 | 
            +
                num_frames = 1
         | 
| 734 | 
            +
                if is_video:
         | 
| 735 | 
            +
                    reader = imageio.get_reader(media_path)
         | 
| 736 | 
            +
                    num_frames = reader.count_frames()
         | 
| 737 | 
            +
                    reader.close()
         | 
| 738 | 
            +
                return num_frames
         | 
| 739 | 
            +
             | 
| 740 | 
            +
             | 
| 741 | 
            +
            def load_media_file(
         | 
| 742 | 
            +
                media_path: str,
         | 
| 743 | 
            +
                height: int,
         | 
| 744 | 
            +
                width: int,
         | 
| 745 | 
            +
                max_frames: int,
         | 
| 746 | 
            +
                padding: tuple[int, int, int, int],
         | 
| 747 | 
            +
                just_crop: bool = False,
         | 
| 748 | 
            +
            ) -> torch.Tensor:
         | 
| 749 | 
            +
                is_video = any(
         | 
| 750 | 
            +
                    media_path.lower().endswith(ext) for ext in [".mp4", ".avi", ".mov", ".mkv"]
         | 
| 751 | 
            +
                )
         | 
| 752 | 
            +
                if is_video:
         | 
| 753 | 
            +
                    reader = imageio.get_reader(media_path)
         | 
| 754 | 
            +
                    num_input_frames = min(reader.count_frames(), max_frames)
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                    # Read and preprocess the relevant frames from the video file.
         | 
| 757 | 
            +
                    frames = []
         | 
| 758 | 
            +
                    for i in range(num_input_frames):
         | 
| 759 | 
            +
                        frame = Image.fromarray(reader.get_data(i))
         | 
| 760 | 
            +
                        frame_tensor = load_image_to_tensor_with_resize_and_crop(
         | 
| 761 | 
            +
                            frame, height, width, just_crop=just_crop
         | 
| 762 | 
            +
                        )
         | 
| 763 | 
            +
                        frame_tensor = torch.nn.functional.pad(frame_tensor, padding)
         | 
| 764 | 
            +
                        frames.append(frame_tensor)
         | 
| 765 | 
            +
                    reader.close()
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                    # Stack frames along the temporal dimension
         | 
| 768 | 
            +
                    media_tensor = torch.cat(frames, dim=2)
         | 
| 769 | 
            +
                else:  # Input image
         | 
| 770 | 
            +
                    media_tensor = load_image_to_tensor_with_resize_and_crop(
         | 
| 771 | 
            +
                        media_path, height, width, just_crop=just_crop
         | 
| 772 | 
            +
                    )
         | 
| 773 | 
            +
                    media_tensor = torch.nn.functional.pad(media_tensor, padding)
         | 
| 774 | 
            +
                return media_tensor
         | 
| 775 | 
            +
             | 
| 776 | 
            +
             | 
| 777 | 
            +
            if __name__ == "__main__":
         | 
| 778 | 
            +
                main()
         | 
    	
        ltx_video/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        ltx_video/models/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        ltx_video/models/autoencoders/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        ltx_video/models/autoencoders/causal_conv3d.py
    ADDED
    
    | @@ -0,0 +1,63 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Tuple, Union
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class CausalConv3d(nn.Module):
         | 
| 8 | 
            +
                def __init__(
         | 
| 9 | 
            +
                    self,
         | 
| 10 | 
            +
                    in_channels,
         | 
| 11 | 
            +
                    out_channels,
         | 
| 12 | 
            +
                    kernel_size: int = 3,
         | 
| 13 | 
            +
                    stride: Union[int, Tuple[int]] = 1,
         | 
| 14 | 
            +
                    dilation: int = 1,
         | 
| 15 | 
            +
                    groups: int = 1,
         | 
| 16 | 
            +
                    spatial_padding_mode: str = "zeros",
         | 
| 17 | 
            +
                    **kwargs,
         | 
| 18 | 
            +
                ):
         | 
| 19 | 
            +
                    super().__init__()
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    self.in_channels = in_channels
         | 
| 22 | 
            +
                    self.out_channels = out_channels
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    kernel_size = (kernel_size, kernel_size, kernel_size)
         | 
| 25 | 
            +
                    self.time_kernel_size = kernel_size[0]
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    dilation = (dilation, 1, 1)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    height_pad = kernel_size[1] // 2
         | 
| 30 | 
            +
                    width_pad = kernel_size[2] // 2
         | 
| 31 | 
            +
                    padding = (0, height_pad, width_pad)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    self.conv = nn.Conv3d(
         | 
| 34 | 
            +
                        in_channels,
         | 
| 35 | 
            +
                        out_channels,
         | 
| 36 | 
            +
                        kernel_size,
         | 
| 37 | 
            +
                        stride=stride,
         | 
| 38 | 
            +
                        dilation=dilation,
         | 
| 39 | 
            +
                        padding=padding,
         | 
| 40 | 
            +
                        padding_mode=spatial_padding_mode,
         | 
| 41 | 
            +
                        groups=groups,
         | 
| 42 | 
            +
                    )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def forward(self, x, causal: bool = True):
         | 
| 45 | 
            +
                    if causal:
         | 
| 46 | 
            +
                        first_frame_pad = x[:, :, :1, :, :].repeat(
         | 
| 47 | 
            +
                            (1, 1, self.time_kernel_size - 1, 1, 1)
         | 
| 48 | 
            +
                        )
         | 
| 49 | 
            +
                        x = torch.concatenate((first_frame_pad, x), dim=2)
         | 
| 50 | 
            +
                    else:
         | 
| 51 | 
            +
                        first_frame_pad = x[:, :, :1, :, :].repeat(
         | 
| 52 | 
            +
                            (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
         | 
| 53 | 
            +
                        )
         | 
| 54 | 
            +
                        last_frame_pad = x[:, :, -1:, :, :].repeat(
         | 
| 55 | 
            +
                            (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
         | 
| 56 | 
            +
                        )
         | 
| 57 | 
            +
                        x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
         | 
| 58 | 
            +
                    x = self.conv(x)
         | 
| 59 | 
            +
                    return x
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                @property
         | 
| 62 | 
            +
                def weight(self):
         | 
| 63 | 
            +
                    return self.conv.weight
         | 
    	
        ltx_video/models/autoencoders/causal_video_autoencoder.py
    ADDED
    
    | @@ -0,0 +1,1403 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from functools import partial
         | 
| 4 | 
            +
            from types import SimpleNamespace
         | 
| 5 | 
            +
            from typing import Any, Mapping, Optional, Tuple, Union, List
         | 
| 6 | 
            +
            from pathlib import Path
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            from einops import rearrange
         | 
| 11 | 
            +
            from torch import nn
         | 
| 12 | 
            +
            from diffusers.utils import logging
         | 
| 13 | 
            +
            import torch.nn.functional as F
         | 
| 14 | 
            +
            from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
         | 
| 15 | 
            +
            from safetensors import safe_open
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
         | 
| 19 | 
            +
            from ltx_video.models.autoencoders.pixel_norm import PixelNorm
         | 
| 20 | 
            +
            from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
         | 
| 21 | 
            +
            from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
         | 
| 22 | 
            +
            from ltx_video.models.transformers.attention import Attention
         | 
| 23 | 
            +
            from ltx_video.utils.diffusers_config_mapping import (
         | 
| 24 | 
            +
                diffusers_and_ours_config_mapping,
         | 
| 25 | 
            +
                make_hashable_key,
         | 
| 26 | 
            +
                VAE_KEYS_RENAME_DICT,
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics."
         | 
| 30 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class CausalVideoAutoencoder(AutoencoderKLWrapper):
         | 
| 34 | 
            +
                @classmethod
         | 
| 35 | 
            +
                def from_pretrained(
         | 
| 36 | 
            +
                    cls,
         | 
| 37 | 
            +
                    pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
         | 
| 38 | 
            +
                    *args,
         | 
| 39 | 
            +
                    **kwargs,
         | 
| 40 | 
            +
                ):
         | 
| 41 | 
            +
                    pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
         | 
| 42 | 
            +
                    if (
         | 
| 43 | 
            +
                        pretrained_model_name_or_path.is_dir()
         | 
| 44 | 
            +
                        and (pretrained_model_name_or_path / "autoencoder.pth").exists()
         | 
| 45 | 
            +
                    ):
         | 
| 46 | 
            +
                        config_local_path = pretrained_model_name_or_path / "config.json"
         | 
| 47 | 
            +
                        config = cls.load_config(config_local_path, **kwargs)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                        model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
         | 
| 50 | 
            +
                        state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                        statistics_local_path = (
         | 
| 53 | 
            +
                            pretrained_model_name_or_path / "per_channel_statistics.json"
         | 
| 54 | 
            +
                        )
         | 
| 55 | 
            +
                        if statistics_local_path.exists():
         | 
| 56 | 
            +
                            with open(statistics_local_path, "r") as file:
         | 
| 57 | 
            +
                                data = json.load(file)
         | 
| 58 | 
            +
                            transposed_data = list(zip(*data["data"]))
         | 
| 59 | 
            +
                            data_dict = {
         | 
| 60 | 
            +
                                col: torch.tensor(vals)
         | 
| 61 | 
            +
                                for col, vals in zip(data["columns"], transposed_data)
         | 
| 62 | 
            +
                            }
         | 
| 63 | 
            +
                            std_of_means = data_dict["std-of-means"]
         | 
| 64 | 
            +
                            mean_of_means = data_dict.get(
         | 
| 65 | 
            +
                                "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
         | 
| 66 | 
            +
                            )
         | 
| 67 | 
            +
                            state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = (
         | 
| 68 | 
            +
                                std_of_means
         | 
| 69 | 
            +
                            )
         | 
| 70 | 
            +
                            state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = (
         | 
| 71 | 
            +
                                mean_of_means
         | 
| 72 | 
            +
                            )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    elif pretrained_model_name_or_path.is_dir():
         | 
| 75 | 
            +
                        config_path = pretrained_model_name_or_path / "vae" / "config.json"
         | 
| 76 | 
            +
                        with open(config_path, "r") as f:
         | 
| 77 | 
            +
                            config = make_hashable_key(json.load(f))
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                        assert config in diffusers_and_ours_config_mapping, (
         | 
| 80 | 
            +
                            "Provided diffusers checkpoint config for VAE is not suppported. "
         | 
| 81 | 
            +
                            "We only support diffusers configs found in Lightricks/LTX-Video."
         | 
| 82 | 
            +
                        )
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                        config = diffusers_and_ours_config_mapping[config]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                        state_dict_path = (
         | 
| 87 | 
            +
                            pretrained_model_name_or_path
         | 
| 88 | 
            +
                            / "vae"
         | 
| 89 | 
            +
                            / "diffusion_pytorch_model.safetensors"
         | 
| 90 | 
            +
                        )
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                        state_dict = {}
         | 
| 93 | 
            +
                        with safe_open(state_dict_path, framework="pt", device="cpu") as f:
         | 
| 94 | 
            +
                            for k in f.keys():
         | 
| 95 | 
            +
                                state_dict[k] = f.get_tensor(k)
         | 
| 96 | 
            +
                        for key in list(state_dict.keys()):
         | 
| 97 | 
            +
                            new_key = key
         | 
| 98 | 
            +
                            for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
         | 
| 99 | 
            +
                                new_key = new_key.replace(replace_key, rename_key)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                            state_dict[new_key] = state_dict.pop(key)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    elif pretrained_model_name_or_path.is_file() and str(
         | 
| 104 | 
            +
                        pretrained_model_name_or_path
         | 
| 105 | 
            +
                    ).endswith(".safetensors"):
         | 
| 106 | 
            +
                        state_dict = {}
         | 
| 107 | 
            +
                        with safe_open(
         | 
| 108 | 
            +
                            pretrained_model_name_or_path, framework="pt", device="cpu"
         | 
| 109 | 
            +
                        ) as f:
         | 
| 110 | 
            +
                            metadata = f.metadata()
         | 
| 111 | 
            +
                            for k in f.keys():
         | 
| 112 | 
            +
                                state_dict[k] = f.get_tensor(k)
         | 
| 113 | 
            +
                        configs = json.loads(metadata["config"])
         | 
| 114 | 
            +
                        config = configs["vae"]
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    video_vae = cls.from_config(config)
         | 
| 117 | 
            +
                    if "torch_dtype" in kwargs:
         | 
| 118 | 
            +
                        video_vae.to(kwargs["torch_dtype"])
         | 
| 119 | 
            +
                    video_vae.load_state_dict(state_dict)
         | 
| 120 | 
            +
                    return video_vae
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                @staticmethod
         | 
| 123 | 
            +
                def from_config(config):
         | 
| 124 | 
            +
                    assert (
         | 
| 125 | 
            +
                        config["_class_name"] == "CausalVideoAutoencoder"
         | 
| 126 | 
            +
                    ), "config must have _class_name=CausalVideoAutoencoder"
         | 
| 127 | 
            +
                    if isinstance(config["dims"], list):
         | 
| 128 | 
            +
                        config["dims"] = tuple(config["dims"])
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    double_z = config.get("double_z", True)
         | 
| 133 | 
            +
                    latent_log_var = config.get(
         | 
| 134 | 
            +
                        "latent_log_var", "per_channel" if double_z else "none"
         | 
| 135 | 
            +
                    )
         | 
| 136 | 
            +
                    use_quant_conv = config.get("use_quant_conv", True)
         | 
| 137 | 
            +
                    normalize_latent_channels = config.get("normalize_latent_channels", False)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    if use_quant_conv and latent_log_var in ["uniform", "constant"]:
         | 
| 140 | 
            +
                        raise ValueError(
         | 
| 141 | 
            +
                            f"latent_log_var={latent_log_var} requires use_quant_conv=False"
         | 
| 142 | 
            +
                        )
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    encoder = Encoder(
         | 
| 145 | 
            +
                        dims=config["dims"],
         | 
| 146 | 
            +
                        in_channels=config.get("in_channels", 3),
         | 
| 147 | 
            +
                        out_channels=config["latent_channels"],
         | 
| 148 | 
            +
                        blocks=config.get("encoder_blocks", config.get("blocks")),
         | 
| 149 | 
            +
                        patch_size=config.get("patch_size", 1),
         | 
| 150 | 
            +
                        latent_log_var=latent_log_var,
         | 
| 151 | 
            +
                        norm_layer=config.get("norm_layer", "group_norm"),
         | 
| 152 | 
            +
                        base_channels=config.get("encoder_base_channels", 128),
         | 
| 153 | 
            +
                        spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
         | 
| 154 | 
            +
                    )
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    decoder = Decoder(
         | 
| 157 | 
            +
                        dims=config["dims"],
         | 
| 158 | 
            +
                        in_channels=config["latent_channels"],
         | 
| 159 | 
            +
                        out_channels=config.get("out_channels", 3),
         | 
| 160 | 
            +
                        blocks=config.get("decoder_blocks", config.get("blocks")),
         | 
| 161 | 
            +
                        patch_size=config.get("patch_size", 1),
         | 
| 162 | 
            +
                        norm_layer=config.get("norm_layer", "group_norm"),
         | 
| 163 | 
            +
                        causal=config.get("causal_decoder", False),
         | 
| 164 | 
            +
                        timestep_conditioning=config.get("timestep_conditioning", False),
         | 
| 165 | 
            +
                        base_channels=config.get("decoder_base_channels", 128),
         | 
| 166 | 
            +
                        spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
         | 
| 167 | 
            +
                    )
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    dims = config["dims"]
         | 
| 170 | 
            +
                    return CausalVideoAutoencoder(
         | 
| 171 | 
            +
                        encoder=encoder,
         | 
| 172 | 
            +
                        decoder=decoder,
         | 
| 173 | 
            +
                        latent_channels=config["latent_channels"],
         | 
| 174 | 
            +
                        dims=dims,
         | 
| 175 | 
            +
                        use_quant_conv=use_quant_conv,
         | 
| 176 | 
            +
                        normalize_latent_channels=normalize_latent_channels,
         | 
| 177 | 
            +
                    )
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                @property
         | 
| 180 | 
            +
                def config(self):
         | 
| 181 | 
            +
                    return SimpleNamespace(
         | 
| 182 | 
            +
                        _class_name="CausalVideoAutoencoder",
         | 
| 183 | 
            +
                        dims=self.dims,
         | 
| 184 | 
            +
                        in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
         | 
| 185 | 
            +
                        out_channels=self.decoder.conv_out.out_channels
         | 
| 186 | 
            +
                        // self.decoder.patch_size**2,
         | 
| 187 | 
            +
                        latent_channels=self.decoder.conv_in.in_channels,
         | 
| 188 | 
            +
                        encoder_blocks=self.encoder.blocks_desc,
         | 
| 189 | 
            +
                        decoder_blocks=self.decoder.blocks_desc,
         | 
| 190 | 
            +
                        scaling_factor=1.0,
         | 
| 191 | 
            +
                        norm_layer=self.encoder.norm_layer,
         | 
| 192 | 
            +
                        patch_size=self.encoder.patch_size,
         | 
| 193 | 
            +
                        latent_log_var=self.encoder.latent_log_var,
         | 
| 194 | 
            +
                        use_quant_conv=self.use_quant_conv,
         | 
| 195 | 
            +
                        causal_decoder=self.decoder.causal,
         | 
| 196 | 
            +
                        timestep_conditioning=self.decoder.timestep_conditioning,
         | 
| 197 | 
            +
                        normalize_latent_channels=self.normalize_latent_channels,
         | 
| 198 | 
            +
                    )
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                @property
         | 
| 201 | 
            +
                def is_video_supported(self):
         | 
| 202 | 
            +
                    """
         | 
| 203 | 
            +
                    Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
         | 
| 204 | 
            +
                    """
         | 
| 205 | 
            +
                    return self.dims != 2
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                @property
         | 
| 208 | 
            +
                def spatial_downscale_factor(self):
         | 
| 209 | 
            +
                    return (
         | 
| 210 | 
            +
                        2
         | 
| 211 | 
            +
                        ** len(
         | 
| 212 | 
            +
                            [
         | 
| 213 | 
            +
                                block
         | 
| 214 | 
            +
                                for block in self.encoder.blocks_desc
         | 
| 215 | 
            +
                                if block[0]
         | 
| 216 | 
            +
                                in [
         | 
| 217 | 
            +
                                    "compress_space",
         | 
| 218 | 
            +
                                    "compress_all",
         | 
| 219 | 
            +
                                    "compress_all_res",
         | 
| 220 | 
            +
                                    "compress_space_res",
         | 
| 221 | 
            +
                                ]
         | 
| 222 | 
            +
                            ]
         | 
| 223 | 
            +
                        )
         | 
| 224 | 
            +
                        * self.encoder.patch_size
         | 
| 225 | 
            +
                    )
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                @property
         | 
| 228 | 
            +
                def temporal_downscale_factor(self):
         | 
| 229 | 
            +
                    return 2 ** len(
         | 
| 230 | 
            +
                        [
         | 
| 231 | 
            +
                            block
         | 
| 232 | 
            +
                            for block in self.encoder.blocks_desc
         | 
| 233 | 
            +
                            if block[0]
         | 
| 234 | 
            +
                            in [
         | 
| 235 | 
            +
                                "compress_time",
         | 
| 236 | 
            +
                                "compress_all",
         | 
| 237 | 
            +
                                "compress_all_res",
         | 
| 238 | 
            +
                                "compress_space_res",
         | 
| 239 | 
            +
                            ]
         | 
| 240 | 
            +
                        ]
         | 
| 241 | 
            +
                    )
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                def to_json_string(self) -> str:
         | 
| 244 | 
            +
                    import json
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    return json.dumps(self.config.__dict__)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
         | 
| 249 | 
            +
                    if any([key.startswith("vae.") for key in state_dict.keys()]):
         | 
| 250 | 
            +
                        state_dict = {
         | 
| 251 | 
            +
                            key.replace("vae.", ""): value
         | 
| 252 | 
            +
                            for key, value in state_dict.items()
         | 
| 253 | 
            +
                            if key.startswith("vae.")
         | 
| 254 | 
            +
                        }
         | 
| 255 | 
            +
                    ckpt_state_dict = {
         | 
| 256 | 
            +
                        key: value
         | 
| 257 | 
            +
                        for key, value in state_dict.items()
         | 
| 258 | 
            +
                        if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
         | 
| 259 | 
            +
                    }
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    model_keys = set(name for name, _ in self.named_modules())
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    key_mapping = {
         | 
| 264 | 
            +
                        ".resnets.": ".res_blocks.",
         | 
| 265 | 
            +
                        "downsamplers.0": "downsample",
         | 
| 266 | 
            +
                        "upsamplers.0": "upsample",
         | 
| 267 | 
            +
                    }
         | 
| 268 | 
            +
                    converted_state_dict = {}
         | 
| 269 | 
            +
                    for key, value in ckpt_state_dict.items():
         | 
| 270 | 
            +
                        for k, v in key_mapping.items():
         | 
| 271 | 
            +
                            key = key.replace(k, v)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                        key_prefix = ".".join(key.split(".")[:-1])
         | 
| 274 | 
            +
                        if "norm" in key and key_prefix not in model_keys:
         | 
| 275 | 
            +
                            logger.info(
         | 
| 276 | 
            +
                                f"Removing key {key} from state_dict as it is not present in the model"
         | 
| 277 | 
            +
                            )
         | 
| 278 | 
            +
                            continue
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                        converted_state_dict[key] = value
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    super().load_state_dict(converted_state_dict, strict=strict)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    data_dict = {
         | 
| 285 | 
            +
                        key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value
         | 
| 286 | 
            +
                        for key, value in state_dict.items()
         | 
| 287 | 
            +
                        if key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
         | 
| 288 | 
            +
                    }
         | 
| 289 | 
            +
                    if len(data_dict) > 0:
         | 
| 290 | 
            +
                        self.register_buffer("std_of_means", data_dict["std-of-means"])
         | 
| 291 | 
            +
                        self.register_buffer(
         | 
| 292 | 
            +
                            "mean_of_means",
         | 
| 293 | 
            +
                            data_dict.get(
         | 
| 294 | 
            +
                                "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
         | 
| 295 | 
            +
                            ),
         | 
| 296 | 
            +
                        )
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                def last_layer(self):
         | 
| 299 | 
            +
                    if hasattr(self.decoder, "conv_out"):
         | 
| 300 | 
            +
                        if isinstance(self.decoder.conv_out, nn.Sequential):
         | 
| 301 | 
            +
                            last_layer = self.decoder.conv_out[-1]
         | 
| 302 | 
            +
                        else:
         | 
| 303 | 
            +
                            last_layer = self.decoder.conv_out
         | 
| 304 | 
            +
                    else:
         | 
| 305 | 
            +
                        last_layer = self.decoder.layers[-1]
         | 
| 306 | 
            +
                    return last_layer
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                def set_use_tpu_flash_attention(self):
         | 
| 309 | 
            +
                    for block in self.decoder.up_blocks:
         | 
| 310 | 
            +
                        if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
         | 
| 311 | 
            +
                            for attention_block in block.attention_blocks:
         | 
| 312 | 
            +
                                attention_block.set_use_tpu_flash_attention()
         | 
| 313 | 
            +
             | 
| 314 | 
            +
             | 
| 315 | 
            +
            class Encoder(nn.Module):
         | 
| 316 | 
            +
                r"""
         | 
| 317 | 
            +
                The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                Args:
         | 
| 320 | 
            +
                    dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
         | 
| 321 | 
            +
                        The number of dimensions to use in convolutions.
         | 
| 322 | 
            +
                    in_channels (`int`, *optional*, defaults to 3):
         | 
| 323 | 
            +
                        The number of input channels.
         | 
| 324 | 
            +
                    out_channels (`int`, *optional*, defaults to 3):
         | 
| 325 | 
            +
                        The number of output channels.
         | 
| 326 | 
            +
                    blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
         | 
| 327 | 
            +
                        The blocks to use. Each block is a tuple of the block name and the number of layers.
         | 
| 328 | 
            +
                    base_channels (`int`, *optional*, defaults to 128):
         | 
| 329 | 
            +
                        The number of output channels for the first convolutional layer.
         | 
| 330 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to 32):
         | 
| 331 | 
            +
                        The number of groups for normalization.
         | 
| 332 | 
            +
                    patch_size (`int`, *optional*, defaults to 1):
         | 
| 333 | 
            +
                        The patch size to use. Should be a power of 2.
         | 
| 334 | 
            +
                    norm_layer (`str`, *optional*, defaults to `group_norm`):
         | 
| 335 | 
            +
                        The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
         | 
| 336 | 
            +
                    latent_log_var (`str`, *optional*, defaults to `per_channel`):
         | 
| 337 | 
            +
                        The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
         | 
| 338 | 
            +
                """
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                def __init__(
         | 
| 341 | 
            +
                    self,
         | 
| 342 | 
            +
                    dims: Union[int, Tuple[int, int]] = 3,
         | 
| 343 | 
            +
                    in_channels: int = 3,
         | 
| 344 | 
            +
                    out_channels: int = 3,
         | 
| 345 | 
            +
                    blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
         | 
| 346 | 
            +
                    base_channels: int = 128,
         | 
| 347 | 
            +
                    norm_num_groups: int = 32,
         | 
| 348 | 
            +
                    patch_size: Union[int, Tuple[int]] = 1,
         | 
| 349 | 
            +
                    norm_layer: str = "group_norm",  # group_norm, pixel_norm
         | 
| 350 | 
            +
                    latent_log_var: str = "per_channel",
         | 
| 351 | 
            +
                    spatial_padding_mode: str = "zeros",
         | 
| 352 | 
            +
                ):
         | 
| 353 | 
            +
                    super().__init__()
         | 
| 354 | 
            +
                    self.patch_size = patch_size
         | 
| 355 | 
            +
                    self.norm_layer = norm_layer
         | 
| 356 | 
            +
                    self.latent_channels = out_channels
         | 
| 357 | 
            +
                    self.latent_log_var = latent_log_var
         | 
| 358 | 
            +
                    self.blocks_desc = blocks
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    in_channels = in_channels * patch_size**2
         | 
| 361 | 
            +
                    output_channel = base_channels
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    self.conv_in = make_conv_nd(
         | 
| 364 | 
            +
                        dims=dims,
         | 
| 365 | 
            +
                        in_channels=in_channels,
         | 
| 366 | 
            +
                        out_channels=output_channel,
         | 
| 367 | 
            +
                        kernel_size=3,
         | 
| 368 | 
            +
                        stride=1,
         | 
| 369 | 
            +
                        padding=1,
         | 
| 370 | 
            +
                        causal=True,
         | 
| 371 | 
            +
                        spatial_padding_mode=spatial_padding_mode,
         | 
| 372 | 
            +
                    )
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    for block_name, block_params in blocks:
         | 
| 377 | 
            +
                        input_channel = output_channel
         | 
| 378 | 
            +
                        if isinstance(block_params, int):
         | 
| 379 | 
            +
                            block_params = {"num_layers": block_params}
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                        if block_name == "res_x":
         | 
| 382 | 
            +
                            block = UNetMidBlock3D(
         | 
| 383 | 
            +
                                dims=dims,
         | 
| 384 | 
            +
                                in_channels=input_channel,
         | 
| 385 | 
            +
                                num_layers=block_params["num_layers"],
         | 
| 386 | 
            +
                                resnet_eps=1e-6,
         | 
| 387 | 
            +
                                resnet_groups=norm_num_groups,
         | 
| 388 | 
            +
                                norm_layer=norm_layer,
         | 
| 389 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 390 | 
            +
                            )
         | 
| 391 | 
            +
                        elif block_name == "res_x_y":
         | 
| 392 | 
            +
                            output_channel = block_params.get("multiplier", 2) * output_channel
         | 
| 393 | 
            +
                            block = ResnetBlock3D(
         | 
| 394 | 
            +
                                dims=dims,
         | 
| 395 | 
            +
                                in_channels=input_channel,
         | 
| 396 | 
            +
                                out_channels=output_channel,
         | 
| 397 | 
            +
                                eps=1e-6,
         | 
| 398 | 
            +
                                groups=norm_num_groups,
         | 
| 399 | 
            +
                                norm_layer=norm_layer,
         | 
| 400 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 401 | 
            +
                            )
         | 
| 402 | 
            +
                        elif block_name == "compress_time":
         | 
| 403 | 
            +
                            block = make_conv_nd(
         | 
| 404 | 
            +
                                dims=dims,
         | 
| 405 | 
            +
                                in_channels=input_channel,
         | 
| 406 | 
            +
                                out_channels=output_channel,
         | 
| 407 | 
            +
                                kernel_size=3,
         | 
| 408 | 
            +
                                stride=(2, 1, 1),
         | 
| 409 | 
            +
                                causal=True,
         | 
| 410 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 411 | 
            +
                            )
         | 
| 412 | 
            +
                        elif block_name == "compress_space":
         | 
| 413 | 
            +
                            block = make_conv_nd(
         | 
| 414 | 
            +
                                dims=dims,
         | 
| 415 | 
            +
                                in_channels=input_channel,
         | 
| 416 | 
            +
                                out_channels=output_channel,
         | 
| 417 | 
            +
                                kernel_size=3,
         | 
| 418 | 
            +
                                stride=(1, 2, 2),
         | 
| 419 | 
            +
                                causal=True,
         | 
| 420 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 421 | 
            +
                            )
         | 
| 422 | 
            +
                        elif block_name == "compress_all":
         | 
| 423 | 
            +
                            block = make_conv_nd(
         | 
| 424 | 
            +
                                dims=dims,
         | 
| 425 | 
            +
                                in_channels=input_channel,
         | 
| 426 | 
            +
                                out_channels=output_channel,
         | 
| 427 | 
            +
                                kernel_size=3,
         | 
| 428 | 
            +
                                stride=(2, 2, 2),
         | 
| 429 | 
            +
                                causal=True,
         | 
| 430 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 431 | 
            +
                            )
         | 
| 432 | 
            +
                        elif block_name == "compress_all_x_y":
         | 
| 433 | 
            +
                            output_channel = block_params.get("multiplier", 2) * output_channel
         | 
| 434 | 
            +
                            block = make_conv_nd(
         | 
| 435 | 
            +
                                dims=dims,
         | 
| 436 | 
            +
                                in_channels=input_channel,
         | 
| 437 | 
            +
                                out_channels=output_channel,
         | 
| 438 | 
            +
                                kernel_size=3,
         | 
| 439 | 
            +
                                stride=(2, 2, 2),
         | 
| 440 | 
            +
                                causal=True,
         | 
| 441 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 442 | 
            +
                            )
         | 
| 443 | 
            +
                        elif block_name == "compress_all_res":
         | 
| 444 | 
            +
                            output_channel = block_params.get("multiplier", 2) * output_channel
         | 
| 445 | 
            +
                            block = SpaceToDepthDownsample(
         | 
| 446 | 
            +
                                dims=dims,
         | 
| 447 | 
            +
                                in_channels=input_channel,
         | 
| 448 | 
            +
                                out_channels=output_channel,
         | 
| 449 | 
            +
                                stride=(2, 2, 2),
         | 
| 450 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 451 | 
            +
                            )
         | 
| 452 | 
            +
                        elif block_name == "compress_space_res":
         | 
| 453 | 
            +
                            output_channel = block_params.get("multiplier", 2) * output_channel
         | 
| 454 | 
            +
                            block = SpaceToDepthDownsample(
         | 
| 455 | 
            +
                                dims=dims,
         | 
| 456 | 
            +
                                in_channels=input_channel,
         | 
| 457 | 
            +
                                out_channels=output_channel,
         | 
| 458 | 
            +
                                stride=(1, 2, 2),
         | 
| 459 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 460 | 
            +
                            )
         | 
| 461 | 
            +
                        elif block_name == "compress_time_res":
         | 
| 462 | 
            +
                            output_channel = block_params.get("multiplier", 2) * output_channel
         | 
| 463 | 
            +
                            block = SpaceToDepthDownsample(
         | 
| 464 | 
            +
                                dims=dims,
         | 
| 465 | 
            +
                                in_channels=input_channel,
         | 
| 466 | 
            +
                                out_channels=output_channel,
         | 
| 467 | 
            +
                                stride=(2, 1, 1),
         | 
| 468 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 469 | 
            +
                            )
         | 
| 470 | 
            +
                        else:
         | 
| 471 | 
            +
                            raise ValueError(f"unknown block: {block_name}")
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                        self.down_blocks.append(block)
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    # out
         | 
| 476 | 
            +
                    if norm_layer == "group_norm":
         | 
| 477 | 
            +
                        self.conv_norm_out = nn.GroupNorm(
         | 
| 478 | 
            +
                            num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
         | 
| 479 | 
            +
                        )
         | 
| 480 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 481 | 
            +
                        self.conv_norm_out = PixelNorm()
         | 
| 482 | 
            +
                    elif norm_layer == "layer_norm":
         | 
| 483 | 
            +
                        self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    conv_out_channels = out_channels
         | 
| 488 | 
            +
                    if latent_log_var == "per_channel":
         | 
| 489 | 
            +
                        conv_out_channels *= 2
         | 
| 490 | 
            +
                    elif latent_log_var == "uniform":
         | 
| 491 | 
            +
                        conv_out_channels += 1
         | 
| 492 | 
            +
                    elif latent_log_var == "constant":
         | 
| 493 | 
            +
                        conv_out_channels += 1
         | 
| 494 | 
            +
                    elif latent_log_var != "none":
         | 
| 495 | 
            +
                        raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
         | 
| 496 | 
            +
                    self.conv_out = make_conv_nd(
         | 
| 497 | 
            +
                        dims,
         | 
| 498 | 
            +
                        output_channel,
         | 
| 499 | 
            +
                        conv_out_channels,
         | 
| 500 | 
            +
                        3,
         | 
| 501 | 
            +
                        padding=1,
         | 
| 502 | 
            +
                        causal=True,
         | 
| 503 | 
            +
                        spatial_padding_mode=spatial_padding_mode,
         | 
| 504 | 
            +
                    )
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                    self.gradient_checkpointing = False
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 509 | 
            +
                    r"""The forward method of the `Encoder` class."""
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                    sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
         | 
| 512 | 
            +
                    sample = self.conv_in(sample)
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                    checkpoint_fn = (
         | 
| 515 | 
            +
                        partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
         | 
| 516 | 
            +
                        if self.gradient_checkpointing and self.training
         | 
| 517 | 
            +
                        else lambda x: x
         | 
| 518 | 
            +
                    )
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    for down_block in self.down_blocks:
         | 
| 521 | 
            +
                        sample = checkpoint_fn(down_block)(sample)
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 524 | 
            +
                    sample = self.conv_act(sample)
         | 
| 525 | 
            +
                    sample = self.conv_out(sample)
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                    if self.latent_log_var == "uniform":
         | 
| 528 | 
            +
                        last_channel = sample[:, -1:, ...]
         | 
| 529 | 
            +
                        num_dims = sample.dim()
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                        if num_dims == 4:
         | 
| 532 | 
            +
                            # For shape (B, C, H, W)
         | 
| 533 | 
            +
                            repeated_last_channel = last_channel.repeat(
         | 
| 534 | 
            +
                                1, sample.shape[1] - 2, 1, 1
         | 
| 535 | 
            +
                            )
         | 
| 536 | 
            +
                            sample = torch.cat([sample, repeated_last_channel], dim=1)
         | 
| 537 | 
            +
                        elif num_dims == 5:
         | 
| 538 | 
            +
                            # For shape (B, C, F, H, W)
         | 
| 539 | 
            +
                            repeated_last_channel = last_channel.repeat(
         | 
| 540 | 
            +
                                1, sample.shape[1] - 2, 1, 1, 1
         | 
| 541 | 
            +
                            )
         | 
| 542 | 
            +
                            sample = torch.cat([sample, repeated_last_channel], dim=1)
         | 
| 543 | 
            +
                        else:
         | 
| 544 | 
            +
                            raise ValueError(f"Invalid input shape: {sample.shape}")
         | 
| 545 | 
            +
                    elif self.latent_log_var == "constant":
         | 
| 546 | 
            +
                        sample = sample[:, :-1, ...]
         | 
| 547 | 
            +
                        approx_ln_0 = (
         | 
| 548 | 
            +
                            -30
         | 
| 549 | 
            +
                        )  # this is the minimal clamp value in DiagonalGaussianDistribution objects
         | 
| 550 | 
            +
                        sample = torch.cat(
         | 
| 551 | 
            +
                            [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
         | 
| 552 | 
            +
                            dim=1,
         | 
| 553 | 
            +
                        )
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                    return sample
         | 
| 556 | 
            +
             | 
| 557 | 
            +
             | 
| 558 | 
            +
            class Decoder(nn.Module):
         | 
| 559 | 
            +
                r"""
         | 
| 560 | 
            +
                The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                Args:
         | 
| 563 | 
            +
                    dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
         | 
| 564 | 
            +
                        The number of dimensions to use in convolutions.
         | 
| 565 | 
            +
                    in_channels (`int`, *optional*, defaults to 3):
         | 
| 566 | 
            +
                        The number of input channels.
         | 
| 567 | 
            +
                    out_channels (`int`, *optional*, defaults to 3):
         | 
| 568 | 
            +
                        The number of output channels.
         | 
| 569 | 
            +
                    blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
         | 
| 570 | 
            +
                        The blocks to use. Each block is a tuple of the block name and the number of layers.
         | 
| 571 | 
            +
                    base_channels (`int`, *optional*, defaults to 128):
         | 
| 572 | 
            +
                        The number of output channels for the first convolutional layer.
         | 
| 573 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to 32):
         | 
| 574 | 
            +
                        The number of groups for normalization.
         | 
| 575 | 
            +
                    patch_size (`int`, *optional*, defaults to 1):
         | 
| 576 | 
            +
                        The patch size to use. Should be a power of 2.
         | 
| 577 | 
            +
                    norm_layer (`str`, *optional*, defaults to `group_norm`):
         | 
| 578 | 
            +
                        The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
         | 
| 579 | 
            +
                    causal (`bool`, *optional*, defaults to `True`):
         | 
| 580 | 
            +
                        Whether to use causal convolutions or not.
         | 
| 581 | 
            +
                """
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                def __init__(
         | 
| 584 | 
            +
                    self,
         | 
| 585 | 
            +
                    dims,
         | 
| 586 | 
            +
                    in_channels: int = 3,
         | 
| 587 | 
            +
                    out_channels: int = 3,
         | 
| 588 | 
            +
                    blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
         | 
| 589 | 
            +
                    base_channels: int = 128,
         | 
| 590 | 
            +
                    layers_per_block: int = 2,
         | 
| 591 | 
            +
                    norm_num_groups: int = 32,
         | 
| 592 | 
            +
                    patch_size: int = 1,
         | 
| 593 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 594 | 
            +
                    causal: bool = True,
         | 
| 595 | 
            +
                    timestep_conditioning: bool = False,
         | 
| 596 | 
            +
                    spatial_padding_mode: str = "zeros",
         | 
| 597 | 
            +
                ):
         | 
| 598 | 
            +
                    super().__init__()
         | 
| 599 | 
            +
                    self.patch_size = patch_size
         | 
| 600 | 
            +
                    self.layers_per_block = layers_per_block
         | 
| 601 | 
            +
                    out_channels = out_channels * patch_size**2
         | 
| 602 | 
            +
                    self.causal = causal
         | 
| 603 | 
            +
                    self.blocks_desc = blocks
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                    # Compute output channel to be product of all channel-multiplier blocks
         | 
| 606 | 
            +
                    output_channel = base_channels
         | 
| 607 | 
            +
                    for block_name, block_params in list(reversed(blocks)):
         | 
| 608 | 
            +
                        block_params = block_params if isinstance(block_params, dict) else {}
         | 
| 609 | 
            +
                        if block_name == "res_x_y":
         | 
| 610 | 
            +
                            output_channel = output_channel * block_params.get("multiplier", 2)
         | 
| 611 | 
            +
                        if block_name == "compress_all":
         | 
| 612 | 
            +
                            output_channel = output_channel * block_params.get("multiplier", 1)
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                    self.conv_in = make_conv_nd(
         | 
| 615 | 
            +
                        dims,
         | 
| 616 | 
            +
                        in_channels,
         | 
| 617 | 
            +
                        output_channel,
         | 
| 618 | 
            +
                        kernel_size=3,
         | 
| 619 | 
            +
                        stride=1,
         | 
| 620 | 
            +
                        padding=1,
         | 
| 621 | 
            +
                        causal=True,
         | 
| 622 | 
            +
                        spatial_padding_mode=spatial_padding_mode,
         | 
| 623 | 
            +
                    )
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                    for block_name, block_params in list(reversed(blocks)):
         | 
| 628 | 
            +
                        input_channel = output_channel
         | 
| 629 | 
            +
                        if isinstance(block_params, int):
         | 
| 630 | 
            +
                            block_params = {"num_layers": block_params}
         | 
| 631 | 
            +
             | 
| 632 | 
            +
                        if block_name == "res_x":
         | 
| 633 | 
            +
                            block = UNetMidBlock3D(
         | 
| 634 | 
            +
                                dims=dims,
         | 
| 635 | 
            +
                                in_channels=input_channel,
         | 
| 636 | 
            +
                                num_layers=block_params["num_layers"],
         | 
| 637 | 
            +
                                resnet_eps=1e-6,
         | 
| 638 | 
            +
                                resnet_groups=norm_num_groups,
         | 
| 639 | 
            +
                                norm_layer=norm_layer,
         | 
| 640 | 
            +
                                inject_noise=block_params.get("inject_noise", False),
         | 
| 641 | 
            +
                                timestep_conditioning=timestep_conditioning,
         | 
| 642 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 643 | 
            +
                            )
         | 
| 644 | 
            +
                        elif block_name == "attn_res_x":
         | 
| 645 | 
            +
                            block = UNetMidBlock3D(
         | 
| 646 | 
            +
                                dims=dims,
         | 
| 647 | 
            +
                                in_channels=input_channel,
         | 
| 648 | 
            +
                                num_layers=block_params["num_layers"],
         | 
| 649 | 
            +
                                resnet_groups=norm_num_groups,
         | 
| 650 | 
            +
                                norm_layer=norm_layer,
         | 
| 651 | 
            +
                                inject_noise=block_params.get("inject_noise", False),
         | 
| 652 | 
            +
                                timestep_conditioning=timestep_conditioning,
         | 
| 653 | 
            +
                                attention_head_dim=block_params["attention_head_dim"],
         | 
| 654 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 655 | 
            +
                            )
         | 
| 656 | 
            +
                        elif block_name == "res_x_y":
         | 
| 657 | 
            +
                            output_channel = output_channel // block_params.get("multiplier", 2)
         | 
| 658 | 
            +
                            block = ResnetBlock3D(
         | 
| 659 | 
            +
                                dims=dims,
         | 
| 660 | 
            +
                                in_channels=input_channel,
         | 
| 661 | 
            +
                                out_channels=output_channel,
         | 
| 662 | 
            +
                                eps=1e-6,
         | 
| 663 | 
            +
                                groups=norm_num_groups,
         | 
| 664 | 
            +
                                norm_layer=norm_layer,
         | 
| 665 | 
            +
                                inject_noise=block_params.get("inject_noise", False),
         | 
| 666 | 
            +
                                timestep_conditioning=False,
         | 
| 667 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 668 | 
            +
                            )
         | 
| 669 | 
            +
                        elif block_name == "compress_time":
         | 
| 670 | 
            +
                            block = DepthToSpaceUpsample(
         | 
| 671 | 
            +
                                dims=dims,
         | 
| 672 | 
            +
                                in_channels=input_channel,
         | 
| 673 | 
            +
                                stride=(2, 1, 1),
         | 
| 674 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 675 | 
            +
                            )
         | 
| 676 | 
            +
                        elif block_name == "compress_space":
         | 
| 677 | 
            +
                            block = DepthToSpaceUpsample(
         | 
| 678 | 
            +
                                dims=dims,
         | 
| 679 | 
            +
                                in_channels=input_channel,
         | 
| 680 | 
            +
                                stride=(1, 2, 2),
         | 
| 681 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 682 | 
            +
                            )
         | 
| 683 | 
            +
                        elif block_name == "compress_all":
         | 
| 684 | 
            +
                            output_channel = output_channel // block_params.get("multiplier", 1)
         | 
| 685 | 
            +
                            block = DepthToSpaceUpsample(
         | 
| 686 | 
            +
                                dims=dims,
         | 
| 687 | 
            +
                                in_channels=input_channel,
         | 
| 688 | 
            +
                                stride=(2, 2, 2),
         | 
| 689 | 
            +
                                residual=block_params.get("residual", False),
         | 
| 690 | 
            +
                                out_channels_reduction_factor=block_params.get("multiplier", 1),
         | 
| 691 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 692 | 
            +
                            )
         | 
| 693 | 
            +
                        else:
         | 
| 694 | 
            +
                            raise ValueError(f"unknown layer: {block_name}")
         | 
| 695 | 
            +
             | 
| 696 | 
            +
                        self.up_blocks.append(block)
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                    if norm_layer == "group_norm":
         | 
| 699 | 
            +
                        self.conv_norm_out = nn.GroupNorm(
         | 
| 700 | 
            +
                            num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
         | 
| 701 | 
            +
                        )
         | 
| 702 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 703 | 
            +
                        self.conv_norm_out = PixelNorm()
         | 
| 704 | 
            +
                    elif norm_layer == "layer_norm":
         | 
| 705 | 
            +
                        self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
         | 
| 706 | 
            +
             | 
| 707 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 708 | 
            +
                    self.conv_out = make_conv_nd(
         | 
| 709 | 
            +
                        dims,
         | 
| 710 | 
            +
                        output_channel,
         | 
| 711 | 
            +
                        out_channels,
         | 
| 712 | 
            +
                        3,
         | 
| 713 | 
            +
                        padding=1,
         | 
| 714 | 
            +
                        causal=True,
         | 
| 715 | 
            +
                        spatial_padding_mode=spatial_padding_mode,
         | 
| 716 | 
            +
                    )
         | 
| 717 | 
            +
             | 
| 718 | 
            +
                    self.gradient_checkpointing = False
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                    self.timestep_conditioning = timestep_conditioning
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                    if timestep_conditioning:
         | 
| 723 | 
            +
                        self.timestep_scale_multiplier = nn.Parameter(
         | 
| 724 | 
            +
                            torch.tensor(1000.0, dtype=torch.float32)
         | 
| 725 | 
            +
                        )
         | 
| 726 | 
            +
                        self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
         | 
| 727 | 
            +
                            output_channel * 2, 0
         | 
| 728 | 
            +
                        )
         | 
| 729 | 
            +
                        self.last_scale_shift_table = nn.Parameter(
         | 
| 730 | 
            +
                            torch.randn(2, output_channel) / output_channel**0.5
         | 
| 731 | 
            +
                        )
         | 
| 732 | 
            +
             | 
| 733 | 
            +
                def forward(
         | 
| 734 | 
            +
                    self,
         | 
| 735 | 
            +
                    sample: torch.FloatTensor,
         | 
| 736 | 
            +
                    target_shape,
         | 
| 737 | 
            +
                    timestep: Optional[torch.Tensor] = None,
         | 
| 738 | 
            +
                ) -> torch.FloatTensor:
         | 
| 739 | 
            +
                    r"""The forward method of the `Decoder` class."""
         | 
| 740 | 
            +
                    assert target_shape is not None, "target_shape must be provided"
         | 
| 741 | 
            +
                    batch_size = sample.shape[0]
         | 
| 742 | 
            +
             | 
| 743 | 
            +
                    sample = self.conv_in(sample, causal=self.causal)
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                    upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                    checkpoint_fn = (
         | 
| 748 | 
            +
                        partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
         | 
| 749 | 
            +
                        if self.gradient_checkpointing and self.training
         | 
| 750 | 
            +
                        else lambda x: x
         | 
| 751 | 
            +
                    )
         | 
| 752 | 
            +
             | 
| 753 | 
            +
                    sample = sample.to(upscale_dtype)
         | 
| 754 | 
            +
             | 
| 755 | 
            +
                    if self.timestep_conditioning:
         | 
| 756 | 
            +
                        assert (
         | 
| 757 | 
            +
                            timestep is not None
         | 
| 758 | 
            +
                        ), "should pass timestep with timestep_conditioning=True"
         | 
| 759 | 
            +
                        scaled_timestep = timestep * self.timestep_scale_multiplier
         | 
| 760 | 
            +
             | 
| 761 | 
            +
                    for up_block in self.up_blocks:
         | 
| 762 | 
            +
                        if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
         | 
| 763 | 
            +
                            sample = checkpoint_fn(up_block)(
         | 
| 764 | 
            +
                                sample, causal=self.causal, timestep=scaled_timestep
         | 
| 765 | 
            +
                            )
         | 
| 766 | 
            +
                        else:
         | 
| 767 | 
            +
                            sample = checkpoint_fn(up_block)(sample, causal=self.causal)
         | 
| 768 | 
            +
             | 
| 769 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 770 | 
            +
             | 
| 771 | 
            +
                    if self.timestep_conditioning:
         | 
| 772 | 
            +
                        embedded_timestep = self.last_time_embedder(
         | 
| 773 | 
            +
                            timestep=scaled_timestep.flatten(),
         | 
| 774 | 
            +
                            resolution=None,
         | 
| 775 | 
            +
                            aspect_ratio=None,
         | 
| 776 | 
            +
                            batch_size=sample.shape[0],
         | 
| 777 | 
            +
                            hidden_dtype=sample.dtype,
         | 
| 778 | 
            +
                        )
         | 
| 779 | 
            +
                        embedded_timestep = embedded_timestep.view(
         | 
| 780 | 
            +
                            batch_size, embedded_timestep.shape[-1], 1, 1, 1
         | 
| 781 | 
            +
                        )
         | 
| 782 | 
            +
                        ada_values = self.last_scale_shift_table[
         | 
| 783 | 
            +
                            None, ..., None, None, None
         | 
| 784 | 
            +
                        ] + embedded_timestep.reshape(
         | 
| 785 | 
            +
                            batch_size,
         | 
| 786 | 
            +
                            2,
         | 
| 787 | 
            +
                            -1,
         | 
| 788 | 
            +
                            embedded_timestep.shape[-3],
         | 
| 789 | 
            +
                            embedded_timestep.shape[-2],
         | 
| 790 | 
            +
                            embedded_timestep.shape[-1],
         | 
| 791 | 
            +
                        )
         | 
| 792 | 
            +
                        shift, scale = ada_values.unbind(dim=1)
         | 
| 793 | 
            +
                        sample = sample * (1 + scale) + shift
         | 
| 794 | 
            +
             | 
| 795 | 
            +
                    sample = self.conv_act(sample)
         | 
| 796 | 
            +
                    sample = self.conv_out(sample, causal=self.causal)
         | 
| 797 | 
            +
             | 
| 798 | 
            +
                    sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
         | 
| 799 | 
            +
             | 
| 800 | 
            +
                    return sample
         | 
| 801 | 
            +
             | 
| 802 | 
            +
             | 
| 803 | 
            +
            class UNetMidBlock3D(nn.Module):
         | 
| 804 | 
            +
                """
         | 
| 805 | 
            +
                A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
         | 
| 806 | 
            +
             | 
| 807 | 
            +
                Args:
         | 
| 808 | 
            +
                    in_channels (`int`): The number of input channels.
         | 
| 809 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
         | 
| 810 | 
            +
                    num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
         | 
| 811 | 
            +
                    resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
         | 
| 812 | 
            +
                    resnet_groups (`int`, *optional*, defaults to 32):
         | 
| 813 | 
            +
                        The number of groups to use in the group normalization layers of the resnet blocks.
         | 
| 814 | 
            +
                    norm_layer (`str`, *optional*, defaults to `group_norm`):
         | 
| 815 | 
            +
                        The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
         | 
| 816 | 
            +
                    inject_noise (`bool`, *optional*, defaults to `False`):
         | 
| 817 | 
            +
                        Whether to inject noise into the hidden states.
         | 
| 818 | 
            +
                    timestep_conditioning (`bool`, *optional*, defaults to `False`):
         | 
| 819 | 
            +
                        Whether to condition the hidden states on the timestep.
         | 
| 820 | 
            +
                    attention_head_dim (`int`, *optional*, defaults to -1):
         | 
| 821 | 
            +
                        The dimension of the attention head. If -1, no attention is used.
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                Returns:
         | 
| 824 | 
            +
                    `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
         | 
| 825 | 
            +
                    in_channels, height, width)`.
         | 
| 826 | 
            +
             | 
| 827 | 
            +
                """
         | 
| 828 | 
            +
             | 
| 829 | 
            +
                def __init__(
         | 
| 830 | 
            +
                    self,
         | 
| 831 | 
            +
                    dims: Union[int, Tuple[int, int]],
         | 
| 832 | 
            +
                    in_channels: int,
         | 
| 833 | 
            +
                    dropout: float = 0.0,
         | 
| 834 | 
            +
                    num_layers: int = 1,
         | 
| 835 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 836 | 
            +
                    resnet_groups: int = 32,
         | 
| 837 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 838 | 
            +
                    inject_noise: bool = False,
         | 
| 839 | 
            +
                    timestep_conditioning: bool = False,
         | 
| 840 | 
            +
                    attention_head_dim: int = -1,
         | 
| 841 | 
            +
                    spatial_padding_mode: str = "zeros",
         | 
| 842 | 
            +
                ):
         | 
| 843 | 
            +
                    super().__init__()
         | 
| 844 | 
            +
                    resnet_groups = (
         | 
| 845 | 
            +
                        resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
         | 
| 846 | 
            +
                    )
         | 
| 847 | 
            +
                    self.timestep_conditioning = timestep_conditioning
         | 
| 848 | 
            +
             | 
| 849 | 
            +
                    if timestep_conditioning:
         | 
| 850 | 
            +
                        self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
         | 
| 851 | 
            +
                            in_channels * 4, 0
         | 
| 852 | 
            +
                        )
         | 
| 853 | 
            +
             | 
| 854 | 
            +
                    self.res_blocks = nn.ModuleList(
         | 
| 855 | 
            +
                        [
         | 
| 856 | 
            +
                            ResnetBlock3D(
         | 
| 857 | 
            +
                                dims=dims,
         | 
| 858 | 
            +
                                in_channels=in_channels,
         | 
| 859 | 
            +
                                out_channels=in_channels,
         | 
| 860 | 
            +
                                eps=resnet_eps,
         | 
| 861 | 
            +
                                groups=resnet_groups,
         | 
| 862 | 
            +
                                dropout=dropout,
         | 
| 863 | 
            +
                                norm_layer=norm_layer,
         | 
| 864 | 
            +
                                inject_noise=inject_noise,
         | 
| 865 | 
            +
                                timestep_conditioning=timestep_conditioning,
         | 
| 866 | 
            +
                                spatial_padding_mode=spatial_padding_mode,
         | 
| 867 | 
            +
                            )
         | 
| 868 | 
            +
                            for _ in range(num_layers)
         | 
| 869 | 
            +
                        ]
         | 
| 870 | 
            +
                    )
         | 
| 871 | 
            +
             | 
| 872 | 
            +
                    self.attention_blocks = None
         | 
| 873 | 
            +
             | 
| 874 | 
            +
                    if attention_head_dim > 0:
         | 
| 875 | 
            +
                        if attention_head_dim > in_channels:
         | 
| 876 | 
            +
                            raise ValueError(
         | 
| 877 | 
            +
                                "attention_head_dim must be less than or equal to in_channels"
         | 
| 878 | 
            +
                            )
         | 
| 879 | 
            +
             | 
| 880 | 
            +
                        self.attention_blocks = nn.ModuleList(
         | 
| 881 | 
            +
                            [
         | 
| 882 | 
            +
                                Attention(
         | 
| 883 | 
            +
                                    query_dim=in_channels,
         | 
| 884 | 
            +
                                    heads=in_channels // attention_head_dim,
         | 
| 885 | 
            +
                                    dim_head=attention_head_dim,
         | 
| 886 | 
            +
                                    bias=True,
         | 
| 887 | 
            +
                                    out_bias=True,
         | 
| 888 | 
            +
                                    qk_norm="rms_norm",
         | 
| 889 | 
            +
                                    residual_connection=True,
         | 
| 890 | 
            +
                                )
         | 
| 891 | 
            +
                                for _ in range(num_layers)
         | 
| 892 | 
            +
                            ]
         | 
| 893 | 
            +
                        )
         | 
| 894 | 
            +
             | 
| 895 | 
            +
                def forward(
         | 
| 896 | 
            +
                    self,
         | 
| 897 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 898 | 
            +
                    causal: bool = True,
         | 
| 899 | 
            +
                    timestep: Optional[torch.Tensor] = None,
         | 
| 900 | 
            +
                ) -> torch.FloatTensor:
         | 
| 901 | 
            +
                    timestep_embed = None
         | 
| 902 | 
            +
                    if self.timestep_conditioning:
         | 
| 903 | 
            +
                        assert (
         | 
| 904 | 
            +
                            timestep is not None
         | 
| 905 | 
            +
                        ), "should pass timestep with timestep_conditioning=True"
         | 
| 906 | 
            +
                        batch_size = hidden_states.shape[0]
         | 
| 907 | 
            +
                        timestep_embed = self.time_embedder(
         | 
| 908 | 
            +
                            timestep=timestep.flatten(),
         | 
| 909 | 
            +
                            resolution=None,
         | 
| 910 | 
            +
                            aspect_ratio=None,
         | 
| 911 | 
            +
                            batch_size=batch_size,
         | 
| 912 | 
            +
                            hidden_dtype=hidden_states.dtype,
         | 
| 913 | 
            +
                        )
         | 
| 914 | 
            +
                        timestep_embed = timestep_embed.view(
         | 
| 915 | 
            +
                            batch_size, timestep_embed.shape[-1], 1, 1, 1
         | 
| 916 | 
            +
                        )
         | 
| 917 | 
            +
             | 
| 918 | 
            +
                    if self.attention_blocks:
         | 
| 919 | 
            +
                        for resnet, attention in zip(self.res_blocks, self.attention_blocks):
         | 
| 920 | 
            +
                            hidden_states = resnet(
         | 
| 921 | 
            +
                                hidden_states, causal=causal, timestep=timestep_embed
         | 
| 922 | 
            +
                            )
         | 
| 923 | 
            +
             | 
| 924 | 
            +
                            # Reshape the hidden states to be (batch_size, frames * height * width, channel)
         | 
| 925 | 
            +
                            batch_size, channel, frames, height, width = hidden_states.shape
         | 
| 926 | 
            +
                            hidden_states = hidden_states.view(
         | 
| 927 | 
            +
                                batch_size, channel, frames * height * width
         | 
| 928 | 
            +
                            ).transpose(1, 2)
         | 
| 929 | 
            +
             | 
| 930 | 
            +
                            if attention.use_tpu_flash_attention:
         | 
| 931 | 
            +
                                # Pad the second dimension to be divisible by block_k_major (block in flash attention)
         | 
| 932 | 
            +
                                seq_len = hidden_states.shape[1]
         | 
| 933 | 
            +
                                block_k_major = 512
         | 
| 934 | 
            +
                                pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
         | 
| 935 | 
            +
                                if pad_len > 0:
         | 
| 936 | 
            +
                                    hidden_states = F.pad(
         | 
| 937 | 
            +
                                        hidden_states, (0, 0, 0, pad_len), "constant", 0
         | 
| 938 | 
            +
                                    )
         | 
| 939 | 
            +
             | 
| 940 | 
            +
                                # Create a mask with ones for the original sequence length and zeros for the padded indexes
         | 
| 941 | 
            +
                                mask = torch.ones(
         | 
| 942 | 
            +
                                    (hidden_states.shape[0], seq_len),
         | 
| 943 | 
            +
                                    device=hidden_states.device,
         | 
| 944 | 
            +
                                    dtype=hidden_states.dtype,
         | 
| 945 | 
            +
                                )
         | 
| 946 | 
            +
                                if pad_len > 0:
         | 
| 947 | 
            +
                                    mask = F.pad(mask, (0, pad_len), "constant", 0)
         | 
| 948 | 
            +
             | 
| 949 | 
            +
                            hidden_states = attention(
         | 
| 950 | 
            +
                                hidden_states,
         | 
| 951 | 
            +
                                attention_mask=(
         | 
| 952 | 
            +
                                    None if not attention.use_tpu_flash_attention else mask
         | 
| 953 | 
            +
                                ),
         | 
| 954 | 
            +
                            )
         | 
| 955 | 
            +
             | 
| 956 | 
            +
                            if attention.use_tpu_flash_attention:
         | 
| 957 | 
            +
                                # Remove the padding
         | 
| 958 | 
            +
                                if pad_len > 0:
         | 
| 959 | 
            +
                                    hidden_states = hidden_states[:, :-pad_len, :]
         | 
| 960 | 
            +
             | 
| 961 | 
            +
                            # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
         | 
| 962 | 
            +
                            hidden_states = hidden_states.transpose(-1, -2).reshape(
         | 
| 963 | 
            +
                                batch_size, channel, frames, height, width
         | 
| 964 | 
            +
                            )
         | 
| 965 | 
            +
                    else:
         | 
| 966 | 
            +
                        for resnet in self.res_blocks:
         | 
| 967 | 
            +
                            hidden_states = resnet(
         | 
| 968 | 
            +
                                hidden_states, causal=causal, timestep=timestep_embed
         | 
| 969 | 
            +
                            )
         | 
| 970 | 
            +
             | 
| 971 | 
            +
                    return hidden_states
         | 
| 972 | 
            +
             | 
| 973 | 
            +
             | 
| 974 | 
            +
            class SpaceToDepthDownsample(nn.Module):
         | 
| 975 | 
            +
                def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
         | 
| 976 | 
            +
                    super().__init__()
         | 
| 977 | 
            +
                    self.stride = stride
         | 
| 978 | 
            +
                    self.group_size = in_channels * np.prod(stride) // out_channels
         | 
| 979 | 
            +
                    self.conv = make_conv_nd(
         | 
| 980 | 
            +
                        dims=dims,
         | 
| 981 | 
            +
                        in_channels=in_channels,
         | 
| 982 | 
            +
                        out_channels=out_channels // np.prod(stride),
         | 
| 983 | 
            +
                        kernel_size=3,
         | 
| 984 | 
            +
                        stride=1,
         | 
| 985 | 
            +
                        causal=True,
         | 
| 986 | 
            +
                        spatial_padding_mode=spatial_padding_mode,
         | 
| 987 | 
            +
                    )
         | 
| 988 | 
            +
             | 
| 989 | 
            +
                def forward(self, x, causal: bool = True):
         | 
| 990 | 
            +
                    if self.stride[0] == 2:
         | 
| 991 | 
            +
                        x = torch.cat(
         | 
| 992 | 
            +
                            [x[:, :, :1, :, :], x], dim=2
         | 
| 993 | 
            +
                        )  # duplicate first frames for padding
         | 
| 994 | 
            +
             | 
| 995 | 
            +
                    # skip connection
         | 
| 996 | 
            +
                    x_in = rearrange(
         | 
| 997 | 
            +
                        x,
         | 
| 998 | 
            +
                        "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
         | 
| 999 | 
            +
                        p1=self.stride[0],
         | 
| 1000 | 
            +
                        p2=self.stride[1],
         | 
| 1001 | 
            +
                        p3=self.stride[2],
         | 
| 1002 | 
            +
                    )
         | 
| 1003 | 
            +
                    x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
         | 
| 1004 | 
            +
                    x_in = x_in.mean(dim=2)
         | 
| 1005 | 
            +
             | 
| 1006 | 
            +
                    # conv
         | 
| 1007 | 
            +
                    x = self.conv(x, causal=causal)
         | 
| 1008 | 
            +
                    x = rearrange(
         | 
| 1009 | 
            +
                        x,
         | 
| 1010 | 
            +
                        "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
         | 
| 1011 | 
            +
                        p1=self.stride[0],
         | 
| 1012 | 
            +
                        p2=self.stride[1],
         | 
| 1013 | 
            +
                        p3=self.stride[2],
         | 
| 1014 | 
            +
                    )
         | 
| 1015 | 
            +
             | 
| 1016 | 
            +
                    x = x + x_in
         | 
| 1017 | 
            +
             | 
| 1018 | 
            +
                    return x
         | 
| 1019 | 
            +
             | 
| 1020 | 
            +
             | 
| 1021 | 
            +
            class DepthToSpaceUpsample(nn.Module):
         | 
| 1022 | 
            +
                def __init__(
         | 
| 1023 | 
            +
                    self,
         | 
| 1024 | 
            +
                    dims,
         | 
| 1025 | 
            +
                    in_channels,
         | 
| 1026 | 
            +
                    stride,
         | 
| 1027 | 
            +
                    residual=False,
         | 
| 1028 | 
            +
                    out_channels_reduction_factor=1,
         | 
| 1029 | 
            +
                    spatial_padding_mode="zeros",
         | 
| 1030 | 
            +
                ):
         | 
| 1031 | 
            +
                    super().__init__()
         | 
| 1032 | 
            +
                    self.stride = stride
         | 
| 1033 | 
            +
                    self.out_channels = (
         | 
| 1034 | 
            +
                        np.prod(stride) * in_channels // out_channels_reduction_factor
         | 
| 1035 | 
            +
                    )
         | 
| 1036 | 
            +
                    self.conv = make_conv_nd(
         | 
| 1037 | 
            +
                        dims=dims,
         | 
| 1038 | 
            +
                        in_channels=in_channels,
         | 
| 1039 | 
            +
                        out_channels=self.out_channels,
         | 
| 1040 | 
            +
                        kernel_size=3,
         | 
| 1041 | 
            +
                        stride=1,
         | 
| 1042 | 
            +
                        causal=True,
         | 
| 1043 | 
            +
                        spatial_padding_mode=spatial_padding_mode,
         | 
| 1044 | 
            +
                    )
         | 
| 1045 | 
            +
                    self.pixel_shuffle = PixelShuffleND(dims=dims, upscale_factors=stride)
         | 
| 1046 | 
            +
                    self.residual = residual
         | 
| 1047 | 
            +
                    self.out_channels_reduction_factor = out_channels_reduction_factor
         | 
| 1048 | 
            +
             | 
| 1049 | 
            +
                def forward(self, x, causal: bool = True):
         | 
| 1050 | 
            +
                    if self.residual:
         | 
| 1051 | 
            +
                        # Reshape and duplicate the input to match the output shape
         | 
| 1052 | 
            +
                        x_in = self.pixel_shuffle(x)
         | 
| 1053 | 
            +
                        num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
         | 
| 1054 | 
            +
                        x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
         | 
| 1055 | 
            +
                        if self.stride[0] == 2:
         | 
| 1056 | 
            +
                            x_in = x_in[:, :, 1:, :, :]
         | 
| 1057 | 
            +
                    x = self.conv(x, causal=causal)
         | 
| 1058 | 
            +
                    x = self.pixel_shuffle(x)
         | 
| 1059 | 
            +
                    if self.stride[0] == 2:
         | 
| 1060 | 
            +
                        x = x[:, :, 1:, :, :]
         | 
| 1061 | 
            +
                    if self.residual:
         | 
| 1062 | 
            +
                        x = x + x_in
         | 
| 1063 | 
            +
                    return x
         | 
| 1064 | 
            +
             | 
| 1065 | 
            +
             | 
| 1066 | 
            +
            class LayerNorm(nn.Module):
         | 
| 1067 | 
            +
                def __init__(self, dim, eps, elementwise_affine=True) -> None:
         | 
| 1068 | 
            +
                    super().__init__()
         | 
| 1069 | 
            +
                    self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
         | 
| 1070 | 
            +
             | 
| 1071 | 
            +
                def forward(self, x):
         | 
| 1072 | 
            +
                    x = rearrange(x, "b c d h w -> b d h w c")
         | 
| 1073 | 
            +
                    x = self.norm(x)
         | 
| 1074 | 
            +
                    x = rearrange(x, "b d h w c -> b c d h w")
         | 
| 1075 | 
            +
                    return x
         | 
| 1076 | 
            +
             | 
| 1077 | 
            +
             | 
| 1078 | 
            +
            class ResnetBlock3D(nn.Module):
         | 
| 1079 | 
            +
                r"""
         | 
| 1080 | 
            +
                A Resnet block.
         | 
| 1081 | 
            +
             | 
| 1082 | 
            +
                Parameters:
         | 
| 1083 | 
            +
                    in_channels (`int`): The number of channels in the input.
         | 
| 1084 | 
            +
                    out_channels (`int`, *optional*, default to be `None`):
         | 
| 1085 | 
            +
                        The number of output channels for the first conv layer. If None, same as `in_channels`.
         | 
| 1086 | 
            +
                    dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
         | 
| 1087 | 
            +
                    groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
         | 
| 1088 | 
            +
                    eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
         | 
| 1089 | 
            +
                """
         | 
| 1090 | 
            +
             | 
| 1091 | 
            +
                def __init__(
         | 
| 1092 | 
            +
                    self,
         | 
| 1093 | 
            +
                    dims: Union[int, Tuple[int, int]],
         | 
| 1094 | 
            +
                    in_channels: int,
         | 
| 1095 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 1096 | 
            +
                    dropout: float = 0.0,
         | 
| 1097 | 
            +
                    groups: int = 32,
         | 
| 1098 | 
            +
                    eps: float = 1e-6,
         | 
| 1099 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 1100 | 
            +
                    inject_noise: bool = False,
         | 
| 1101 | 
            +
                    timestep_conditioning: bool = False,
         | 
| 1102 | 
            +
                    spatial_padding_mode: str = "zeros",
         | 
| 1103 | 
            +
                ):
         | 
| 1104 | 
            +
                    super().__init__()
         | 
| 1105 | 
            +
                    self.in_channels = in_channels
         | 
| 1106 | 
            +
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 1107 | 
            +
                    self.out_channels = out_channels
         | 
| 1108 | 
            +
                    self.inject_noise = inject_noise
         | 
| 1109 | 
            +
             | 
| 1110 | 
            +
                    if norm_layer == "group_norm":
         | 
| 1111 | 
            +
                        self.norm1 = nn.GroupNorm(
         | 
| 1112 | 
            +
                            num_groups=groups, num_channels=in_channels, eps=eps, affine=True
         | 
| 1113 | 
            +
                        )
         | 
| 1114 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 1115 | 
            +
                        self.norm1 = PixelNorm()
         | 
| 1116 | 
            +
                    elif norm_layer == "layer_norm":
         | 
| 1117 | 
            +
                        self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
         | 
| 1118 | 
            +
             | 
| 1119 | 
            +
                    self.non_linearity = nn.SiLU()
         | 
| 1120 | 
            +
             | 
| 1121 | 
            +
                    self.conv1 = make_conv_nd(
         | 
| 1122 | 
            +
                        dims,
         | 
| 1123 | 
            +
                        in_channels,
         | 
| 1124 | 
            +
                        out_channels,
         | 
| 1125 | 
            +
                        kernel_size=3,
         | 
| 1126 | 
            +
                        stride=1,
         | 
| 1127 | 
            +
                        padding=1,
         | 
| 1128 | 
            +
                        causal=True,
         | 
| 1129 | 
            +
                        spatial_padding_mode=spatial_padding_mode,
         | 
| 1130 | 
            +
                    )
         | 
| 1131 | 
            +
             | 
| 1132 | 
            +
                    if inject_noise:
         | 
| 1133 | 
            +
                        self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
         | 
| 1134 | 
            +
             | 
| 1135 | 
            +
                    if norm_layer == "group_norm":
         | 
| 1136 | 
            +
                        self.norm2 = nn.GroupNorm(
         | 
| 1137 | 
            +
                            num_groups=groups, num_channels=out_channels, eps=eps, affine=True
         | 
| 1138 | 
            +
                        )
         | 
| 1139 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 1140 | 
            +
                        self.norm2 = PixelNorm()
         | 
| 1141 | 
            +
                    elif norm_layer == "layer_norm":
         | 
| 1142 | 
            +
                        self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
         | 
| 1143 | 
            +
             | 
| 1144 | 
            +
                    self.dropout = torch.nn.Dropout(dropout)
         | 
| 1145 | 
            +
             | 
| 1146 | 
            +
                    self.conv2 = make_conv_nd(
         | 
| 1147 | 
            +
                        dims,
         | 
| 1148 | 
            +
                        out_channels,
         | 
| 1149 | 
            +
                        out_channels,
         | 
| 1150 | 
            +
                        kernel_size=3,
         | 
| 1151 | 
            +
                        stride=1,
         | 
| 1152 | 
            +
                        padding=1,
         | 
| 1153 | 
            +
                        causal=True,
         | 
| 1154 | 
            +
                        spatial_padding_mode=spatial_padding_mode,
         | 
| 1155 | 
            +
                    )
         | 
| 1156 | 
            +
             | 
| 1157 | 
            +
                    if inject_noise:
         | 
| 1158 | 
            +
                        self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
         | 
| 1159 | 
            +
             | 
| 1160 | 
            +
                    self.conv_shortcut = (
         | 
| 1161 | 
            +
                        make_linear_nd(
         | 
| 1162 | 
            +
                            dims=dims, in_channels=in_channels, out_channels=out_channels
         | 
| 1163 | 
            +
                        )
         | 
| 1164 | 
            +
                        if in_channels != out_channels
         | 
| 1165 | 
            +
                        else nn.Identity()
         | 
| 1166 | 
            +
                    )
         | 
| 1167 | 
            +
             | 
| 1168 | 
            +
                    self.norm3 = (
         | 
| 1169 | 
            +
                        LayerNorm(in_channels, eps=eps, elementwise_affine=True)
         | 
| 1170 | 
            +
                        if in_channels != out_channels
         | 
| 1171 | 
            +
                        else nn.Identity()
         | 
| 1172 | 
            +
                    )
         | 
| 1173 | 
            +
             | 
| 1174 | 
            +
                    self.timestep_conditioning = timestep_conditioning
         | 
| 1175 | 
            +
             | 
| 1176 | 
            +
                    if timestep_conditioning:
         | 
| 1177 | 
            +
                        self.scale_shift_table = nn.Parameter(
         | 
| 1178 | 
            +
                            torch.randn(4, in_channels) / in_channels**0.5
         | 
| 1179 | 
            +
                        )
         | 
| 1180 | 
            +
             | 
| 1181 | 
            +
                def _feed_spatial_noise(
         | 
| 1182 | 
            +
                    self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
         | 
| 1183 | 
            +
                ) -> torch.FloatTensor:
         | 
| 1184 | 
            +
                    spatial_shape = hidden_states.shape[-2:]
         | 
| 1185 | 
            +
                    device = hidden_states.device
         | 
| 1186 | 
            +
                    dtype = hidden_states.dtype
         | 
| 1187 | 
            +
             | 
| 1188 | 
            +
                    # similar to the "explicit noise inputs" method in style-gan
         | 
| 1189 | 
            +
                    spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
         | 
| 1190 | 
            +
                    scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
         | 
| 1191 | 
            +
                    hidden_states = hidden_states + scaled_noise
         | 
| 1192 | 
            +
             | 
| 1193 | 
            +
                    return hidden_states
         | 
| 1194 | 
            +
             | 
| 1195 | 
            +
                def forward(
         | 
| 1196 | 
            +
                    self,
         | 
| 1197 | 
            +
                    input_tensor: torch.FloatTensor,
         | 
| 1198 | 
            +
                    causal: bool = True,
         | 
| 1199 | 
            +
                    timestep: Optional[torch.Tensor] = None,
         | 
| 1200 | 
            +
                ) -> torch.FloatTensor:
         | 
| 1201 | 
            +
                    hidden_states = input_tensor
         | 
| 1202 | 
            +
                    batch_size = hidden_states.shape[0]
         | 
| 1203 | 
            +
             | 
| 1204 | 
            +
                    hidden_states = self.norm1(hidden_states)
         | 
| 1205 | 
            +
                    if self.timestep_conditioning:
         | 
| 1206 | 
            +
                        assert (
         | 
| 1207 | 
            +
                            timestep is not None
         | 
| 1208 | 
            +
                        ), "should pass timestep with timestep_conditioning=True"
         | 
| 1209 | 
            +
                        ada_values = self.scale_shift_table[
         | 
| 1210 | 
            +
                            None, ..., None, None, None
         | 
| 1211 | 
            +
                        ] + timestep.reshape(
         | 
| 1212 | 
            +
                            batch_size,
         | 
| 1213 | 
            +
                            4,
         | 
| 1214 | 
            +
                            -1,
         | 
| 1215 | 
            +
                            timestep.shape[-3],
         | 
| 1216 | 
            +
                            timestep.shape[-2],
         | 
| 1217 | 
            +
                            timestep.shape[-1],
         | 
| 1218 | 
            +
                        )
         | 
| 1219 | 
            +
                        shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
         | 
| 1220 | 
            +
             | 
| 1221 | 
            +
                        hidden_states = hidden_states * (1 + scale1) + shift1
         | 
| 1222 | 
            +
             | 
| 1223 | 
            +
                    hidden_states = self.non_linearity(hidden_states)
         | 
| 1224 | 
            +
             | 
| 1225 | 
            +
                    hidden_states = self.conv1(hidden_states, causal=causal)
         | 
| 1226 | 
            +
             | 
| 1227 | 
            +
                    if self.inject_noise:
         | 
| 1228 | 
            +
                        hidden_states = self._feed_spatial_noise(
         | 
| 1229 | 
            +
                            hidden_states, self.per_channel_scale1
         | 
| 1230 | 
            +
                        )
         | 
| 1231 | 
            +
             | 
| 1232 | 
            +
                    hidden_states = self.norm2(hidden_states)
         | 
| 1233 | 
            +
             | 
| 1234 | 
            +
                    if self.timestep_conditioning:
         | 
| 1235 | 
            +
                        hidden_states = hidden_states * (1 + scale2) + shift2
         | 
| 1236 | 
            +
             | 
| 1237 | 
            +
                    hidden_states = self.non_linearity(hidden_states)
         | 
| 1238 | 
            +
             | 
| 1239 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 1240 | 
            +
             | 
| 1241 | 
            +
                    hidden_states = self.conv2(hidden_states, causal=causal)
         | 
| 1242 | 
            +
             | 
| 1243 | 
            +
                    if self.inject_noise:
         | 
| 1244 | 
            +
                        hidden_states = self._feed_spatial_noise(
         | 
| 1245 | 
            +
                            hidden_states, self.per_channel_scale2
         | 
| 1246 | 
            +
                        )
         | 
| 1247 | 
            +
             | 
| 1248 | 
            +
                    input_tensor = self.norm3(input_tensor)
         | 
| 1249 | 
            +
             | 
| 1250 | 
            +
                    batch_size = input_tensor.shape[0]
         | 
| 1251 | 
            +
             | 
| 1252 | 
            +
                    input_tensor = self.conv_shortcut(input_tensor)
         | 
| 1253 | 
            +
             | 
| 1254 | 
            +
                    output_tensor = input_tensor + hidden_states
         | 
| 1255 | 
            +
             | 
| 1256 | 
            +
                    return output_tensor
         | 
| 1257 | 
            +
             | 
| 1258 | 
            +
             | 
| 1259 | 
            +
            def patchify(x, patch_size_hw, patch_size_t=1):
         | 
| 1260 | 
            +
                if patch_size_hw == 1 and patch_size_t == 1:
         | 
| 1261 | 
            +
                    return x
         | 
| 1262 | 
            +
                if x.dim() == 4:
         | 
| 1263 | 
            +
                    x = rearrange(
         | 
| 1264 | 
            +
                        x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
         | 
| 1265 | 
            +
                    )
         | 
| 1266 | 
            +
                elif x.dim() == 5:
         | 
| 1267 | 
            +
                    x = rearrange(
         | 
| 1268 | 
            +
                        x,
         | 
| 1269 | 
            +
                        "b c (f p) (h q) (w r) -> b (c p r q) f h w",
         | 
| 1270 | 
            +
                        p=patch_size_t,
         | 
| 1271 | 
            +
                        q=patch_size_hw,
         | 
| 1272 | 
            +
                        r=patch_size_hw,
         | 
| 1273 | 
            +
                    )
         | 
| 1274 | 
            +
                else:
         | 
| 1275 | 
            +
                    raise ValueError(f"Invalid input shape: {x.shape}")
         | 
| 1276 | 
            +
             | 
| 1277 | 
            +
                return x
         | 
| 1278 | 
            +
             | 
| 1279 | 
            +
             | 
| 1280 | 
            +
            def unpatchify(x, patch_size_hw, patch_size_t=1):
         | 
| 1281 | 
            +
                if patch_size_hw == 1 and patch_size_t == 1:
         | 
| 1282 | 
            +
                    return x
         | 
| 1283 | 
            +
             | 
| 1284 | 
            +
                if x.dim() == 4:
         | 
| 1285 | 
            +
                    x = rearrange(
         | 
| 1286 | 
            +
                        x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
         | 
| 1287 | 
            +
                    )
         | 
| 1288 | 
            +
                elif x.dim() == 5:
         | 
| 1289 | 
            +
                    x = rearrange(
         | 
| 1290 | 
            +
                        x,
         | 
| 1291 | 
            +
                        "b (c p r q) f h w -> b c (f p) (h q) (w r)",
         | 
| 1292 | 
            +
                        p=patch_size_t,
         | 
| 1293 | 
            +
                        q=patch_size_hw,
         | 
| 1294 | 
            +
                        r=patch_size_hw,
         | 
| 1295 | 
            +
                    )
         | 
| 1296 | 
            +
             | 
| 1297 | 
            +
                return x
         | 
| 1298 | 
            +
             | 
| 1299 | 
            +
             | 
| 1300 | 
            +
            def create_video_autoencoder_demo_config(
         | 
| 1301 | 
            +
                latent_channels: int = 64,
         | 
| 1302 | 
            +
            ):
         | 
| 1303 | 
            +
                encoder_blocks = [
         | 
| 1304 | 
            +
                    ("res_x", {"num_layers": 2}),
         | 
| 1305 | 
            +
                    ("compress_space_res", {"multiplier": 2}),
         | 
| 1306 | 
            +
                    ("res_x", {"num_layers": 2}),
         | 
| 1307 | 
            +
                    ("compress_time_res", {"multiplier": 2}),
         | 
| 1308 | 
            +
                    ("res_x", {"num_layers": 1}),
         | 
| 1309 | 
            +
                    ("compress_all_res", {"multiplier": 2}),
         | 
| 1310 | 
            +
                    ("res_x", {"num_layers": 1}),
         | 
| 1311 | 
            +
                    ("compress_all_res", {"multiplier": 2}),
         | 
| 1312 | 
            +
                    ("res_x", {"num_layers": 1}),
         | 
| 1313 | 
            +
                ]
         | 
| 1314 | 
            +
                decoder_blocks = [
         | 
| 1315 | 
            +
                    ("res_x", {"num_layers": 2, "inject_noise": False}),
         | 
| 1316 | 
            +
                    ("compress_all", {"residual": True, "multiplier": 2}),
         | 
| 1317 | 
            +
                    ("res_x", {"num_layers": 2, "inject_noise": False}),
         | 
| 1318 | 
            +
                    ("compress_all", {"residual": True, "multiplier": 2}),
         | 
| 1319 | 
            +
                    ("res_x", {"num_layers": 2, "inject_noise": False}),
         | 
| 1320 | 
            +
                    ("compress_all", {"residual": True, "multiplier": 2}),
         | 
| 1321 | 
            +
                    ("res_x", {"num_layers": 2, "inject_noise": False}),
         | 
| 1322 | 
            +
                ]
         | 
| 1323 | 
            +
                return {
         | 
| 1324 | 
            +
                    "_class_name": "CausalVideoAutoencoder",
         | 
| 1325 | 
            +
                    "dims": 3,
         | 
| 1326 | 
            +
                    "encoder_blocks": encoder_blocks,
         | 
| 1327 | 
            +
                    "decoder_blocks": decoder_blocks,
         | 
| 1328 | 
            +
                    "latent_channels": latent_channels,
         | 
| 1329 | 
            +
                    "norm_layer": "pixel_norm",
         | 
| 1330 | 
            +
                    "patch_size": 4,
         | 
| 1331 | 
            +
                    "latent_log_var": "uniform",
         | 
| 1332 | 
            +
                    "use_quant_conv": False,
         | 
| 1333 | 
            +
                    "causal_decoder": False,
         | 
| 1334 | 
            +
                    "timestep_conditioning": True,
         | 
| 1335 | 
            +
                    "spatial_padding_mode": "replicate",
         | 
| 1336 | 
            +
                }
         | 
| 1337 | 
            +
             | 
| 1338 | 
            +
             | 
| 1339 | 
            +
            def test_vae_patchify_unpatchify():
         | 
| 1340 | 
            +
                import torch
         | 
| 1341 | 
            +
             | 
| 1342 | 
            +
                x = torch.randn(2, 3, 8, 64, 64)
         | 
| 1343 | 
            +
                x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
         | 
| 1344 | 
            +
                x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
         | 
| 1345 | 
            +
                assert torch.allclose(x, x_unpatched)
         | 
| 1346 | 
            +
             | 
| 1347 | 
            +
             | 
| 1348 | 
            +
            def demo_video_autoencoder_forward_backward():
         | 
| 1349 | 
            +
                # Configuration for the VideoAutoencoder
         | 
| 1350 | 
            +
                config = create_video_autoencoder_demo_config()
         | 
| 1351 | 
            +
             | 
| 1352 | 
            +
                # Instantiate the VideoAutoencoder with the specified configuration
         | 
| 1353 | 
            +
                video_autoencoder = CausalVideoAutoencoder.from_config(config)
         | 
| 1354 | 
            +
             | 
| 1355 | 
            +
                print(video_autoencoder)
         | 
| 1356 | 
            +
                video_autoencoder.eval()
         | 
| 1357 | 
            +
                # Print the total number of parameters in the video autoencoder
         | 
| 1358 | 
            +
                total_params = sum(p.numel() for p in video_autoencoder.parameters())
         | 
| 1359 | 
            +
                print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
         | 
| 1360 | 
            +
             | 
| 1361 | 
            +
                # Create a mock input tensor simulating a batch of videos
         | 
| 1362 | 
            +
                # Shape: (batch_size, channels, depth, height, width)
         | 
| 1363 | 
            +
                # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
         | 
| 1364 | 
            +
                input_videos = torch.randn(2, 3, 17, 64, 64)
         | 
| 1365 | 
            +
             | 
| 1366 | 
            +
                # Forward pass: encode and decode the input videos
         | 
| 1367 | 
            +
                latent = video_autoencoder.encode(input_videos).latent_dist.mode()
         | 
| 1368 | 
            +
                print(f"input shape={input_videos.shape}")
         | 
| 1369 | 
            +
                print(f"latent shape={latent.shape}")
         | 
| 1370 | 
            +
             | 
| 1371 | 
            +
                timestep = torch.ones(input_videos.shape[0]) * 0.1
         | 
| 1372 | 
            +
                reconstructed_videos = video_autoencoder.decode(
         | 
| 1373 | 
            +
                    latent, target_shape=input_videos.shape, timestep=timestep
         | 
| 1374 | 
            +
                ).sample
         | 
| 1375 | 
            +
             | 
| 1376 | 
            +
                print(f"reconstructed shape={reconstructed_videos.shape}")
         | 
| 1377 | 
            +
             | 
| 1378 | 
            +
                # Validate that single image gets treated the same way as first frame
         | 
| 1379 | 
            +
                input_image = input_videos[:, :, :1, :, :]
         | 
| 1380 | 
            +
                image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
         | 
| 1381 | 
            +
                _ = video_autoencoder.decode(
         | 
| 1382 | 
            +
                    image_latent, target_shape=image_latent.shape, timestep=timestep
         | 
| 1383 | 
            +
                ).sample
         | 
| 1384 | 
            +
             | 
| 1385 | 
            +
                first_frame_latent = latent[:, :, :1, :, :]
         | 
| 1386 | 
            +
             | 
| 1387 | 
            +
                assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
         | 
| 1388 | 
            +
                # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
         | 
| 1389 | 
            +
                # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
         | 
| 1390 | 
            +
                # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()
         | 
| 1391 | 
            +
             | 
| 1392 | 
            +
                # Calculate the loss (e.g., mean squared error)
         | 
| 1393 | 
            +
                loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
         | 
| 1394 | 
            +
             | 
| 1395 | 
            +
                # Perform backward pass
         | 
| 1396 | 
            +
                loss.backward()
         | 
| 1397 | 
            +
             | 
| 1398 | 
            +
                print(f"Demo completed with loss: {loss.item()}")
         | 
| 1399 | 
            +
             | 
| 1400 | 
            +
             | 
| 1401 | 
            +
            # Ensure to call the demo function to execute the forward and backward pass
         | 
| 1402 | 
            +
            if __name__ == "__main__":
         | 
| 1403 | 
            +
                demo_video_autoencoder_forward_backward()
         | 
    	
        ltx_video/models/autoencoders/conv_nd_factory.py
    ADDED
    
    | @@ -0,0 +1,90 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Tuple, Union
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from ltx_video.models.autoencoders.dual_conv3d import DualConv3d
         | 
| 6 | 
            +
            from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def make_conv_nd(
         | 
| 10 | 
            +
                dims: Union[int, Tuple[int, int]],
         | 
| 11 | 
            +
                in_channels: int,
         | 
| 12 | 
            +
                out_channels: int,
         | 
| 13 | 
            +
                kernel_size: int,
         | 
| 14 | 
            +
                stride=1,
         | 
| 15 | 
            +
                padding=0,
         | 
| 16 | 
            +
                dilation=1,
         | 
| 17 | 
            +
                groups=1,
         | 
| 18 | 
            +
                bias=True,
         | 
| 19 | 
            +
                causal=False,
         | 
| 20 | 
            +
                spatial_padding_mode="zeros",
         | 
| 21 | 
            +
                temporal_padding_mode="zeros",
         | 
| 22 | 
            +
            ):
         | 
| 23 | 
            +
                if not (spatial_padding_mode == temporal_padding_mode or causal):
         | 
| 24 | 
            +
                    raise NotImplementedError("spatial and temporal padding modes must be equal")
         | 
| 25 | 
            +
                if dims == 2:
         | 
| 26 | 
            +
                    return torch.nn.Conv2d(
         | 
| 27 | 
            +
                        in_channels=in_channels,
         | 
| 28 | 
            +
                        out_channels=out_channels,
         | 
| 29 | 
            +
                        kernel_size=kernel_size,
         | 
| 30 | 
            +
                        stride=stride,
         | 
| 31 | 
            +
                        padding=padding,
         | 
| 32 | 
            +
                        dilation=dilation,
         | 
| 33 | 
            +
                        groups=groups,
         | 
| 34 | 
            +
                        bias=bias,
         | 
| 35 | 
            +
                        padding_mode=spatial_padding_mode,
         | 
| 36 | 
            +
                    )
         | 
| 37 | 
            +
                elif dims == 3:
         | 
| 38 | 
            +
                    if causal:
         | 
| 39 | 
            +
                        return CausalConv3d(
         | 
| 40 | 
            +
                            in_channels=in_channels,
         | 
| 41 | 
            +
                            out_channels=out_channels,
         | 
| 42 | 
            +
                            kernel_size=kernel_size,
         | 
| 43 | 
            +
                            stride=stride,
         | 
| 44 | 
            +
                            padding=padding,
         | 
| 45 | 
            +
                            dilation=dilation,
         | 
| 46 | 
            +
                            groups=groups,
         | 
| 47 | 
            +
                            bias=bias,
         | 
| 48 | 
            +
                            spatial_padding_mode=spatial_padding_mode,
         | 
| 49 | 
            +
                        )
         | 
| 50 | 
            +
                    return torch.nn.Conv3d(
         | 
| 51 | 
            +
                        in_channels=in_channels,
         | 
| 52 | 
            +
                        out_channels=out_channels,
         | 
| 53 | 
            +
                        kernel_size=kernel_size,
         | 
| 54 | 
            +
                        stride=stride,
         | 
| 55 | 
            +
                        padding=padding,
         | 
| 56 | 
            +
                        dilation=dilation,
         | 
| 57 | 
            +
                        groups=groups,
         | 
| 58 | 
            +
                        bias=bias,
         | 
| 59 | 
            +
                        padding_mode=spatial_padding_mode,
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                elif dims == (2, 1):
         | 
| 62 | 
            +
                    return DualConv3d(
         | 
| 63 | 
            +
                        in_channels=in_channels,
         | 
| 64 | 
            +
                        out_channels=out_channels,
         | 
| 65 | 
            +
                        kernel_size=kernel_size,
         | 
| 66 | 
            +
                        stride=stride,
         | 
| 67 | 
            +
                        padding=padding,
         | 
| 68 | 
            +
                        bias=bias,
         | 
| 69 | 
            +
                        padding_mode=spatial_padding_mode,
         | 
| 70 | 
            +
                    )
         | 
| 71 | 
            +
                else:
         | 
| 72 | 
            +
                    raise ValueError(f"unsupported dimensions: {dims}")
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def make_linear_nd(
         | 
| 76 | 
            +
                dims: int,
         | 
| 77 | 
            +
                in_channels: int,
         | 
| 78 | 
            +
                out_channels: int,
         | 
| 79 | 
            +
                bias=True,
         | 
| 80 | 
            +
            ):
         | 
| 81 | 
            +
                if dims == 2:
         | 
| 82 | 
            +
                    return torch.nn.Conv2d(
         | 
| 83 | 
            +
                        in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                elif dims == 3 or dims == (2, 1):
         | 
| 86 | 
            +
                    return torch.nn.Conv3d(
         | 
| 87 | 
            +
                        in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
         | 
| 88 | 
            +
                    )
         | 
| 89 | 
            +
                else:
         | 
| 90 | 
            +
                    raise ValueError(f"unsupported dimensions: {dims}")
         | 
    	
        ltx_video/models/autoencoders/dual_conv3d.py
    ADDED
    
    | @@ -0,0 +1,217 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from typing import Tuple, Union
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            from einops import rearrange
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class DualConv3d(nn.Module):
         | 
| 11 | 
            +
                def __init__(
         | 
| 12 | 
            +
                    self,
         | 
| 13 | 
            +
                    in_channels,
         | 
| 14 | 
            +
                    out_channels,
         | 
| 15 | 
            +
                    kernel_size,
         | 
| 16 | 
            +
                    stride: Union[int, Tuple[int, int, int]] = 1,
         | 
| 17 | 
            +
                    padding: Union[int, Tuple[int, int, int]] = 0,
         | 
| 18 | 
            +
                    dilation: Union[int, Tuple[int, int, int]] = 1,
         | 
| 19 | 
            +
                    groups=1,
         | 
| 20 | 
            +
                    bias=True,
         | 
| 21 | 
            +
                    padding_mode="zeros",
         | 
| 22 | 
            +
                ):
         | 
| 23 | 
            +
                    super(DualConv3d, self).__init__()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    self.in_channels = in_channels
         | 
| 26 | 
            +
                    self.out_channels = out_channels
         | 
| 27 | 
            +
                    self.padding_mode = padding_mode
         | 
| 28 | 
            +
                    # Ensure kernel_size, stride, padding, and dilation are tuples of length 3
         | 
| 29 | 
            +
                    if isinstance(kernel_size, int):
         | 
| 30 | 
            +
                        kernel_size = (kernel_size, kernel_size, kernel_size)
         | 
| 31 | 
            +
                    if kernel_size == (1, 1, 1):
         | 
| 32 | 
            +
                        raise ValueError(
         | 
| 33 | 
            +
                            "kernel_size must be greater than 1. Use make_linear_nd instead."
         | 
| 34 | 
            +
                        )
         | 
| 35 | 
            +
                    if isinstance(stride, int):
         | 
| 36 | 
            +
                        stride = (stride, stride, stride)
         | 
| 37 | 
            +
                    if isinstance(padding, int):
         | 
| 38 | 
            +
                        padding = (padding, padding, padding)
         | 
| 39 | 
            +
                    if isinstance(dilation, int):
         | 
| 40 | 
            +
                        dilation = (dilation, dilation, dilation)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    # Set parameters for convolutions
         | 
| 43 | 
            +
                    self.groups = groups
         | 
| 44 | 
            +
                    self.bias = bias
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    # Define the size of the channels after the first convolution
         | 
| 47 | 
            +
                    intermediate_channels = (
         | 
| 48 | 
            +
                        out_channels if in_channels < out_channels else in_channels
         | 
| 49 | 
            +
                    )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    # Define parameters for the first convolution
         | 
| 52 | 
            +
                    self.weight1 = nn.Parameter(
         | 
| 53 | 
            +
                        torch.Tensor(
         | 
| 54 | 
            +
                            intermediate_channels,
         | 
| 55 | 
            +
                            in_channels // groups,
         | 
| 56 | 
            +
                            1,
         | 
| 57 | 
            +
                            kernel_size[1],
         | 
| 58 | 
            +
                            kernel_size[2],
         | 
| 59 | 
            +
                        )
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                    self.stride1 = (1, stride[1], stride[2])
         | 
| 62 | 
            +
                    self.padding1 = (0, padding[1], padding[2])
         | 
| 63 | 
            +
                    self.dilation1 = (1, dilation[1], dilation[2])
         | 
| 64 | 
            +
                    if bias:
         | 
| 65 | 
            +
                        self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
         | 
| 66 | 
            +
                    else:
         | 
| 67 | 
            +
                        self.register_parameter("bias1", None)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # Define parameters for the second convolution
         | 
| 70 | 
            +
                    self.weight2 = nn.Parameter(
         | 
| 71 | 
            +
                        torch.Tensor(
         | 
| 72 | 
            +
                            out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
         | 
| 73 | 
            +
                        )
         | 
| 74 | 
            +
                    )
         | 
| 75 | 
            +
                    self.stride2 = (stride[0], 1, 1)
         | 
| 76 | 
            +
                    self.padding2 = (padding[0], 0, 0)
         | 
| 77 | 
            +
                    self.dilation2 = (dilation[0], 1, 1)
         | 
| 78 | 
            +
                    if bias:
         | 
| 79 | 
            +
                        self.bias2 = nn.Parameter(torch.Tensor(out_channels))
         | 
| 80 | 
            +
                    else:
         | 
| 81 | 
            +
                        self.register_parameter("bias2", None)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    # Initialize weights and biases
         | 
| 84 | 
            +
                    self.reset_parameters()
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def reset_parameters(self):
         | 
| 87 | 
            +
                    nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
         | 
| 88 | 
            +
                    nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
         | 
| 89 | 
            +
                    if self.bias:
         | 
| 90 | 
            +
                        fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
         | 
| 91 | 
            +
                        bound1 = 1 / math.sqrt(fan_in1)
         | 
| 92 | 
            +
                        nn.init.uniform_(self.bias1, -bound1, bound1)
         | 
| 93 | 
            +
                        fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
         | 
| 94 | 
            +
                        bound2 = 1 / math.sqrt(fan_in2)
         | 
| 95 | 
            +
                        nn.init.uniform_(self.bias2, -bound2, bound2)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def forward(self, x, use_conv3d=False, skip_time_conv=False):
         | 
| 98 | 
            +
                    if use_conv3d:
         | 
| 99 | 
            +
                        return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
         | 
| 100 | 
            +
                    else:
         | 
| 101 | 
            +
                        return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def forward_with_3d(self, x, skip_time_conv):
         | 
| 104 | 
            +
                    # First convolution
         | 
| 105 | 
            +
                    x = F.conv3d(
         | 
| 106 | 
            +
                        x,
         | 
| 107 | 
            +
                        self.weight1,
         | 
| 108 | 
            +
                        self.bias1,
         | 
| 109 | 
            +
                        self.stride1,
         | 
| 110 | 
            +
                        self.padding1,
         | 
| 111 | 
            +
                        self.dilation1,
         | 
| 112 | 
            +
                        self.groups,
         | 
| 113 | 
            +
                        padding_mode=self.padding_mode,
         | 
| 114 | 
            +
                    )
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    if skip_time_conv:
         | 
| 117 | 
            +
                        return x
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    # Second convolution
         | 
| 120 | 
            +
                    x = F.conv3d(
         | 
| 121 | 
            +
                        x,
         | 
| 122 | 
            +
                        self.weight2,
         | 
| 123 | 
            +
                        self.bias2,
         | 
| 124 | 
            +
                        self.stride2,
         | 
| 125 | 
            +
                        self.padding2,
         | 
| 126 | 
            +
                        self.dilation2,
         | 
| 127 | 
            +
                        self.groups,
         | 
| 128 | 
            +
                        padding_mode=self.padding_mode,
         | 
| 129 | 
            +
                    )
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    return x
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def forward_with_2d(self, x, skip_time_conv):
         | 
| 134 | 
            +
                    b, c, d, h, w = x.shape
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    # First 2D convolution
         | 
| 137 | 
            +
                    x = rearrange(x, "b c d h w -> (b d) c h w")
         | 
| 138 | 
            +
                    # Squeeze the depth dimension out of weight1 since it's 1
         | 
| 139 | 
            +
                    weight1 = self.weight1.squeeze(2)
         | 
| 140 | 
            +
                    # Select stride, padding, and dilation for the 2D convolution
         | 
| 141 | 
            +
                    stride1 = (self.stride1[1], self.stride1[2])
         | 
| 142 | 
            +
                    padding1 = (self.padding1[1], self.padding1[2])
         | 
| 143 | 
            +
                    dilation1 = (self.dilation1[1], self.dilation1[2])
         | 
| 144 | 
            +
                    x = F.conv2d(
         | 
| 145 | 
            +
                        x,
         | 
| 146 | 
            +
                        weight1,
         | 
| 147 | 
            +
                        self.bias1,
         | 
| 148 | 
            +
                        stride1,
         | 
| 149 | 
            +
                        padding1,
         | 
| 150 | 
            +
                        dilation1,
         | 
| 151 | 
            +
                        self.groups,
         | 
| 152 | 
            +
                        padding_mode=self.padding_mode,
         | 
| 153 | 
            +
                    )
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    _, _, h, w = x.shape
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    if skip_time_conv:
         | 
| 158 | 
            +
                        x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
         | 
| 159 | 
            +
                        return x
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
         | 
| 162 | 
            +
                    x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    # Reshape weight2 to match the expected dimensions for conv1d
         | 
| 165 | 
            +
                    weight2 = self.weight2.squeeze(-1).squeeze(-1)
         | 
| 166 | 
            +
                    # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
         | 
| 167 | 
            +
                    stride2 = self.stride2[0]
         | 
| 168 | 
            +
                    padding2 = self.padding2[0]
         | 
| 169 | 
            +
                    dilation2 = self.dilation2[0]
         | 
| 170 | 
            +
                    x = F.conv1d(
         | 
| 171 | 
            +
                        x,
         | 
| 172 | 
            +
                        weight2,
         | 
| 173 | 
            +
                        self.bias2,
         | 
| 174 | 
            +
                        stride2,
         | 
| 175 | 
            +
                        padding2,
         | 
| 176 | 
            +
                        dilation2,
         | 
| 177 | 
            +
                        self.groups,
         | 
| 178 | 
            +
                        padding_mode=self.padding_mode,
         | 
| 179 | 
            +
                    )
         | 
| 180 | 
            +
                    x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    return x
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                @property
         | 
| 185 | 
            +
                def weight(self):
         | 
| 186 | 
            +
                    return self.weight2
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            def test_dual_conv3d_consistency():
         | 
| 190 | 
            +
                # Initialize parameters
         | 
| 191 | 
            +
                in_channels = 3
         | 
| 192 | 
            +
                out_channels = 5
         | 
| 193 | 
            +
                kernel_size = (3, 3, 3)
         | 
| 194 | 
            +
                stride = (2, 2, 2)
         | 
| 195 | 
            +
                padding = (1, 1, 1)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                # Create an instance of the DualConv3d class
         | 
| 198 | 
            +
                dual_conv3d = DualConv3d(
         | 
| 199 | 
            +
                    in_channels=in_channels,
         | 
| 200 | 
            +
                    out_channels=out_channels,
         | 
| 201 | 
            +
                    kernel_size=kernel_size,
         | 
| 202 | 
            +
                    stride=stride,
         | 
| 203 | 
            +
                    padding=padding,
         | 
| 204 | 
            +
                    bias=True,
         | 
| 205 | 
            +
                )
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                # Example input tensor
         | 
| 208 | 
            +
                test_input = torch.randn(1, 3, 10, 10, 10)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                # Perform forward passes with both 3D and 2D settings
         | 
| 211 | 
            +
                output_conv3d = dual_conv3d(test_input, use_conv3d=True)
         | 
| 212 | 
            +
                output_2d = dual_conv3d(test_input, use_conv3d=False)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                # Assert that the outputs from both methods are sufficiently close
         | 
| 215 | 
            +
                assert torch.allclose(
         | 
| 216 | 
            +
                    output_conv3d, output_2d, atol=1e-6
         | 
| 217 | 
            +
                ), "Outputs are not consistent between 3D and 2D convolutions."
         | 
    	
        ltx_video/models/autoencoders/latent_upsampler.py
    ADDED
    
    | @@ -0,0 +1,203 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Optional, Union
         | 
| 2 | 
            +
            from pathlib import Path
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            from einops import rearrange
         | 
| 9 | 
            +
            from diffusers import ConfigMixin, ModelMixin
         | 
| 10 | 
            +
            from safetensors.torch import safe_open
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from ltx_video.models.autoencoders.pixel_shuffle import PixelShuffleND
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class ResBlock(nn.Module):
         | 
| 16 | 
            +
                def __init__(
         | 
| 17 | 
            +
                    self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
         | 
| 18 | 
            +
                ):
         | 
| 19 | 
            +
                    super().__init__()
         | 
| 20 | 
            +
                    if mid_channels is None:
         | 
| 21 | 
            +
                        mid_channels = channels
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    Conv = nn.Conv2d if dims == 2 else nn.Conv3d
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
         | 
| 26 | 
            +
                    self.norm1 = nn.GroupNorm(32, mid_channels)
         | 
| 27 | 
            +
                    self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
         | 
| 28 | 
            +
                    self.norm2 = nn.GroupNorm(32, channels)
         | 
| 29 | 
            +
                    self.activation = nn.SiLU()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 32 | 
            +
                    residual = x
         | 
| 33 | 
            +
                    x = self.conv1(x)
         | 
| 34 | 
            +
                    x = self.norm1(x)
         | 
| 35 | 
            +
                    x = self.activation(x)
         | 
| 36 | 
            +
                    x = self.conv2(x)
         | 
| 37 | 
            +
                    x = self.norm2(x)
         | 
| 38 | 
            +
                    x = self.activation(x + residual)
         | 
| 39 | 
            +
                    return x
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            class LatentUpsampler(ModelMixin, ConfigMixin):
         | 
| 43 | 
            +
                """
         | 
| 44 | 
            +
                Model to spatially upsample VAE latents.
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                Args:
         | 
| 47 | 
            +
                    in_channels (`int`): Number of channels in the input latent
         | 
| 48 | 
            +
                    mid_channels (`int`): Number of channels in the middle layers
         | 
| 49 | 
            +
                    num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
         | 
| 50 | 
            +
                    dims (`int`): Number of dimensions for convolutions (2 or 3)
         | 
| 51 | 
            +
                    spatial_upsample (`bool`): Whether to spatially upsample the latent
         | 
| 52 | 
            +
                    temporal_upsample (`bool`): Whether to temporally upsample the latent
         | 
| 53 | 
            +
                """
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __init__(
         | 
| 56 | 
            +
                    self,
         | 
| 57 | 
            +
                    in_channels: int = 128,
         | 
| 58 | 
            +
                    mid_channels: int = 512,
         | 
| 59 | 
            +
                    num_blocks_per_stage: int = 4,
         | 
| 60 | 
            +
                    dims: int = 3,
         | 
| 61 | 
            +
                    spatial_upsample: bool = True,
         | 
| 62 | 
            +
                    temporal_upsample: bool = False,
         | 
| 63 | 
            +
                ):
         | 
| 64 | 
            +
                    super().__init__()
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    self.in_channels = in_channels
         | 
| 67 | 
            +
                    self.mid_channels = mid_channels
         | 
| 68 | 
            +
                    self.num_blocks_per_stage = num_blocks_per_stage
         | 
| 69 | 
            +
                    self.dims = dims
         | 
| 70 | 
            +
                    self.spatial_upsample = spatial_upsample
         | 
| 71 | 
            +
                    self.temporal_upsample = temporal_upsample
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    Conv = nn.Conv2d if dims == 2 else nn.Conv3d
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
         | 
| 76 | 
            +
                    self.initial_norm = nn.GroupNorm(32, mid_channels)
         | 
| 77 | 
            +
                    self.initial_activation = nn.SiLU()
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    self.res_blocks = nn.ModuleList(
         | 
| 80 | 
            +
                        [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
         | 
| 81 | 
            +
                    )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    if spatial_upsample and temporal_upsample:
         | 
| 84 | 
            +
                        self.upsampler = nn.Sequential(
         | 
| 85 | 
            +
                            nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
         | 
| 86 | 
            +
                            PixelShuffleND(3),
         | 
| 87 | 
            +
                        )
         | 
| 88 | 
            +
                    elif spatial_upsample:
         | 
| 89 | 
            +
                        self.upsampler = nn.Sequential(
         | 
| 90 | 
            +
                            nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
         | 
| 91 | 
            +
                            PixelShuffleND(2),
         | 
| 92 | 
            +
                        )
         | 
| 93 | 
            +
                    elif temporal_upsample:
         | 
| 94 | 
            +
                        self.upsampler = nn.Sequential(
         | 
| 95 | 
            +
                            nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
         | 
| 96 | 
            +
                            PixelShuffleND(1),
         | 
| 97 | 
            +
                        )
         | 
| 98 | 
            +
                    else:
         | 
| 99 | 
            +
                        raise ValueError(
         | 
| 100 | 
            +
                            "Either spatial_upsample or temporal_upsample must be True"
         | 
| 101 | 
            +
                        )
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    self.post_upsample_res_blocks = nn.ModuleList(
         | 
| 104 | 
            +
                        [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def forward(self, latent: torch.Tensor) -> torch.Tensor:
         | 
| 110 | 
            +
                    b, c, f, h, w = latent.shape
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    if self.dims == 2:
         | 
| 113 | 
            +
                        x = rearrange(latent, "b c f h w -> (b f) c h w")
         | 
| 114 | 
            +
                        x = self.initial_conv(x)
         | 
| 115 | 
            +
                        x = self.initial_norm(x)
         | 
| 116 | 
            +
                        x = self.initial_activation(x)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                        for block in self.res_blocks:
         | 
| 119 | 
            +
                            x = block(x)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                        x = self.upsampler(x)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                        for block in self.post_upsample_res_blocks:
         | 
| 124 | 
            +
                            x = block(x)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        x = self.final_conv(x)
         | 
| 127 | 
            +
                        x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
         | 
| 128 | 
            +
                    else:
         | 
| 129 | 
            +
                        x = self.initial_conv(latent)
         | 
| 130 | 
            +
                        x = self.initial_norm(x)
         | 
| 131 | 
            +
                        x = self.initial_activation(x)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                        for block in self.res_blocks:
         | 
| 134 | 
            +
                            x = block(x)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                        if self.temporal_upsample:
         | 
| 137 | 
            +
                            x = self.upsampler(x)
         | 
| 138 | 
            +
                            x = x[:, :, 1:, :, :]
         | 
| 139 | 
            +
                        else:
         | 
| 140 | 
            +
                            x = rearrange(x, "b c f h w -> (b f) c h w")
         | 
| 141 | 
            +
                            x = self.upsampler(x)
         | 
| 142 | 
            +
                            x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                        for block in self.post_upsample_res_blocks:
         | 
| 145 | 
            +
                            x = block(x)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        x = self.final_conv(x)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    return x
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                @classmethod
         | 
| 152 | 
            +
                def from_config(cls, config):
         | 
| 153 | 
            +
                    return cls(
         | 
| 154 | 
            +
                        in_channels=config.get("in_channels", 4),
         | 
| 155 | 
            +
                        mid_channels=config.get("mid_channels", 128),
         | 
| 156 | 
            +
                        num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
         | 
| 157 | 
            +
                        dims=config.get("dims", 2),
         | 
| 158 | 
            +
                        spatial_upsample=config.get("spatial_upsample", True),
         | 
| 159 | 
            +
                        temporal_upsample=config.get("temporal_upsample", False),
         | 
| 160 | 
            +
                    )
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def config(self):
         | 
| 163 | 
            +
                    return {
         | 
| 164 | 
            +
                        "_class_name": "LatentUpsampler",
         | 
| 165 | 
            +
                        "in_channels": self.in_channels,
         | 
| 166 | 
            +
                        "mid_channels": self.mid_channels,
         | 
| 167 | 
            +
                        "num_blocks_per_stage": self.num_blocks_per_stage,
         | 
| 168 | 
            +
                        "dims": self.dims,
         | 
| 169 | 
            +
                        "spatial_upsample": self.spatial_upsample,
         | 
| 170 | 
            +
                        "temporal_upsample": self.temporal_upsample,
         | 
| 171 | 
            +
                    }
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                @classmethod
         | 
| 174 | 
            +
                def from_pretrained(
         | 
| 175 | 
            +
                    cls,
         | 
| 176 | 
            +
                    pretrained_model_path: Optional[Union[str, os.PathLike]],
         | 
| 177 | 
            +
                    *args,
         | 
| 178 | 
            +
                    **kwargs,
         | 
| 179 | 
            +
                ):
         | 
| 180 | 
            +
                    pretrained_model_path = Path(pretrained_model_path)
         | 
| 181 | 
            +
                    if pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
         | 
| 182 | 
            +
                        ".safetensors"
         | 
| 183 | 
            +
                    ):
         | 
| 184 | 
            +
                        state_dict = {}
         | 
| 185 | 
            +
                        with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
         | 
| 186 | 
            +
                            metadata = f.metadata()
         | 
| 187 | 
            +
                            for k in f.keys():
         | 
| 188 | 
            +
                                state_dict[k] = f.get_tensor(k)
         | 
| 189 | 
            +
                        config = json.loads(metadata["config"])
         | 
| 190 | 
            +
                        with torch.device("meta"):
         | 
| 191 | 
            +
                            latent_upsampler = LatentUpsampler.from_config(config)
         | 
| 192 | 
            +
                        latent_upsampler.load_state_dict(state_dict, assign=True)
         | 
| 193 | 
            +
                    return latent_upsampler
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            if __name__ == "__main__":
         | 
| 197 | 
            +
                latent_upsampler = LatentUpsampler(num_blocks_per_stage=4, dims=3)
         | 
| 198 | 
            +
                print(latent_upsampler)
         | 
| 199 | 
            +
                total_params = sum(p.numel() for p in latent_upsampler.parameters())
         | 
| 200 | 
            +
                print(f"Total number of parameters: {total_params:,}")
         | 
| 201 | 
            +
                latent = torch.randn(1, 128, 9, 16, 16)
         | 
| 202 | 
            +
                upsampled_latent = latent_upsampler(latent)
         | 
| 203 | 
            +
                print(f"Upsampled latent shape: {upsampled_latent.shape}")
         | 
    	
        ltx_video/models/autoencoders/pixel_norm.py
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class PixelNorm(nn.Module):
         | 
| 6 | 
            +
                def __init__(self, dim=1, eps=1e-8):
         | 
| 7 | 
            +
                    super(PixelNorm, self).__init__()
         | 
| 8 | 
            +
                    self.dim = dim
         | 
| 9 | 
            +
                    self.eps = eps
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                def forward(self, x):
         | 
| 12 | 
            +
                    return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
         | 
    	
        ltx_video/models/autoencoders/pixel_shuffle.py
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch.nn as nn
         | 
| 2 | 
            +
            from einops import rearrange
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            class PixelShuffleND(nn.Module):
         | 
| 6 | 
            +
                def __init__(self, dims, upscale_factors=(2, 2, 2)):
         | 
| 7 | 
            +
                    super().__init__()
         | 
| 8 | 
            +
                    assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
         | 
| 9 | 
            +
                    self.dims = dims
         | 
| 10 | 
            +
                    self.upscale_factors = upscale_factors
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                def forward(self, x):
         | 
| 13 | 
            +
                    if self.dims == 3:
         | 
| 14 | 
            +
                        return rearrange(
         | 
| 15 | 
            +
                            x,
         | 
| 16 | 
            +
                            "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
         | 
| 17 | 
            +
                            p1=self.upscale_factors[0],
         | 
| 18 | 
            +
                            p2=self.upscale_factors[1],
         | 
| 19 | 
            +
                            p3=self.upscale_factors[2],
         | 
| 20 | 
            +
                        )
         | 
| 21 | 
            +
                    elif self.dims == 2:
         | 
| 22 | 
            +
                        return rearrange(
         | 
| 23 | 
            +
                            x,
         | 
| 24 | 
            +
                            "b (c p1 p2) h w -> b c (h p1) (w p2)",
         | 
| 25 | 
            +
                            p1=self.upscale_factors[0],
         | 
| 26 | 
            +
                            p2=self.upscale_factors[1],
         | 
| 27 | 
            +
                        )
         | 
| 28 | 
            +
                    elif self.dims == 1:
         | 
| 29 | 
            +
                        return rearrange(
         | 
| 30 | 
            +
                            x,
         | 
| 31 | 
            +
                            "b (c p1) f h w -> b c (f p1) h w",
         | 
| 32 | 
            +
                            p1=self.upscale_factors[0],
         | 
| 33 | 
            +
                        )
         | 
    	
        ltx_video/models/autoencoders/vae.py
    ADDED
    
    | @@ -0,0 +1,380 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Optional, Union
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import inspect
         | 
| 5 | 
            +
            import math
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            from diffusers import ConfigMixin, ModelMixin
         | 
| 8 | 
            +
            from diffusers.models.autoencoders.vae import (
         | 
| 9 | 
            +
                DecoderOutput,
         | 
| 10 | 
            +
                DiagonalGaussianDistribution,
         | 
| 11 | 
            +
            )
         | 
| 12 | 
            +
            from diffusers.models.modeling_outputs import AutoencoderKLOutput
         | 
| 13 | 
            +
            from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
         | 
| 17 | 
            +
                """Variational Autoencoder (VAE) model with KL loss.
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
         | 
| 20 | 
            +
                This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                Args:
         | 
| 23 | 
            +
                    encoder (`nn.Module`):
         | 
| 24 | 
            +
                        Encoder module.
         | 
| 25 | 
            +
                    decoder (`nn.Module`):
         | 
| 26 | 
            +
                        Decoder module.
         | 
| 27 | 
            +
                    latent_channels (`int`, *optional*, defaults to 4):
         | 
| 28 | 
            +
                        Number of latent channels.
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def __init__(
         | 
| 32 | 
            +
                    self,
         | 
| 33 | 
            +
                    encoder: nn.Module,
         | 
| 34 | 
            +
                    decoder: nn.Module,
         | 
| 35 | 
            +
                    latent_channels: int = 4,
         | 
| 36 | 
            +
                    dims: int = 2,
         | 
| 37 | 
            +
                    sample_size=512,
         | 
| 38 | 
            +
                    use_quant_conv: bool = True,
         | 
| 39 | 
            +
                    normalize_latent_channels: bool = False,
         | 
| 40 | 
            +
                ):
         | 
| 41 | 
            +
                    super().__init__()
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    # pass init params to Encoder
         | 
| 44 | 
            +
                    self.encoder = encoder
         | 
| 45 | 
            +
                    self.use_quant_conv = use_quant_conv
         | 
| 46 | 
            +
                    self.normalize_latent_channels = normalize_latent_channels
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    # pass init params to Decoder
         | 
| 49 | 
            +
                    quant_dims = 2 if dims == 2 else 3
         | 
| 50 | 
            +
                    self.decoder = decoder
         | 
| 51 | 
            +
                    if use_quant_conv:
         | 
| 52 | 
            +
                        self.quant_conv = make_conv_nd(
         | 
| 53 | 
            +
                            quant_dims, 2 * latent_channels, 2 * latent_channels, 1
         | 
| 54 | 
            +
                        )
         | 
| 55 | 
            +
                        self.post_quant_conv = make_conv_nd(
         | 
| 56 | 
            +
                            quant_dims, latent_channels, latent_channels, 1
         | 
| 57 | 
            +
                        )
         | 
| 58 | 
            +
                    else:
         | 
| 59 | 
            +
                        self.quant_conv = nn.Identity()
         | 
| 60 | 
            +
                        self.post_quant_conv = nn.Identity()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    if normalize_latent_channels:
         | 
| 63 | 
            +
                        if dims == 2:
         | 
| 64 | 
            +
                            self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False)
         | 
| 65 | 
            +
                        else:
         | 
| 66 | 
            +
                            self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False)
         | 
| 67 | 
            +
                    else:
         | 
| 68 | 
            +
                        self.latent_norm_out = nn.Identity()
         | 
| 69 | 
            +
                    self.use_z_tiling = False
         | 
| 70 | 
            +
                    self.use_hw_tiling = False
         | 
| 71 | 
            +
                    self.dims = dims
         | 
| 72 | 
            +
                    self.z_sample_size = 1
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    self.decoder_params = inspect.signature(self.decoder.forward).parameters
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    # only relevant if vae tiling is enabled
         | 
| 77 | 
            +
                    self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25):
         | 
| 80 | 
            +
                    self.tile_sample_min_size = sample_size
         | 
| 81 | 
            +
                    num_blocks = len(self.encoder.down_blocks)
         | 
| 82 | 
            +
                    self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1)))
         | 
| 83 | 
            +
                    self.tile_overlap_factor = overlap_factor
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def enable_z_tiling(self, z_sample_size: int = 8):
         | 
| 86 | 
            +
                    r"""
         | 
| 87 | 
            +
                    Enable tiling during VAE decoding.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several
         | 
| 90 | 
            +
                    steps. This is useful to save some memory and allow larger batch sizes.
         | 
| 91 | 
            +
                    """
         | 
| 92 | 
            +
                    self.use_z_tiling = z_sample_size > 1
         | 
| 93 | 
            +
                    self.z_sample_size = z_sample_size
         | 
| 94 | 
            +
                    assert (
         | 
| 95 | 
            +
                        z_sample_size % 8 == 0 or z_sample_size == 1
         | 
| 96 | 
            +
                    ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}."
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def disable_z_tiling(self):
         | 
| 99 | 
            +
                    r"""
         | 
| 100 | 
            +
                    Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing
         | 
| 101 | 
            +
                    decoding in one step.
         | 
| 102 | 
            +
                    """
         | 
| 103 | 
            +
                    self.use_z_tiling = False
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def enable_hw_tiling(self):
         | 
| 106 | 
            +
                    r"""
         | 
| 107 | 
            +
                    Enable tiling during VAE decoding along the height and width dimension.
         | 
| 108 | 
            +
                    """
         | 
| 109 | 
            +
                    self.use_hw_tiling = True
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def disable_hw_tiling(self):
         | 
| 112 | 
            +
                    r"""
         | 
| 113 | 
            +
                    Disable tiling during VAE decoding along the height and width dimension.
         | 
| 114 | 
            +
                    """
         | 
| 115 | 
            +
                    self.use_hw_tiling = False
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True):
         | 
| 118 | 
            +
                    overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
         | 
| 119 | 
            +
                    blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
         | 
| 120 | 
            +
                    row_limit = self.tile_latent_min_size - blend_extent
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    # Split the image into 512x512 tiles and encode them separately.
         | 
| 123 | 
            +
                    rows = []
         | 
| 124 | 
            +
                    for i in range(0, x.shape[3], overlap_size):
         | 
| 125 | 
            +
                        row = []
         | 
| 126 | 
            +
                        for j in range(0, x.shape[4], overlap_size):
         | 
| 127 | 
            +
                            tile = x[
         | 
| 128 | 
            +
                                :,
         | 
| 129 | 
            +
                                :,
         | 
| 130 | 
            +
                                :,
         | 
| 131 | 
            +
                                i : i + self.tile_sample_min_size,
         | 
| 132 | 
            +
                                j : j + self.tile_sample_min_size,
         | 
| 133 | 
            +
                            ]
         | 
| 134 | 
            +
                            tile = self.encoder(tile)
         | 
| 135 | 
            +
                            tile = self.quant_conv(tile)
         | 
| 136 | 
            +
                            row.append(tile)
         | 
| 137 | 
            +
                        rows.append(row)
         | 
| 138 | 
            +
                    result_rows = []
         | 
| 139 | 
            +
                    for i, row in enumerate(rows):
         | 
| 140 | 
            +
                        result_row = []
         | 
| 141 | 
            +
                        for j, tile in enumerate(row):
         | 
| 142 | 
            +
                            # blend the above tile and the left tile
         | 
| 143 | 
            +
                            # to the current tile and add the current tile to the result row
         | 
| 144 | 
            +
                            if i > 0:
         | 
| 145 | 
            +
                                tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
         | 
| 146 | 
            +
                            if j > 0:
         | 
| 147 | 
            +
                                tile = self.blend_h(row[j - 1], tile, blend_extent)
         | 
| 148 | 
            +
                            result_row.append(tile[:, :, :, :row_limit, :row_limit])
         | 
| 149 | 
            +
                        result_rows.append(torch.cat(result_row, dim=4))
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    moments = torch.cat(result_rows, dim=3)
         | 
| 152 | 
            +
                    return moments
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                def blend_z(
         | 
| 155 | 
            +
                    self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
         | 
| 156 | 
            +
                ) -> torch.Tensor:
         | 
| 157 | 
            +
                    blend_extent = min(a.shape[2], b.shape[2], blend_extent)
         | 
| 158 | 
            +
                    for z in range(blend_extent):
         | 
| 159 | 
            +
                        b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
         | 
| 160 | 
            +
                            1 - z / blend_extent
         | 
| 161 | 
            +
                        ) + b[:, :, z, :, :] * (z / blend_extent)
         | 
| 162 | 
            +
                    return b
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def blend_v(
         | 
| 165 | 
            +
                    self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
         | 
| 166 | 
            +
                ) -> torch.Tensor:
         | 
| 167 | 
            +
                    blend_extent = min(a.shape[3], b.shape[3], blend_extent)
         | 
| 168 | 
            +
                    for y in range(blend_extent):
         | 
| 169 | 
            +
                        b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
         | 
| 170 | 
            +
                            1 - y / blend_extent
         | 
| 171 | 
            +
                        ) + b[:, :, :, y, :] * (y / blend_extent)
         | 
| 172 | 
            +
                    return b
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def blend_h(
         | 
| 175 | 
            +
                    self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
         | 
| 176 | 
            +
                ) -> torch.Tensor:
         | 
| 177 | 
            +
                    blend_extent = min(a.shape[4], b.shape[4], blend_extent)
         | 
| 178 | 
            +
                    for x in range(blend_extent):
         | 
| 179 | 
            +
                        b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
         | 
| 180 | 
            +
                            1 - x / blend_extent
         | 
| 181 | 
            +
                        ) + b[:, :, :, :, x] * (x / blend_extent)
         | 
| 182 | 
            +
                    return b
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
         | 
| 185 | 
            +
                    overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
         | 
| 186 | 
            +
                    blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
         | 
| 187 | 
            +
                    row_limit = self.tile_sample_min_size - blend_extent
         | 
| 188 | 
            +
                    tile_target_shape = (
         | 
| 189 | 
            +
                        *target_shape[:3],
         | 
| 190 | 
            +
                        self.tile_sample_min_size,
         | 
| 191 | 
            +
                        self.tile_sample_min_size,
         | 
| 192 | 
            +
                    )
         | 
| 193 | 
            +
                    # Split z into overlapping 64x64 tiles and decode them separately.
         | 
| 194 | 
            +
                    # The tiles have an overlap to avoid seams between tiles.
         | 
| 195 | 
            +
                    rows = []
         | 
| 196 | 
            +
                    for i in range(0, z.shape[3], overlap_size):
         | 
| 197 | 
            +
                        row = []
         | 
| 198 | 
            +
                        for j in range(0, z.shape[4], overlap_size):
         | 
| 199 | 
            +
                            tile = z[
         | 
| 200 | 
            +
                                :,
         | 
| 201 | 
            +
                                :,
         | 
| 202 | 
            +
                                :,
         | 
| 203 | 
            +
                                i : i + self.tile_latent_min_size,
         | 
| 204 | 
            +
                                j : j + self.tile_latent_min_size,
         | 
| 205 | 
            +
                            ]
         | 
| 206 | 
            +
                            tile = self.post_quant_conv(tile)
         | 
| 207 | 
            +
                            decoded = self.decoder(tile, target_shape=tile_target_shape)
         | 
| 208 | 
            +
                            row.append(decoded)
         | 
| 209 | 
            +
                        rows.append(row)
         | 
| 210 | 
            +
                    result_rows = []
         | 
| 211 | 
            +
                    for i, row in enumerate(rows):
         | 
| 212 | 
            +
                        result_row = []
         | 
| 213 | 
            +
                        for j, tile in enumerate(row):
         | 
| 214 | 
            +
                            # blend the above tile and the left tile
         | 
| 215 | 
            +
                            # to the current tile and add the current tile to the result row
         | 
| 216 | 
            +
                            if i > 0:
         | 
| 217 | 
            +
                                tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
         | 
| 218 | 
            +
                            if j > 0:
         | 
| 219 | 
            +
                                tile = self.blend_h(row[j - 1], tile, blend_extent)
         | 
| 220 | 
            +
                            result_row.append(tile[:, :, :, :row_limit, :row_limit])
         | 
| 221 | 
            +
                        result_rows.append(torch.cat(result_row, dim=4))
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    dec = torch.cat(result_rows, dim=3)
         | 
| 224 | 
            +
                    return dec
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                def encode(
         | 
| 227 | 
            +
                    self, z: torch.FloatTensor, return_dict: bool = True
         | 
| 228 | 
            +
                ) -> Union[DecoderOutput, torch.FloatTensor]:
         | 
| 229 | 
            +
                    if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
         | 
| 230 | 
            +
                        num_splits = z.shape[2] // self.z_sample_size
         | 
| 231 | 
            +
                        sizes = [self.z_sample_size] * num_splits
         | 
| 232 | 
            +
                        sizes = (
         | 
| 233 | 
            +
                            sizes + [z.shape[2] - sum(sizes)]
         | 
| 234 | 
            +
                            if z.shape[2] - sum(sizes) > 0
         | 
| 235 | 
            +
                            else sizes
         | 
| 236 | 
            +
                        )
         | 
| 237 | 
            +
                        tiles = z.split(sizes, dim=2)
         | 
| 238 | 
            +
                        moments_tiles = [
         | 
| 239 | 
            +
                            (
         | 
| 240 | 
            +
                                self._hw_tiled_encode(z_tile, return_dict)
         | 
| 241 | 
            +
                                if self.use_hw_tiling
         | 
| 242 | 
            +
                                else self._encode(z_tile)
         | 
| 243 | 
            +
                            )
         | 
| 244 | 
            +
                            for z_tile in tiles
         | 
| 245 | 
            +
                        ]
         | 
| 246 | 
            +
                        moments = torch.cat(moments_tiles, dim=2)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    else:
         | 
| 249 | 
            +
                        moments = (
         | 
| 250 | 
            +
                            self._hw_tiled_encode(z, return_dict)
         | 
| 251 | 
            +
                            if self.use_hw_tiling
         | 
| 252 | 
            +
                            else self._encode(z)
         | 
| 253 | 
            +
                        )
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    posterior = DiagonalGaussianDistribution(moments)
         | 
| 256 | 
            +
                    if not return_dict:
         | 
| 257 | 
            +
                        return (posterior,)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    return AutoencoderKLOutput(latent_dist=posterior)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 262 | 
            +
                    if isinstance(self.latent_norm_out, nn.BatchNorm3d):
         | 
| 263 | 
            +
                        _, c, _, _, _ = z.shape
         | 
| 264 | 
            +
                        z = torch.cat(
         | 
| 265 | 
            +
                            [
         | 
| 266 | 
            +
                                self.latent_norm_out(z[:, : c // 2, :, :, :]),
         | 
| 267 | 
            +
                                z[:, c // 2 :, :, :, :],
         | 
| 268 | 
            +
                            ],
         | 
| 269 | 
            +
                            dim=1,
         | 
| 270 | 
            +
                        )
         | 
| 271 | 
            +
                    elif isinstance(self.latent_norm_out, nn.BatchNorm2d):
         | 
| 272 | 
            +
                        raise NotImplementedError("BatchNorm2d not supported")
         | 
| 273 | 
            +
                    return z
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 276 | 
            +
                    if isinstance(self.latent_norm_out, nn.BatchNorm3d):
         | 
| 277 | 
            +
                        running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1)
         | 
| 278 | 
            +
                        running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1)
         | 
| 279 | 
            +
                        eps = self.latent_norm_out.eps
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                        z = z * torch.sqrt(running_var + eps) + running_mean
         | 
| 282 | 
            +
                    elif isinstance(self.latent_norm_out, nn.BatchNorm3d):
         | 
| 283 | 
            +
                        raise NotImplementedError("BatchNorm2d not supported")
         | 
| 284 | 
            +
                    return z
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
         | 
| 287 | 
            +
                    h = self.encoder(x)
         | 
| 288 | 
            +
                    moments = self.quant_conv(h)
         | 
| 289 | 
            +
                    moments = self._normalize_latent_channels(moments)
         | 
| 290 | 
            +
                    return moments
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                def _decode(
         | 
| 293 | 
            +
                    self,
         | 
| 294 | 
            +
                    z: torch.FloatTensor,
         | 
| 295 | 
            +
                    target_shape=None,
         | 
| 296 | 
            +
                    timestep: Optional[torch.Tensor] = None,
         | 
| 297 | 
            +
                ) -> Union[DecoderOutput, torch.FloatTensor]:
         | 
| 298 | 
            +
                    z = self._unnormalize_latent_channels(z)
         | 
| 299 | 
            +
                    z = self.post_quant_conv(z)
         | 
| 300 | 
            +
                    if "timestep" in self.decoder_params:
         | 
| 301 | 
            +
                        dec = self.decoder(z, target_shape=target_shape, timestep=timestep)
         | 
| 302 | 
            +
                    else:
         | 
| 303 | 
            +
                        dec = self.decoder(z, target_shape=target_shape)
         | 
| 304 | 
            +
                    return dec
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                def decode(
         | 
| 307 | 
            +
                    self,
         | 
| 308 | 
            +
                    z: torch.FloatTensor,
         | 
| 309 | 
            +
                    return_dict: bool = True,
         | 
| 310 | 
            +
                    target_shape=None,
         | 
| 311 | 
            +
                    timestep: Optional[torch.Tensor] = None,
         | 
| 312 | 
            +
                ) -> Union[DecoderOutput, torch.FloatTensor]:
         | 
| 313 | 
            +
                    assert target_shape is not None, "target_shape must be provided for decoding"
         | 
| 314 | 
            +
                    if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
         | 
| 315 | 
            +
                        reduction_factor = int(
         | 
| 316 | 
            +
                            self.encoder.patch_size_t
         | 
| 317 | 
            +
                            * 2
         | 
| 318 | 
            +
                            ** (
         | 
| 319 | 
            +
                                len(self.encoder.down_blocks)
         | 
| 320 | 
            +
                                - 1
         | 
| 321 | 
            +
                                - math.sqrt(self.encoder.patch_size)
         | 
| 322 | 
            +
                            )
         | 
| 323 | 
            +
                        )
         | 
| 324 | 
            +
                        split_size = self.z_sample_size // reduction_factor
         | 
| 325 | 
            +
                        num_splits = z.shape[2] // split_size
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                        # copy target shape, and divide frame dimension (=2) by the context size
         | 
| 328 | 
            +
                        target_shape_split = list(target_shape)
         | 
| 329 | 
            +
                        target_shape_split[2] = target_shape[2] // num_splits
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                        decoded_tiles = [
         | 
| 332 | 
            +
                            (
         | 
| 333 | 
            +
                                self._hw_tiled_decode(z_tile, target_shape_split)
         | 
| 334 | 
            +
                                if self.use_hw_tiling
         | 
| 335 | 
            +
                                else self._decode(z_tile, target_shape=target_shape_split)
         | 
| 336 | 
            +
                            )
         | 
| 337 | 
            +
                            for z_tile in torch.tensor_split(z, num_splits, dim=2)
         | 
| 338 | 
            +
                        ]
         | 
| 339 | 
            +
                        decoded = torch.cat(decoded_tiles, dim=2)
         | 
| 340 | 
            +
                    else:
         | 
| 341 | 
            +
                        decoded = (
         | 
| 342 | 
            +
                            self._hw_tiled_decode(z, target_shape)
         | 
| 343 | 
            +
                            if self.use_hw_tiling
         | 
| 344 | 
            +
                            else self._decode(z, target_shape=target_shape, timestep=timestep)
         | 
| 345 | 
            +
                        )
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    if not return_dict:
         | 
| 348 | 
            +
                        return (decoded,)
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                    return DecoderOutput(sample=decoded)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def forward(
         | 
| 353 | 
            +
                    self,
         | 
| 354 | 
            +
                    sample: torch.FloatTensor,
         | 
| 355 | 
            +
                    sample_posterior: bool = False,
         | 
| 356 | 
            +
                    return_dict: bool = True,
         | 
| 357 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 358 | 
            +
                ) -> Union[DecoderOutput, torch.FloatTensor]:
         | 
| 359 | 
            +
                    r"""
         | 
| 360 | 
            +
                    Args:
         | 
| 361 | 
            +
                        sample (`torch.FloatTensor`): Input sample.
         | 
| 362 | 
            +
                        sample_posterior (`bool`, *optional*, defaults to `False`):
         | 
| 363 | 
            +
                            Whether to sample from the posterior.
         | 
| 364 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 365 | 
            +
                            Whether to return a [`DecoderOutput`] instead of a plain tuple.
         | 
| 366 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 367 | 
            +
                            Generator used to sample from the posterior.
         | 
| 368 | 
            +
                    """
         | 
| 369 | 
            +
                    x = sample
         | 
| 370 | 
            +
                    posterior = self.encode(x).latent_dist
         | 
| 371 | 
            +
                    if sample_posterior:
         | 
| 372 | 
            +
                        z = posterior.sample(generator=generator)
         | 
| 373 | 
            +
                    else:
         | 
| 374 | 
            +
                        z = posterior.mode()
         | 
| 375 | 
            +
                    dec = self.decode(z, target_shape=sample.shape).sample
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    if not return_dict:
         | 
| 378 | 
            +
                        return (dec,)
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    return DecoderOutput(sample=dec)
         | 
    	
        ltx_video/models/autoencoders/vae_encode.py
    ADDED
    
    | @@ -0,0 +1,247 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Tuple
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from diffusers import AutoencoderKL
         | 
| 4 | 
            +
            from einops import rearrange
         | 
| 5 | 
            +
            from torch import Tensor
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            from ltx_video.models.autoencoders.causal_video_autoencoder import (
         | 
| 9 | 
            +
                CausalVideoAutoencoder,
         | 
| 10 | 
            +
            )
         | 
| 11 | 
            +
            from ltx_video.models.autoencoders.video_autoencoder import (
         | 
| 12 | 
            +
                Downsample3D,
         | 
| 13 | 
            +
                VideoAutoencoder,
         | 
| 14 | 
            +
            )
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            try:
         | 
| 17 | 
            +
                import torch_xla.core.xla_model as xm
         | 
| 18 | 
            +
            except ImportError:
         | 
| 19 | 
            +
                xm = None
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def vae_encode(
         | 
| 23 | 
            +
                media_items: Tensor,
         | 
| 24 | 
            +
                vae: AutoencoderKL,
         | 
| 25 | 
            +
                split_size: int = 1,
         | 
| 26 | 
            +
                vae_per_channel_normalize=False,
         | 
| 27 | 
            +
            ) -> Tensor:
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
                Encodes media items (images or videos) into latent representations using a specified VAE model.
         | 
| 30 | 
            +
                The function supports processing batches of images or video frames and can handle the processing
         | 
| 31 | 
            +
                in smaller sub-batches if needed.
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                Args:
         | 
| 34 | 
            +
                    media_items (Tensor): A torch Tensor containing the media items to encode. The expected
         | 
| 35 | 
            +
                        shape is (batch_size, channels, height, width) for images or (batch_size, channels,
         | 
| 36 | 
            +
                        frames, height, width) for videos.
         | 
| 37 | 
            +
                    vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
         | 
| 38 | 
            +
                        pre-configured and loaded with the appropriate model weights.
         | 
| 39 | 
            +
                    split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
         | 
| 40 | 
            +
                        If set to more than 1, the input media items are processed in smaller batches according to
         | 
| 41 | 
            +
                        this value. Defaults to 1, which processes all items in a single batch.
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                Returns:
         | 
| 44 | 
            +
                    Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
         | 
| 45 | 
            +
                        to match the input shape, scaled by the model's configuration.
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                Examples:
         | 
| 48 | 
            +
                    >>> import torch
         | 
| 49 | 
            +
                    >>> from diffusers import AutoencoderKL
         | 
| 50 | 
            +
                    >>> vae = AutoencoderKL.from_pretrained('your-model-name')
         | 
| 51 | 
            +
                    >>> images = torch.rand(10, 3, 8 256, 256)  # Example tensor with 10 videos of 8 frames.
         | 
| 52 | 
            +
                    >>> latents = vae_encode(images, vae)
         | 
| 53 | 
            +
                    >>> print(latents.shape)  # Output shape will depend on the model's latent configuration.
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                Note:
         | 
| 56 | 
            +
                    In case of a video, the function encodes the media item frame-by frame.
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
                is_video_shaped = media_items.dim() == 5
         | 
| 59 | 
            +
                batch_size, channels = media_items.shape[0:2]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                if channels != 3:
         | 
| 62 | 
            +
                    raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                if is_video_shaped and not isinstance(
         | 
| 65 | 
            +
                    vae, (VideoAutoencoder, CausalVideoAutoencoder)
         | 
| 66 | 
            +
                ):
         | 
| 67 | 
            +
                    media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
         | 
| 68 | 
            +
                if split_size > 1:
         | 
| 69 | 
            +
                    if len(media_items) % split_size != 0:
         | 
| 70 | 
            +
                        raise ValueError(
         | 
| 71 | 
            +
                            "Error: The batch size must be divisible by 'train.vae_bs_split"
         | 
| 72 | 
            +
                        )
         | 
| 73 | 
            +
                    encode_bs = len(media_items) // split_size
         | 
| 74 | 
            +
                    # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
         | 
| 75 | 
            +
                    latents = []
         | 
| 76 | 
            +
                    if media_items.device.type == "xla":
         | 
| 77 | 
            +
                        xm.mark_step()
         | 
| 78 | 
            +
                    for image_batch in media_items.split(encode_bs):
         | 
| 79 | 
            +
                        latents.append(vae.encode(image_batch).latent_dist.sample())
         | 
| 80 | 
            +
                        if media_items.device.type == "xla":
         | 
| 81 | 
            +
                            xm.mark_step()
         | 
| 82 | 
            +
                    latents = torch.cat(latents, dim=0)
         | 
| 83 | 
            +
                else:
         | 
| 84 | 
            +
                    latents = vae.encode(media_items).latent_dist.sample()
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                latents = normalize_latents(latents, vae, vae_per_channel_normalize)
         | 
| 87 | 
            +
                if is_video_shaped and not isinstance(
         | 
| 88 | 
            +
                    vae, (VideoAutoencoder, CausalVideoAutoencoder)
         | 
| 89 | 
            +
                ):
         | 
| 90 | 
            +
                    latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
         | 
| 91 | 
            +
                return latents
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def vae_decode(
         | 
| 95 | 
            +
                latents: Tensor,
         | 
| 96 | 
            +
                vae: AutoencoderKL,
         | 
| 97 | 
            +
                is_video: bool = True,
         | 
| 98 | 
            +
                split_size: int = 1,
         | 
| 99 | 
            +
                vae_per_channel_normalize=False,
         | 
| 100 | 
            +
                timestep=None,
         | 
| 101 | 
            +
            ) -> Tensor:
         | 
| 102 | 
            +
                is_video_shaped = latents.dim() == 5
         | 
| 103 | 
            +
                batch_size = latents.shape[0]
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                if is_video_shaped and not isinstance(
         | 
| 106 | 
            +
                    vae, (VideoAutoencoder, CausalVideoAutoencoder)
         | 
| 107 | 
            +
                ):
         | 
| 108 | 
            +
                    latents = rearrange(latents, "b c n h w -> (b n) c h w")
         | 
| 109 | 
            +
                if split_size > 1:
         | 
| 110 | 
            +
                    if len(latents) % split_size != 0:
         | 
| 111 | 
            +
                        raise ValueError(
         | 
| 112 | 
            +
                            "Error: The batch size must be divisible by 'train.vae_bs_split"
         | 
| 113 | 
            +
                        )
         | 
| 114 | 
            +
                    encode_bs = len(latents) // split_size
         | 
| 115 | 
            +
                    image_batch = [
         | 
| 116 | 
            +
                        _run_decoder(
         | 
| 117 | 
            +
                            latent_batch, vae, is_video, vae_per_channel_normalize, timestep
         | 
| 118 | 
            +
                        )
         | 
| 119 | 
            +
                        for latent_batch in latents.split(encode_bs)
         | 
| 120 | 
            +
                    ]
         | 
| 121 | 
            +
                    images = torch.cat(image_batch, dim=0)
         | 
| 122 | 
            +
                else:
         | 
| 123 | 
            +
                    images = _run_decoder(
         | 
| 124 | 
            +
                        latents, vae, is_video, vae_per_channel_normalize, timestep
         | 
| 125 | 
            +
                    )
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                if is_video_shaped and not isinstance(
         | 
| 128 | 
            +
                    vae, (VideoAutoencoder, CausalVideoAutoencoder)
         | 
| 129 | 
            +
                ):
         | 
| 130 | 
            +
                    images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
         | 
| 131 | 
            +
                return images
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            def _run_decoder(
         | 
| 135 | 
            +
                latents: Tensor,
         | 
| 136 | 
            +
                vae: AutoencoderKL,
         | 
| 137 | 
            +
                is_video: bool,
         | 
| 138 | 
            +
                vae_per_channel_normalize=False,
         | 
| 139 | 
            +
                timestep=None,
         | 
| 140 | 
            +
            ) -> Tensor:
         | 
| 141 | 
            +
                if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
         | 
| 142 | 
            +
                    *_, fl, hl, wl = latents.shape
         | 
| 143 | 
            +
                    temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
         | 
| 144 | 
            +
                    latents = latents.to(vae.dtype)
         | 
| 145 | 
            +
                    vae_decode_kwargs = {}
         | 
| 146 | 
            +
                    if timestep is not None:
         | 
| 147 | 
            +
                        vae_decode_kwargs["timestep"] = timestep
         | 
| 148 | 
            +
                    image = vae.decode(
         | 
| 149 | 
            +
                        un_normalize_latents(latents, vae, vae_per_channel_normalize),
         | 
| 150 | 
            +
                        return_dict=False,
         | 
| 151 | 
            +
                        target_shape=(
         | 
| 152 | 
            +
                            1,
         | 
| 153 | 
            +
                            3,
         | 
| 154 | 
            +
                            fl * temporal_scale if is_video else 1,
         | 
| 155 | 
            +
                            hl * spatial_scale,
         | 
| 156 | 
            +
                            wl * spatial_scale,
         | 
| 157 | 
            +
                        ),
         | 
| 158 | 
            +
                        **vae_decode_kwargs,
         | 
| 159 | 
            +
                    )[0]
         | 
| 160 | 
            +
                else:
         | 
| 161 | 
            +
                    image = vae.decode(
         | 
| 162 | 
            +
                        un_normalize_latents(latents, vae, vae_per_channel_normalize),
         | 
| 163 | 
            +
                        return_dict=False,
         | 
| 164 | 
            +
                    )[0]
         | 
| 165 | 
            +
                return image
         | 
| 166 | 
            +
             | 
| 167 | 
            +
             | 
| 168 | 
            +
            def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
         | 
| 169 | 
            +
                if isinstance(vae, CausalVideoAutoencoder):
         | 
| 170 | 
            +
                    spatial = vae.spatial_downscale_factor
         | 
| 171 | 
            +
                    temporal = vae.temporal_downscale_factor
         | 
| 172 | 
            +
                else:
         | 
| 173 | 
            +
                    down_blocks = len(
         | 
| 174 | 
            +
                        [
         | 
| 175 | 
            +
                            block
         | 
| 176 | 
            +
                            for block in vae.encoder.down_blocks
         | 
| 177 | 
            +
                            if isinstance(block.downsample, Downsample3D)
         | 
| 178 | 
            +
                        ]
         | 
| 179 | 
            +
                    )
         | 
| 180 | 
            +
                    spatial = vae.config.patch_size * 2**down_blocks
         | 
| 181 | 
            +
                    temporal = (
         | 
| 182 | 
            +
                        vae.config.patch_size_t * 2**down_blocks
         | 
| 183 | 
            +
                        if isinstance(vae, VideoAutoencoder)
         | 
| 184 | 
            +
                        else 1
         | 
| 185 | 
            +
                    )
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                return (temporal, spatial, spatial)
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
            def latent_to_pixel_coords(
         | 
| 191 | 
            +
                latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False
         | 
| 192 | 
            +
            ) -> Tensor:
         | 
| 193 | 
            +
                """
         | 
| 194 | 
            +
                Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
         | 
| 195 | 
            +
                configuration.
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                Args:
         | 
| 198 | 
            +
                    latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
         | 
| 199 | 
            +
                    containing the latent corner coordinates of each token.
         | 
| 200 | 
            +
                    vae (AutoencoderKL): The VAE model
         | 
| 201 | 
            +
                    causal_fix (bool): Whether to take into account the different temporal scale
         | 
| 202 | 
            +
                        of the first frame. Default = False for backwards compatibility.
         | 
| 203 | 
            +
                Returns:
         | 
| 204 | 
            +
                    Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
         | 
| 205 | 
            +
                """
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                scale_factors = get_vae_size_scale_factor(vae)
         | 
| 208 | 
            +
                causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix
         | 
| 209 | 
            +
                pixel_coords = latent_to_pixel_coords_from_factors(
         | 
| 210 | 
            +
                    latent_coords, scale_factors, causal_fix
         | 
| 211 | 
            +
                )
         | 
| 212 | 
            +
                return pixel_coords
         | 
| 213 | 
            +
             | 
| 214 | 
            +
             | 
| 215 | 
            +
            def latent_to_pixel_coords_from_factors(
         | 
| 216 | 
            +
                latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False
         | 
| 217 | 
            +
            ) -> Tensor:
         | 
| 218 | 
            +
                pixel_coords = (
         | 
| 219 | 
            +
                    latent_coords
         | 
| 220 | 
            +
                    * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
         | 
| 221 | 
            +
                )
         | 
| 222 | 
            +
                if causal_fix:
         | 
| 223 | 
            +
                    # Fix temporal scale for first frame to 1 due to causality
         | 
| 224 | 
            +
                    pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
         | 
| 225 | 
            +
                return pixel_coords
         | 
| 226 | 
            +
             | 
| 227 | 
            +
             | 
| 228 | 
            +
            def normalize_latents(
         | 
| 229 | 
            +
                latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
         | 
| 230 | 
            +
            ) -> Tensor:
         | 
| 231 | 
            +
                return (
         | 
| 232 | 
            +
                    (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
         | 
| 233 | 
            +
                    / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
         | 
| 234 | 
            +
                    if vae_per_channel_normalize
         | 
| 235 | 
            +
                    else latents * vae.config.scaling_factor
         | 
| 236 | 
            +
                )
         | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
            def un_normalize_latents(
         | 
| 240 | 
            +
                latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
         | 
| 241 | 
            +
            ) -> Tensor:
         | 
| 242 | 
            +
                return (
         | 
| 243 | 
            +
                    latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
         | 
| 244 | 
            +
                    + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
         | 
| 245 | 
            +
                    if vae_per_channel_normalize
         | 
| 246 | 
            +
                    else latents / vae.config.scaling_factor
         | 
| 247 | 
            +
                )
         | 
    	
        ltx_video/models/autoencoders/video_autoencoder.py
    ADDED
    
    | @@ -0,0 +1,1045 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from functools import partial
         | 
| 4 | 
            +
            from types import SimpleNamespace
         | 
| 5 | 
            +
            from typing import Any, Mapping, Optional, Tuple, Union
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from einops import rearrange
         | 
| 9 | 
            +
            from torch import nn
         | 
| 10 | 
            +
            from torch.nn import functional
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from diffusers.utils import logging
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from ltx_video.utils.torch_utils import Identity
         | 
| 15 | 
            +
            from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
         | 
| 16 | 
            +
            from ltx_video.models.autoencoders.pixel_norm import PixelNorm
         | 
| 17 | 
            +
            from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class VideoAutoencoder(AutoencoderKLWrapper):
         | 
| 23 | 
            +
                @classmethod
         | 
| 24 | 
            +
                def from_pretrained(
         | 
| 25 | 
            +
                    cls,
         | 
| 26 | 
            +
                    pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
         | 
| 27 | 
            +
                    *args,
         | 
| 28 | 
            +
                    **kwargs,
         | 
| 29 | 
            +
                ):
         | 
| 30 | 
            +
                    config_local_path = pretrained_model_name_or_path / "config.json"
         | 
| 31 | 
            +
                    config = cls.load_config(config_local_path, **kwargs)
         | 
| 32 | 
            +
                    video_vae = cls.from_config(config)
         | 
| 33 | 
            +
                    video_vae.to(kwargs["torch_dtype"])
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
         | 
| 36 | 
            +
                    ckpt_state_dict = torch.load(model_local_path)
         | 
| 37 | 
            +
                    video_vae.load_state_dict(ckpt_state_dict)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    statistics_local_path = (
         | 
| 40 | 
            +
                        pretrained_model_name_or_path / "per_channel_statistics.json"
         | 
| 41 | 
            +
                    )
         | 
| 42 | 
            +
                    if statistics_local_path.exists():
         | 
| 43 | 
            +
                        with open(statistics_local_path, "r") as file:
         | 
| 44 | 
            +
                            data = json.load(file)
         | 
| 45 | 
            +
                        transposed_data = list(zip(*data["data"]))
         | 
| 46 | 
            +
                        data_dict = {
         | 
| 47 | 
            +
                            col: torch.tensor(vals)
         | 
| 48 | 
            +
                            for col, vals in zip(data["columns"], transposed_data)
         | 
| 49 | 
            +
                        }
         | 
| 50 | 
            +
                        video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
         | 
| 51 | 
            +
                        video_vae.register_buffer(
         | 
| 52 | 
            +
                            "mean_of_means",
         | 
| 53 | 
            +
                            data_dict.get(
         | 
| 54 | 
            +
                                "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
         | 
| 55 | 
            +
                            ),
         | 
| 56 | 
            +
                        )
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    return video_vae
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                @staticmethod
         | 
| 61 | 
            +
                def from_config(config):
         | 
| 62 | 
            +
                    assert (
         | 
| 63 | 
            +
                        config["_class_name"] == "VideoAutoencoder"
         | 
| 64 | 
            +
                    ), "config must have _class_name=VideoAutoencoder"
         | 
| 65 | 
            +
                    if isinstance(config["dims"], list):
         | 
| 66 | 
            +
                        config["dims"] = tuple(config["dims"])
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    double_z = config.get("double_z", True)
         | 
| 71 | 
            +
                    latent_log_var = config.get(
         | 
| 72 | 
            +
                        "latent_log_var", "per_channel" if double_z else "none"
         | 
| 73 | 
            +
                    )
         | 
| 74 | 
            +
                    use_quant_conv = config.get("use_quant_conv", True)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    if use_quant_conv and latent_log_var == "uniform":
         | 
| 77 | 
            +
                        raise ValueError("uniform latent_log_var requires use_quant_conv=False")
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    encoder = Encoder(
         | 
| 80 | 
            +
                        dims=config["dims"],
         | 
| 81 | 
            +
                        in_channels=config.get("in_channels", 3),
         | 
| 82 | 
            +
                        out_channels=config["latent_channels"],
         | 
| 83 | 
            +
                        block_out_channels=config["block_out_channels"],
         | 
| 84 | 
            +
                        patch_size=config.get("patch_size", 1),
         | 
| 85 | 
            +
                        latent_log_var=latent_log_var,
         | 
| 86 | 
            +
                        norm_layer=config.get("norm_layer", "group_norm"),
         | 
| 87 | 
            +
                        patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
         | 
| 88 | 
            +
                        add_channel_padding=config.get("add_channel_padding", False),
         | 
| 89 | 
            +
                    )
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    decoder = Decoder(
         | 
| 92 | 
            +
                        dims=config["dims"],
         | 
| 93 | 
            +
                        in_channels=config["latent_channels"],
         | 
| 94 | 
            +
                        out_channels=config.get("out_channels", 3),
         | 
| 95 | 
            +
                        block_out_channels=config["block_out_channels"],
         | 
| 96 | 
            +
                        patch_size=config.get("patch_size", 1),
         | 
| 97 | 
            +
                        norm_layer=config.get("norm_layer", "group_norm"),
         | 
| 98 | 
            +
                        patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
         | 
| 99 | 
            +
                        add_channel_padding=config.get("add_channel_padding", False),
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    dims = config["dims"]
         | 
| 103 | 
            +
                    return VideoAutoencoder(
         | 
| 104 | 
            +
                        encoder=encoder,
         | 
| 105 | 
            +
                        decoder=decoder,
         | 
| 106 | 
            +
                        latent_channels=config["latent_channels"],
         | 
| 107 | 
            +
                        dims=dims,
         | 
| 108 | 
            +
                        use_quant_conv=use_quant_conv,
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                @property
         | 
| 112 | 
            +
                def config(self):
         | 
| 113 | 
            +
                    return SimpleNamespace(
         | 
| 114 | 
            +
                        _class_name="VideoAutoencoder",
         | 
| 115 | 
            +
                        dims=self.dims,
         | 
| 116 | 
            +
                        in_channels=self.encoder.conv_in.in_channels
         | 
| 117 | 
            +
                        // (self.encoder.patch_size_t * self.encoder.patch_size**2),
         | 
| 118 | 
            +
                        out_channels=self.decoder.conv_out.out_channels
         | 
| 119 | 
            +
                        // (self.decoder.patch_size_t * self.decoder.patch_size**2),
         | 
| 120 | 
            +
                        latent_channels=self.decoder.conv_in.in_channels,
         | 
| 121 | 
            +
                        block_out_channels=[
         | 
| 122 | 
            +
                            self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
         | 
| 123 | 
            +
                            for i in range(len(self.encoder.down_blocks))
         | 
| 124 | 
            +
                        ],
         | 
| 125 | 
            +
                        scaling_factor=1.0,
         | 
| 126 | 
            +
                        norm_layer=self.encoder.norm_layer,
         | 
| 127 | 
            +
                        patch_size=self.encoder.patch_size,
         | 
| 128 | 
            +
                        latent_log_var=self.encoder.latent_log_var,
         | 
| 129 | 
            +
                        use_quant_conv=self.use_quant_conv,
         | 
| 130 | 
            +
                        patch_size_t=self.encoder.patch_size_t,
         | 
| 131 | 
            +
                        add_channel_padding=self.encoder.add_channel_padding,
         | 
| 132 | 
            +
                    )
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                @property
         | 
| 135 | 
            +
                def is_video_supported(self):
         | 
| 136 | 
            +
                    """
         | 
| 137 | 
            +
                    Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
         | 
| 138 | 
            +
                    """
         | 
| 139 | 
            +
                    return self.dims != 2
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                @property
         | 
| 142 | 
            +
                def downscale_factor(self):
         | 
| 143 | 
            +
                    return self.encoder.downsample_factor
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def to_json_string(self) -> str:
         | 
| 146 | 
            +
                    import json
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    return json.dumps(self.config.__dict__)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
         | 
| 151 | 
            +
                    model_keys = set(name for name, _ in self.named_parameters())
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    key_mapping = {
         | 
| 154 | 
            +
                        ".resnets.": ".res_blocks.",
         | 
| 155 | 
            +
                        "downsamplers.0": "downsample",
         | 
| 156 | 
            +
                        "upsamplers.0": "upsample",
         | 
| 157 | 
            +
                    }
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    converted_state_dict = {}
         | 
| 160 | 
            +
                    for key, value in state_dict.items():
         | 
| 161 | 
            +
                        for k, v in key_mapping.items():
         | 
| 162 | 
            +
                            key = key.replace(k, v)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                        if "norm" in key and key not in model_keys:
         | 
| 165 | 
            +
                            logger.info(
         | 
| 166 | 
            +
                                f"Removing key {key} from state_dict as it is not present in the model"
         | 
| 167 | 
            +
                            )
         | 
| 168 | 
            +
                            continue
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                        converted_state_dict[key] = value
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    super().load_state_dict(converted_state_dict, strict=strict)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def last_layer(self):
         | 
| 175 | 
            +
                    if hasattr(self.decoder, "conv_out"):
         | 
| 176 | 
            +
                        if isinstance(self.decoder.conv_out, nn.Sequential):
         | 
| 177 | 
            +
                            last_layer = self.decoder.conv_out[-1]
         | 
| 178 | 
            +
                        else:
         | 
| 179 | 
            +
                            last_layer = self.decoder.conv_out
         | 
| 180 | 
            +
                    else:
         | 
| 181 | 
            +
                        last_layer = self.decoder.layers[-1]
         | 
| 182 | 
            +
                    return last_layer
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            class Encoder(nn.Module):
         | 
| 186 | 
            +
                r"""
         | 
| 187 | 
            +
                The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                Args:
         | 
| 190 | 
            +
                    in_channels (`int`, *optional*, defaults to 3):
         | 
| 191 | 
            +
                        The number of input channels.
         | 
| 192 | 
            +
                    out_channels (`int`, *optional*, defaults to 3):
         | 
| 193 | 
            +
                        The number of output channels.
         | 
| 194 | 
            +
                    block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
         | 
| 195 | 
            +
                        The number of output channels for each block.
         | 
| 196 | 
            +
                    layers_per_block (`int`, *optional*, defaults to 2):
         | 
| 197 | 
            +
                        The number of layers per block.
         | 
| 198 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to 32):
         | 
| 199 | 
            +
                        The number of groups for normalization.
         | 
| 200 | 
            +
                    patch_size (`int`, *optional*, defaults to 1):
         | 
| 201 | 
            +
                        The patch size to use. Should be a power of 2.
         | 
| 202 | 
            +
                    norm_layer (`str`, *optional*, defaults to `group_norm`):
         | 
| 203 | 
            +
                        The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
         | 
| 204 | 
            +
                    latent_log_var (`str`, *optional*, defaults to `per_channel`):
         | 
| 205 | 
            +
                        The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
         | 
| 206 | 
            +
                """
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                def __init__(
         | 
| 209 | 
            +
                    self,
         | 
| 210 | 
            +
                    dims: Union[int, Tuple[int, int]] = 3,
         | 
| 211 | 
            +
                    in_channels: int = 3,
         | 
| 212 | 
            +
                    out_channels: int = 3,
         | 
| 213 | 
            +
                    block_out_channels: Tuple[int, ...] = (64,),
         | 
| 214 | 
            +
                    layers_per_block: int = 2,
         | 
| 215 | 
            +
                    norm_num_groups: int = 32,
         | 
| 216 | 
            +
                    patch_size: Union[int, Tuple[int]] = 1,
         | 
| 217 | 
            +
                    norm_layer: str = "group_norm",  # group_norm, pixel_norm
         | 
| 218 | 
            +
                    latent_log_var: str = "per_channel",
         | 
| 219 | 
            +
                    patch_size_t: Optional[int] = None,
         | 
| 220 | 
            +
                    add_channel_padding: Optional[bool] = False,
         | 
| 221 | 
            +
                ):
         | 
| 222 | 
            +
                    super().__init__()
         | 
| 223 | 
            +
                    self.patch_size = patch_size
         | 
| 224 | 
            +
                    self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
         | 
| 225 | 
            +
                    self.add_channel_padding = add_channel_padding
         | 
| 226 | 
            +
                    self.layers_per_block = layers_per_block
         | 
| 227 | 
            +
                    self.norm_layer = norm_layer
         | 
| 228 | 
            +
                    self.latent_channels = out_channels
         | 
| 229 | 
            +
                    self.latent_log_var = latent_log_var
         | 
| 230 | 
            +
                    if add_channel_padding:
         | 
| 231 | 
            +
                        in_channels = in_channels * self.patch_size**3
         | 
| 232 | 
            +
                    else:
         | 
| 233 | 
            +
                        in_channels = in_channels * self.patch_size_t * self.patch_size**2
         | 
| 234 | 
            +
                    self.in_channels = in_channels
         | 
| 235 | 
            +
                    output_channel = block_out_channels[0]
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    self.conv_in = make_conv_nd(
         | 
| 238 | 
            +
                        dims=dims,
         | 
| 239 | 
            +
                        in_channels=in_channels,
         | 
| 240 | 
            +
                        out_channels=output_channel,
         | 
| 241 | 
            +
                        kernel_size=3,
         | 
| 242 | 
            +
                        stride=1,
         | 
| 243 | 
            +
                        padding=1,
         | 
| 244 | 
            +
                    )
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    for i in range(len(block_out_channels)):
         | 
| 249 | 
            +
                        input_channel = output_channel
         | 
| 250 | 
            +
                        output_channel = block_out_channels[i]
         | 
| 251 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                        down_block = DownEncoderBlock3D(
         | 
| 254 | 
            +
                            dims=dims,
         | 
| 255 | 
            +
                            in_channels=input_channel,
         | 
| 256 | 
            +
                            out_channels=output_channel,
         | 
| 257 | 
            +
                            num_layers=self.layers_per_block,
         | 
| 258 | 
            +
                            add_downsample=not is_final_block and 2**i >= patch_size,
         | 
| 259 | 
            +
                            resnet_eps=1e-6,
         | 
| 260 | 
            +
                            downsample_padding=0,
         | 
| 261 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 262 | 
            +
                            norm_layer=norm_layer,
         | 
| 263 | 
            +
                        )
         | 
| 264 | 
            +
                        self.down_blocks.append(down_block)
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                    self.mid_block = UNetMidBlock3D(
         | 
| 267 | 
            +
                        dims=dims,
         | 
| 268 | 
            +
                        in_channels=block_out_channels[-1],
         | 
| 269 | 
            +
                        num_layers=self.layers_per_block,
         | 
| 270 | 
            +
                        resnet_eps=1e-6,
         | 
| 271 | 
            +
                        resnet_groups=norm_num_groups,
         | 
| 272 | 
            +
                        norm_layer=norm_layer,
         | 
| 273 | 
            +
                    )
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    # out
         | 
| 276 | 
            +
                    if norm_layer == "group_norm":
         | 
| 277 | 
            +
                        self.conv_norm_out = nn.GroupNorm(
         | 
| 278 | 
            +
                            num_channels=block_out_channels[-1],
         | 
| 279 | 
            +
                            num_groups=norm_num_groups,
         | 
| 280 | 
            +
                            eps=1e-6,
         | 
| 281 | 
            +
                        )
         | 
| 282 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 283 | 
            +
                        self.conv_norm_out = PixelNorm()
         | 
| 284 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    conv_out_channels = out_channels
         | 
| 287 | 
            +
                    if latent_log_var == "per_channel":
         | 
| 288 | 
            +
                        conv_out_channels *= 2
         | 
| 289 | 
            +
                    elif latent_log_var == "uniform":
         | 
| 290 | 
            +
                        conv_out_channels += 1
         | 
| 291 | 
            +
                    elif latent_log_var != "none":
         | 
| 292 | 
            +
                        raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
         | 
| 293 | 
            +
                    self.conv_out = make_conv_nd(
         | 
| 294 | 
            +
                        dims, block_out_channels[-1], conv_out_channels, 3, padding=1
         | 
| 295 | 
            +
                    )
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    self.gradient_checkpointing = False
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                @property
         | 
| 300 | 
            +
                def downscale_factor(self):
         | 
| 301 | 
            +
                    return (
         | 
| 302 | 
            +
                        2
         | 
| 303 | 
            +
                        ** len(
         | 
| 304 | 
            +
                            [
         | 
| 305 | 
            +
                                block
         | 
| 306 | 
            +
                                for block in self.down_blocks
         | 
| 307 | 
            +
                                if isinstance(block.downsample, Downsample3D)
         | 
| 308 | 
            +
                            ]
         | 
| 309 | 
            +
                        )
         | 
| 310 | 
            +
                        * self.patch_size
         | 
| 311 | 
            +
                    )
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                def forward(
         | 
| 314 | 
            +
                    self, sample: torch.FloatTensor, return_features=False
         | 
| 315 | 
            +
                ) -> torch.FloatTensor:
         | 
| 316 | 
            +
                    r"""The forward method of the `Encoder` class."""
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    downsample_in_time = sample.shape[2] != 1
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    # patchify
         | 
| 321 | 
            +
                    patch_size_t = self.patch_size_t if downsample_in_time else 1
         | 
| 322 | 
            +
                    sample = patchify(
         | 
| 323 | 
            +
                        sample,
         | 
| 324 | 
            +
                        patch_size_hw=self.patch_size,
         | 
| 325 | 
            +
                        patch_size_t=patch_size_t,
         | 
| 326 | 
            +
                        add_channel_padding=self.add_channel_padding,
         | 
| 327 | 
            +
                    )
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    sample = self.conv_in(sample)
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    checkpoint_fn = (
         | 
| 332 | 
            +
                        partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
         | 
| 333 | 
            +
                        if self.gradient_checkpointing and self.training
         | 
| 334 | 
            +
                        else lambda x: x
         | 
| 335 | 
            +
                    )
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    if return_features:
         | 
| 338 | 
            +
                        features = []
         | 
| 339 | 
            +
                    for down_block in self.down_blocks:
         | 
| 340 | 
            +
                        sample = checkpoint_fn(down_block)(
         | 
| 341 | 
            +
                            sample, downsample_in_time=downsample_in_time
         | 
| 342 | 
            +
                        )
         | 
| 343 | 
            +
                        if return_features:
         | 
| 344 | 
            +
                            features.append(sample)
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                    sample = checkpoint_fn(self.mid_block)(sample)
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    # post-process
         | 
| 349 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 350 | 
            +
                    sample = self.conv_act(sample)
         | 
| 351 | 
            +
                    sample = self.conv_out(sample)
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    if self.latent_log_var == "uniform":
         | 
| 354 | 
            +
                        last_channel = sample[:, -1:, ...]
         | 
| 355 | 
            +
                        num_dims = sample.dim()
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                        if num_dims == 4:
         | 
| 358 | 
            +
                            # For shape (B, C, H, W)
         | 
| 359 | 
            +
                            repeated_last_channel = last_channel.repeat(
         | 
| 360 | 
            +
                                1, sample.shape[1] - 2, 1, 1
         | 
| 361 | 
            +
                            )
         | 
| 362 | 
            +
                            sample = torch.cat([sample, repeated_last_channel], dim=1)
         | 
| 363 | 
            +
                        elif num_dims == 5:
         | 
| 364 | 
            +
                            # For shape (B, C, F, H, W)
         | 
| 365 | 
            +
                            repeated_last_channel = last_channel.repeat(
         | 
| 366 | 
            +
                                1, sample.shape[1] - 2, 1, 1, 1
         | 
| 367 | 
            +
                            )
         | 
| 368 | 
            +
                            sample = torch.cat([sample, repeated_last_channel], dim=1)
         | 
| 369 | 
            +
                        else:
         | 
| 370 | 
            +
                            raise ValueError(f"Invalid input shape: {sample.shape}")
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    if return_features:
         | 
| 373 | 
            +
                        features.append(sample[:, : self.latent_channels, ...])
         | 
| 374 | 
            +
                        return sample, features
         | 
| 375 | 
            +
                    return sample
         | 
| 376 | 
            +
             | 
| 377 | 
            +
             | 
| 378 | 
            +
            class Decoder(nn.Module):
         | 
| 379 | 
            +
                r"""
         | 
| 380 | 
            +
                The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                Args:
         | 
| 383 | 
            +
                    in_channels (`int`, *optional*, defaults to 3):
         | 
| 384 | 
            +
                        The number of input channels.
         | 
| 385 | 
            +
                    out_channels (`int`, *optional*, defaults to 3):
         | 
| 386 | 
            +
                        The number of output channels.
         | 
| 387 | 
            +
                    block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
         | 
| 388 | 
            +
                        The number of output channels for each block.
         | 
| 389 | 
            +
                    layers_per_block (`int`, *optional*, defaults to 2):
         | 
| 390 | 
            +
                        The number of layers per block.
         | 
| 391 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to 32):
         | 
| 392 | 
            +
                        The number of groups for normalization.
         | 
| 393 | 
            +
                    patch_size (`int`, *optional*, defaults to 1):
         | 
| 394 | 
            +
                        The patch size to use. Should be a power of 2.
         | 
| 395 | 
            +
                    norm_layer (`str`, *optional*, defaults to `group_norm`):
         | 
| 396 | 
            +
                        The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
         | 
| 397 | 
            +
                """
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                def __init__(
         | 
| 400 | 
            +
                    self,
         | 
| 401 | 
            +
                    dims,
         | 
| 402 | 
            +
                    in_channels: int = 3,
         | 
| 403 | 
            +
                    out_channels: int = 3,
         | 
| 404 | 
            +
                    block_out_channels: Tuple[int, ...] = (64,),
         | 
| 405 | 
            +
                    layers_per_block: int = 2,
         | 
| 406 | 
            +
                    norm_num_groups: int = 32,
         | 
| 407 | 
            +
                    patch_size: int = 1,
         | 
| 408 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 409 | 
            +
                    patch_size_t: Optional[int] = None,
         | 
| 410 | 
            +
                    add_channel_padding: Optional[bool] = False,
         | 
| 411 | 
            +
                ):
         | 
| 412 | 
            +
                    super().__init__()
         | 
| 413 | 
            +
                    self.patch_size = patch_size
         | 
| 414 | 
            +
                    self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
         | 
| 415 | 
            +
                    self.add_channel_padding = add_channel_padding
         | 
| 416 | 
            +
                    self.layers_per_block = layers_per_block
         | 
| 417 | 
            +
                    if add_channel_padding:
         | 
| 418 | 
            +
                        out_channels = out_channels * self.patch_size**3
         | 
| 419 | 
            +
                    else:
         | 
| 420 | 
            +
                        out_channels = out_channels * self.patch_size_t * self.patch_size**2
         | 
| 421 | 
            +
                    self.out_channels = out_channels
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    self.conv_in = make_conv_nd(
         | 
| 424 | 
            +
                        dims,
         | 
| 425 | 
            +
                        in_channels,
         | 
| 426 | 
            +
                        block_out_channels[-1],
         | 
| 427 | 
            +
                        kernel_size=3,
         | 
| 428 | 
            +
                        stride=1,
         | 
| 429 | 
            +
                        padding=1,
         | 
| 430 | 
            +
                    )
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    self.mid_block = None
         | 
| 433 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                    self.mid_block = UNetMidBlock3D(
         | 
| 436 | 
            +
                        dims=dims,
         | 
| 437 | 
            +
                        in_channels=block_out_channels[-1],
         | 
| 438 | 
            +
                        num_layers=self.layers_per_block,
         | 
| 439 | 
            +
                        resnet_eps=1e-6,
         | 
| 440 | 
            +
                        resnet_groups=norm_num_groups,
         | 
| 441 | 
            +
                        norm_layer=norm_layer,
         | 
| 442 | 
            +
                    )
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         | 
| 445 | 
            +
                    output_channel = reversed_block_out_channels[0]
         | 
| 446 | 
            +
                    for i in range(len(reversed_block_out_channels)):
         | 
| 447 | 
            +
                        prev_output_channel = output_channel
         | 
| 448 | 
            +
                        output_channel = reversed_block_out_channels[i]
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                        up_block = UpDecoderBlock3D(
         | 
| 453 | 
            +
                            dims=dims,
         | 
| 454 | 
            +
                            num_layers=self.layers_per_block + 1,
         | 
| 455 | 
            +
                            in_channels=prev_output_channel,
         | 
| 456 | 
            +
                            out_channels=output_channel,
         | 
| 457 | 
            +
                            add_upsample=not is_final_block
         | 
| 458 | 
            +
                            and 2 ** (len(block_out_channels) - i - 1) > patch_size,
         | 
| 459 | 
            +
                            resnet_eps=1e-6,
         | 
| 460 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 461 | 
            +
                            norm_layer=norm_layer,
         | 
| 462 | 
            +
                        )
         | 
| 463 | 
            +
                        self.up_blocks.append(up_block)
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                    if norm_layer == "group_norm":
         | 
| 466 | 
            +
                        self.conv_norm_out = nn.GroupNorm(
         | 
| 467 | 
            +
                            num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
         | 
| 468 | 
            +
                        )
         | 
| 469 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 470 | 
            +
                        self.conv_norm_out = PixelNorm()
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 473 | 
            +
                    self.conv_out = make_conv_nd(
         | 
| 474 | 
            +
                        dims, block_out_channels[0], out_channels, 3, padding=1
         | 
| 475 | 
            +
                    )
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    self.gradient_checkpointing = False
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
         | 
| 480 | 
            +
                    r"""The forward method of the `Decoder` class."""
         | 
| 481 | 
            +
                    assert target_shape is not None, "target_shape must be provided"
         | 
| 482 | 
            +
                    upsample_in_time = sample.shape[2] < target_shape[2]
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    sample = self.conv_in(sample)
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    checkpoint_fn = (
         | 
| 489 | 
            +
                        partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
         | 
| 490 | 
            +
                        if self.gradient_checkpointing and self.training
         | 
| 491 | 
            +
                        else lambda x: x
         | 
| 492 | 
            +
                    )
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    sample = checkpoint_fn(self.mid_block)(sample)
         | 
| 495 | 
            +
                    sample = sample.to(upscale_dtype)
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                    for up_block in self.up_blocks:
         | 
| 498 | 
            +
                        sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                    # post-process
         | 
| 501 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 502 | 
            +
                    sample = self.conv_act(sample)
         | 
| 503 | 
            +
                    sample = self.conv_out(sample)
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                    # un-patchify
         | 
| 506 | 
            +
                    patch_size_t = self.patch_size_t if upsample_in_time else 1
         | 
| 507 | 
            +
                    sample = unpatchify(
         | 
| 508 | 
            +
                        sample,
         | 
| 509 | 
            +
                        patch_size_hw=self.patch_size,
         | 
| 510 | 
            +
                        patch_size_t=patch_size_t,
         | 
| 511 | 
            +
                        add_channel_padding=self.add_channel_padding,
         | 
| 512 | 
            +
                    )
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                    return sample
         | 
| 515 | 
            +
             | 
| 516 | 
            +
             | 
| 517 | 
            +
            class DownEncoderBlock3D(nn.Module):
         | 
| 518 | 
            +
                def __init__(
         | 
| 519 | 
            +
                    self,
         | 
| 520 | 
            +
                    dims: Union[int, Tuple[int, int]],
         | 
| 521 | 
            +
                    in_channels: int,
         | 
| 522 | 
            +
                    out_channels: int,
         | 
| 523 | 
            +
                    dropout: float = 0.0,
         | 
| 524 | 
            +
                    num_layers: int = 1,
         | 
| 525 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 526 | 
            +
                    resnet_groups: int = 32,
         | 
| 527 | 
            +
                    add_downsample: bool = True,
         | 
| 528 | 
            +
                    downsample_padding: int = 1,
         | 
| 529 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 530 | 
            +
                ):
         | 
| 531 | 
            +
                    super().__init__()
         | 
| 532 | 
            +
                    res_blocks = []
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    for i in range(num_layers):
         | 
| 535 | 
            +
                        in_channels = in_channels if i == 0 else out_channels
         | 
| 536 | 
            +
                        res_blocks.append(
         | 
| 537 | 
            +
                            ResnetBlock3D(
         | 
| 538 | 
            +
                                dims=dims,
         | 
| 539 | 
            +
                                in_channels=in_channels,
         | 
| 540 | 
            +
                                out_channels=out_channels,
         | 
| 541 | 
            +
                                eps=resnet_eps,
         | 
| 542 | 
            +
                                groups=resnet_groups,
         | 
| 543 | 
            +
                                dropout=dropout,
         | 
| 544 | 
            +
                                norm_layer=norm_layer,
         | 
| 545 | 
            +
                            )
         | 
| 546 | 
            +
                        )
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                    self.res_blocks = nn.ModuleList(res_blocks)
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                    if add_downsample:
         | 
| 551 | 
            +
                        self.downsample = Downsample3D(
         | 
| 552 | 
            +
                            dims,
         | 
| 553 | 
            +
                            out_channels,
         | 
| 554 | 
            +
                            out_channels=out_channels,
         | 
| 555 | 
            +
                            padding=downsample_padding,
         | 
| 556 | 
            +
                        )
         | 
| 557 | 
            +
                    else:
         | 
| 558 | 
            +
                        self.downsample = Identity()
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                def forward(
         | 
| 561 | 
            +
                    self, hidden_states: torch.FloatTensor, downsample_in_time
         | 
| 562 | 
            +
                ) -> torch.FloatTensor:
         | 
| 563 | 
            +
                    for resnet in self.res_blocks:
         | 
| 564 | 
            +
                        hidden_states = resnet(hidden_states)
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                    hidden_states = self.downsample(
         | 
| 567 | 
            +
                        hidden_states, downsample_in_time=downsample_in_time
         | 
| 568 | 
            +
                    )
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    return hidden_states
         | 
| 571 | 
            +
             | 
| 572 | 
            +
             | 
| 573 | 
            +
            class UNetMidBlock3D(nn.Module):
         | 
| 574 | 
            +
                """
         | 
| 575 | 
            +
                A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                Args:
         | 
| 578 | 
            +
                    in_channels (`int`): The number of input channels.
         | 
| 579 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
         | 
| 580 | 
            +
                    num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
         | 
| 581 | 
            +
                    resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
         | 
| 582 | 
            +
                    resnet_groups (`int`, *optional*, defaults to 32):
         | 
| 583 | 
            +
                        The number of groups to use in the group normalization layers of the resnet blocks.
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                Returns:
         | 
| 586 | 
            +
                    `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
         | 
| 587 | 
            +
                    in_channels, height, width)`.
         | 
| 588 | 
            +
             | 
| 589 | 
            +
                """
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                def __init__(
         | 
| 592 | 
            +
                    self,
         | 
| 593 | 
            +
                    dims: Union[int, Tuple[int, int]],
         | 
| 594 | 
            +
                    in_channels: int,
         | 
| 595 | 
            +
                    dropout: float = 0.0,
         | 
| 596 | 
            +
                    num_layers: int = 1,
         | 
| 597 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 598 | 
            +
                    resnet_groups: int = 32,
         | 
| 599 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 600 | 
            +
                ):
         | 
| 601 | 
            +
                    super().__init__()
         | 
| 602 | 
            +
                    resnet_groups = (
         | 
| 603 | 
            +
                        resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
         | 
| 604 | 
            +
                    )
         | 
| 605 | 
            +
             | 
| 606 | 
            +
                    self.res_blocks = nn.ModuleList(
         | 
| 607 | 
            +
                        [
         | 
| 608 | 
            +
                            ResnetBlock3D(
         | 
| 609 | 
            +
                                dims=dims,
         | 
| 610 | 
            +
                                in_channels=in_channels,
         | 
| 611 | 
            +
                                out_channels=in_channels,
         | 
| 612 | 
            +
                                eps=resnet_eps,
         | 
| 613 | 
            +
                                groups=resnet_groups,
         | 
| 614 | 
            +
                                dropout=dropout,
         | 
| 615 | 
            +
                                norm_layer=norm_layer,
         | 
| 616 | 
            +
                            )
         | 
| 617 | 
            +
                            for _ in range(num_layers)
         | 
| 618 | 
            +
                        ]
         | 
| 619 | 
            +
                    )
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 622 | 
            +
                    for resnet in self.res_blocks:
         | 
| 623 | 
            +
                        hidden_states = resnet(hidden_states)
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                    return hidden_states
         | 
| 626 | 
            +
             | 
| 627 | 
            +
             | 
| 628 | 
            +
            class UpDecoderBlock3D(nn.Module):
         | 
| 629 | 
            +
                def __init__(
         | 
| 630 | 
            +
                    self,
         | 
| 631 | 
            +
                    dims: Union[int, Tuple[int, int]],
         | 
| 632 | 
            +
                    in_channels: int,
         | 
| 633 | 
            +
                    out_channels: int,
         | 
| 634 | 
            +
                    resolution_idx: Optional[int] = None,
         | 
| 635 | 
            +
                    dropout: float = 0.0,
         | 
| 636 | 
            +
                    num_layers: int = 1,
         | 
| 637 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 638 | 
            +
                    resnet_groups: int = 32,
         | 
| 639 | 
            +
                    add_upsample: bool = True,
         | 
| 640 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 641 | 
            +
                ):
         | 
| 642 | 
            +
                    super().__init__()
         | 
| 643 | 
            +
                    res_blocks = []
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                    for i in range(num_layers):
         | 
| 646 | 
            +
                        input_channels = in_channels if i == 0 else out_channels
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                        res_blocks.append(
         | 
| 649 | 
            +
                            ResnetBlock3D(
         | 
| 650 | 
            +
                                dims=dims,
         | 
| 651 | 
            +
                                in_channels=input_channels,
         | 
| 652 | 
            +
                                out_channels=out_channels,
         | 
| 653 | 
            +
                                eps=resnet_eps,
         | 
| 654 | 
            +
                                groups=resnet_groups,
         | 
| 655 | 
            +
                                dropout=dropout,
         | 
| 656 | 
            +
                                norm_layer=norm_layer,
         | 
| 657 | 
            +
                            )
         | 
| 658 | 
            +
                        )
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                    self.res_blocks = nn.ModuleList(res_blocks)
         | 
| 661 | 
            +
             | 
| 662 | 
            +
                    if add_upsample:
         | 
| 663 | 
            +
                        self.upsample = Upsample3D(
         | 
| 664 | 
            +
                            dims=dims, channels=out_channels, out_channels=out_channels
         | 
| 665 | 
            +
                        )
         | 
| 666 | 
            +
                    else:
         | 
| 667 | 
            +
                        self.upsample = Identity()
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                    self.resolution_idx = resolution_idx
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                def forward(
         | 
| 672 | 
            +
                    self, hidden_states: torch.FloatTensor, upsample_in_time=True
         | 
| 673 | 
            +
                ) -> torch.FloatTensor:
         | 
| 674 | 
            +
                    for resnet in self.res_blocks:
         | 
| 675 | 
            +
                        hidden_states = resnet(hidden_states)
         | 
| 676 | 
            +
             | 
| 677 | 
            +
                    hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time)
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    return hidden_states
         | 
| 680 | 
            +
             | 
| 681 | 
            +
             | 
| 682 | 
            +
            class ResnetBlock3D(nn.Module):
         | 
| 683 | 
            +
                r"""
         | 
| 684 | 
            +
                A Resnet block.
         | 
| 685 | 
            +
             | 
| 686 | 
            +
                Parameters:
         | 
| 687 | 
            +
                    in_channels (`int`): The number of channels in the input.
         | 
| 688 | 
            +
                    out_channels (`int`, *optional*, default to be `None`):
         | 
| 689 | 
            +
                        The number of output channels for the first conv layer. If None, same as `in_channels`.
         | 
| 690 | 
            +
                    dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
         | 
| 691 | 
            +
                    groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
         | 
| 692 | 
            +
                    eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
         | 
| 693 | 
            +
                """
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                def __init__(
         | 
| 696 | 
            +
                    self,
         | 
| 697 | 
            +
                    dims: Union[int, Tuple[int, int]],
         | 
| 698 | 
            +
                    in_channels: int,
         | 
| 699 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 700 | 
            +
                    conv_shortcut: bool = False,
         | 
| 701 | 
            +
                    dropout: float = 0.0,
         | 
| 702 | 
            +
                    groups: int = 32,
         | 
| 703 | 
            +
                    eps: float = 1e-6,
         | 
| 704 | 
            +
                    norm_layer: str = "group_norm",
         | 
| 705 | 
            +
                ):
         | 
| 706 | 
            +
                    super().__init__()
         | 
| 707 | 
            +
                    self.in_channels = in_channels
         | 
| 708 | 
            +
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 709 | 
            +
                    self.out_channels = out_channels
         | 
| 710 | 
            +
                    self.use_conv_shortcut = conv_shortcut
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                    if norm_layer == "group_norm":
         | 
| 713 | 
            +
                        self.norm1 = torch.nn.GroupNorm(
         | 
| 714 | 
            +
                            num_groups=groups, num_channels=in_channels, eps=eps, affine=True
         | 
| 715 | 
            +
                        )
         | 
| 716 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 717 | 
            +
                        self.norm1 = PixelNorm()
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                    self.non_linearity = nn.SiLU()
         | 
| 720 | 
            +
             | 
| 721 | 
            +
                    self.conv1 = make_conv_nd(
         | 
| 722 | 
            +
                        dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 723 | 
            +
                    )
         | 
| 724 | 
            +
             | 
| 725 | 
            +
                    if norm_layer == "group_norm":
         | 
| 726 | 
            +
                        self.norm2 = torch.nn.GroupNorm(
         | 
| 727 | 
            +
                            num_groups=groups, num_channels=out_channels, eps=eps, affine=True
         | 
| 728 | 
            +
                        )
         | 
| 729 | 
            +
                    elif norm_layer == "pixel_norm":
         | 
| 730 | 
            +
                        self.norm2 = PixelNorm()
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                    self.dropout = torch.nn.Dropout(dropout)
         | 
| 733 | 
            +
             | 
| 734 | 
            +
                    self.conv2 = make_conv_nd(
         | 
| 735 | 
            +
                        dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1
         | 
| 736 | 
            +
                    )
         | 
| 737 | 
            +
             | 
| 738 | 
            +
                    self.conv_shortcut = (
         | 
| 739 | 
            +
                        make_linear_nd(
         | 
| 740 | 
            +
                            dims=dims, in_channels=in_channels, out_channels=out_channels
         | 
| 741 | 
            +
                        )
         | 
| 742 | 
            +
                        if in_channels != out_channels
         | 
| 743 | 
            +
                        else nn.Identity()
         | 
| 744 | 
            +
                    )
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                def forward(
         | 
| 747 | 
            +
                    self,
         | 
| 748 | 
            +
                    input_tensor: torch.FloatTensor,
         | 
| 749 | 
            +
                ) -> torch.FloatTensor:
         | 
| 750 | 
            +
                    hidden_states = input_tensor
         | 
| 751 | 
            +
             | 
| 752 | 
            +
                    hidden_states = self.norm1(hidden_states)
         | 
| 753 | 
            +
             | 
| 754 | 
            +
                    hidden_states = self.non_linearity(hidden_states)
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                    hidden_states = self.conv1(hidden_states)
         | 
| 757 | 
            +
             | 
| 758 | 
            +
                    hidden_states = self.norm2(hidden_states)
         | 
| 759 | 
            +
             | 
| 760 | 
            +
                    hidden_states = self.non_linearity(hidden_states)
         | 
| 761 | 
            +
             | 
| 762 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 763 | 
            +
             | 
| 764 | 
            +
                    hidden_states = self.conv2(hidden_states)
         | 
| 765 | 
            +
             | 
| 766 | 
            +
                    input_tensor = self.conv_shortcut(input_tensor)
         | 
| 767 | 
            +
             | 
| 768 | 
            +
                    output_tensor = input_tensor + hidden_states
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                    return output_tensor
         | 
| 771 | 
            +
             | 
| 772 | 
            +
             | 
| 773 | 
            +
            class Downsample3D(nn.Module):
         | 
| 774 | 
            +
                def __init__(
         | 
| 775 | 
            +
                    self,
         | 
| 776 | 
            +
                    dims,
         | 
| 777 | 
            +
                    in_channels: int,
         | 
| 778 | 
            +
                    out_channels: int,
         | 
| 779 | 
            +
                    kernel_size: int = 3,
         | 
| 780 | 
            +
                    padding: int = 1,
         | 
| 781 | 
            +
                ):
         | 
| 782 | 
            +
                    super().__init__()
         | 
| 783 | 
            +
                    stride: int = 2
         | 
| 784 | 
            +
                    self.padding = padding
         | 
| 785 | 
            +
                    self.in_channels = in_channels
         | 
| 786 | 
            +
                    self.dims = dims
         | 
| 787 | 
            +
                    self.conv = make_conv_nd(
         | 
| 788 | 
            +
                        dims=dims,
         | 
| 789 | 
            +
                        in_channels=in_channels,
         | 
| 790 | 
            +
                        out_channels=out_channels,
         | 
| 791 | 
            +
                        kernel_size=kernel_size,
         | 
| 792 | 
            +
                        stride=stride,
         | 
| 793 | 
            +
                        padding=padding,
         | 
| 794 | 
            +
                    )
         | 
| 795 | 
            +
             | 
| 796 | 
            +
                def forward(self, x, downsample_in_time=True):
         | 
| 797 | 
            +
                    conv = self.conv
         | 
| 798 | 
            +
                    if self.padding == 0:
         | 
| 799 | 
            +
                        if self.dims == 2:
         | 
| 800 | 
            +
                            padding = (0, 1, 0, 1)
         | 
| 801 | 
            +
                        else:
         | 
| 802 | 
            +
                            padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
         | 
| 803 | 
            +
             | 
| 804 | 
            +
                        x = functional.pad(x, padding, mode="constant", value=0)
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                        if self.dims == (2, 1) and not downsample_in_time:
         | 
| 807 | 
            +
                            return conv(x, skip_time_conv=True)
         | 
| 808 | 
            +
             | 
| 809 | 
            +
                    return conv(x)
         | 
| 810 | 
            +
             | 
| 811 | 
            +
             | 
| 812 | 
            +
            class Upsample3D(nn.Module):
         | 
| 813 | 
            +
                """
         | 
| 814 | 
            +
                An upsampling layer for 3D tensors of shape (B, C, D, H, W).
         | 
| 815 | 
            +
             | 
| 816 | 
            +
                :param channels: channels in the inputs and outputs.
         | 
| 817 | 
            +
                """
         | 
| 818 | 
            +
             | 
| 819 | 
            +
                def __init__(self, dims, channels, out_channels=None):
         | 
| 820 | 
            +
                    super().__init__()
         | 
| 821 | 
            +
                    self.dims = dims
         | 
| 822 | 
            +
                    self.channels = channels
         | 
| 823 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 824 | 
            +
                    self.conv = make_conv_nd(
         | 
| 825 | 
            +
                        dims, channels, out_channels, kernel_size=3, padding=1, bias=True
         | 
| 826 | 
            +
                    )
         | 
| 827 | 
            +
             | 
| 828 | 
            +
                def forward(self, x, upsample_in_time):
         | 
| 829 | 
            +
                    if self.dims == 2:
         | 
| 830 | 
            +
                        x = functional.interpolate(
         | 
| 831 | 
            +
                            x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
         | 
| 832 | 
            +
                        )
         | 
| 833 | 
            +
                    else:
         | 
| 834 | 
            +
                        time_scale_factor = 2 if upsample_in_time else 1
         | 
| 835 | 
            +
                        # print("before:", x.shape)
         | 
| 836 | 
            +
                        b, c, d, h, w = x.shape
         | 
| 837 | 
            +
                        x = rearrange(x, "b c d h w -> (b d) c h w")
         | 
| 838 | 
            +
                        # height and width interpolate
         | 
| 839 | 
            +
                        x = functional.interpolate(
         | 
| 840 | 
            +
                            x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
         | 
| 841 | 
            +
                        )
         | 
| 842 | 
            +
                        _, _, h, w = x.shape
         | 
| 843 | 
            +
             | 
| 844 | 
            +
                        if not upsample_in_time and self.dims == (2, 1):
         | 
| 845 | 
            +
                            x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w)
         | 
| 846 | 
            +
                            return self.conv(x, skip_time_conv=True)
         | 
| 847 | 
            +
             | 
| 848 | 
            +
                        # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension
         | 
| 849 | 
            +
                        x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b)
         | 
| 850 | 
            +
             | 
| 851 | 
            +
                        # (b h w) c 1 d
         | 
| 852 | 
            +
                        new_d = x.shape[-1] * time_scale_factor
         | 
| 853 | 
            +
                        x = functional.interpolate(x, (1, new_d), mode="nearest")
         | 
| 854 | 
            +
                        # (b h w) c 1 new_d
         | 
| 855 | 
            +
                        x = rearrange(
         | 
| 856 | 
            +
                            x, "(b h w) c 1 new_d  -> b c new_d h w", b=b, h=h, w=w, new_d=new_d
         | 
| 857 | 
            +
                        )
         | 
| 858 | 
            +
                        # b c d h w
         | 
| 859 | 
            +
             | 
| 860 | 
            +
                        # x = functional.interpolate(
         | 
| 861 | 
            +
                        #     x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
         | 
| 862 | 
            +
                        # )
         | 
| 863 | 
            +
                        # print("after:", x.shape)
         | 
| 864 | 
            +
             | 
| 865 | 
            +
                    return self.conv(x)
         | 
| 866 | 
            +
             | 
| 867 | 
            +
             | 
| 868 | 
            +
            def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
         | 
| 869 | 
            +
                if patch_size_hw == 1 and patch_size_t == 1:
         | 
| 870 | 
            +
                    return x
         | 
| 871 | 
            +
                if x.dim() == 4:
         | 
| 872 | 
            +
                    x = rearrange(
         | 
| 873 | 
            +
                        x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
         | 
| 874 | 
            +
                    )
         | 
| 875 | 
            +
                elif x.dim() == 5:
         | 
| 876 | 
            +
                    x = rearrange(
         | 
| 877 | 
            +
                        x,
         | 
| 878 | 
            +
                        "b c (f p) (h q) (w r) -> b (c p r q) f h w",
         | 
| 879 | 
            +
                        p=patch_size_t,
         | 
| 880 | 
            +
                        q=patch_size_hw,
         | 
| 881 | 
            +
                        r=patch_size_hw,
         | 
| 882 | 
            +
                    )
         | 
| 883 | 
            +
                else:
         | 
| 884 | 
            +
                    raise ValueError(f"Invalid input shape: {x.shape}")
         | 
| 885 | 
            +
             | 
| 886 | 
            +
                if (
         | 
| 887 | 
            +
                    (x.dim() == 5)
         | 
| 888 | 
            +
                    and (patch_size_hw > patch_size_t)
         | 
| 889 | 
            +
                    and (patch_size_t > 1 or add_channel_padding)
         | 
| 890 | 
            +
                ):
         | 
| 891 | 
            +
                    channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
         | 
| 892 | 
            +
                    padding_zeros = torch.zeros(
         | 
| 893 | 
            +
                        x.shape[0],
         | 
| 894 | 
            +
                        channels_to_pad,
         | 
| 895 | 
            +
                        x.shape[2],
         | 
| 896 | 
            +
                        x.shape[3],
         | 
| 897 | 
            +
                        x.shape[4],
         | 
| 898 | 
            +
                        device=x.device,
         | 
| 899 | 
            +
                        dtype=x.dtype,
         | 
| 900 | 
            +
                    )
         | 
| 901 | 
            +
                    x = torch.cat([padding_zeros, x], dim=1)
         | 
| 902 | 
            +
             | 
| 903 | 
            +
                return x
         | 
| 904 | 
            +
             | 
| 905 | 
            +
             | 
| 906 | 
            +
            def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
         | 
| 907 | 
            +
                if patch_size_hw == 1 and patch_size_t == 1:
         | 
| 908 | 
            +
                    return x
         | 
| 909 | 
            +
             | 
| 910 | 
            +
                if (
         | 
| 911 | 
            +
                    (x.dim() == 5)
         | 
| 912 | 
            +
                    and (patch_size_hw > patch_size_t)
         | 
| 913 | 
            +
                    and (patch_size_t > 1 or add_channel_padding)
         | 
| 914 | 
            +
                ):
         | 
| 915 | 
            +
                    channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
         | 
| 916 | 
            +
                    x = x[:, :channels_to_keep, :, :, :]
         | 
| 917 | 
            +
             | 
| 918 | 
            +
                if x.dim() == 4:
         | 
| 919 | 
            +
                    x = rearrange(
         | 
| 920 | 
            +
                        x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
         | 
| 921 | 
            +
                    )
         | 
| 922 | 
            +
                elif x.dim() == 5:
         | 
| 923 | 
            +
                    x = rearrange(
         | 
| 924 | 
            +
                        x,
         | 
| 925 | 
            +
                        "b (c p r q) f h w -> b c (f p) (h q) (w r)",
         | 
| 926 | 
            +
                        p=patch_size_t,
         | 
| 927 | 
            +
                        q=patch_size_hw,
         | 
| 928 | 
            +
                        r=patch_size_hw,
         | 
| 929 | 
            +
                    )
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                return x
         | 
| 932 | 
            +
             | 
| 933 | 
            +
             | 
| 934 | 
            +
            def create_video_autoencoder_config(
         | 
| 935 | 
            +
                latent_channels: int = 4,
         | 
| 936 | 
            +
            ):
         | 
| 937 | 
            +
                config = {
         | 
| 938 | 
            +
                    "_class_name": "VideoAutoencoder",
         | 
| 939 | 
            +
                    "dims": (
         | 
| 940 | 
            +
                        2,
         | 
| 941 | 
            +
                        1,
         | 
| 942 | 
            +
                    ),  # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
         | 
| 943 | 
            +
                    "in_channels": 3,  # Number of input color channels (e.g., RGB)
         | 
| 944 | 
            +
                    "out_channels": 3,  # Number of output color channels
         | 
| 945 | 
            +
                    "latent_channels": latent_channels,  # Number of channels in the latent space representation
         | 
| 946 | 
            +
                    "block_out_channels": [
         | 
| 947 | 
            +
                        128,
         | 
| 948 | 
            +
                        256,
         | 
| 949 | 
            +
                        512,
         | 
| 950 | 
            +
                        512,
         | 
| 951 | 
            +
                    ],  # Number of output channels of each encoder / decoder inner block
         | 
| 952 | 
            +
                    "patch_size": 1,
         | 
| 953 | 
            +
                }
         | 
| 954 | 
            +
             | 
| 955 | 
            +
                return config
         | 
| 956 | 
            +
             | 
| 957 | 
            +
             | 
| 958 | 
            +
            def create_video_autoencoder_pathify4x4x4_config(
         | 
| 959 | 
            +
                latent_channels: int = 4,
         | 
| 960 | 
            +
            ):
         | 
| 961 | 
            +
                config = {
         | 
| 962 | 
            +
                    "_class_name": "VideoAutoencoder",
         | 
| 963 | 
            +
                    "dims": (
         | 
| 964 | 
            +
                        2,
         | 
| 965 | 
            +
                        1,
         | 
| 966 | 
            +
                    ),  # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
         | 
| 967 | 
            +
                    "in_channels": 3,  # Number of input color channels (e.g., RGB)
         | 
| 968 | 
            +
                    "out_channels": 3,  # Number of output color channels
         | 
| 969 | 
            +
                    "latent_channels": latent_channels,  # Number of channels in the latent space representation
         | 
| 970 | 
            +
                    "block_out_channels": [512]
         | 
| 971 | 
            +
                    * 4,  # Number of output channels of each encoder / decoder inner block
         | 
| 972 | 
            +
                    "patch_size": 4,
         | 
| 973 | 
            +
                    "latent_log_var": "uniform",
         | 
| 974 | 
            +
                }
         | 
| 975 | 
            +
             | 
| 976 | 
            +
                return config
         | 
| 977 | 
            +
             | 
| 978 | 
            +
             | 
| 979 | 
            +
            def create_video_autoencoder_pathify4x4_config(
         | 
| 980 | 
            +
                latent_channels: int = 4,
         | 
| 981 | 
            +
            ):
         | 
| 982 | 
            +
                config = {
         | 
| 983 | 
            +
                    "_class_name": "VideoAutoencoder",
         | 
| 984 | 
            +
                    "dims": 2,  # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
         | 
| 985 | 
            +
                    "in_channels": 3,  # Number of input color channels (e.g., RGB)
         | 
| 986 | 
            +
                    "out_channels": 3,  # Number of output color channels
         | 
| 987 | 
            +
                    "latent_channels": latent_channels,  # Number of channels in the latent space representation
         | 
| 988 | 
            +
                    "block_out_channels": [512]
         | 
| 989 | 
            +
                    * 4,  # Number of output channels of each encoder / decoder inner block
         | 
| 990 | 
            +
                    "patch_size": 4,
         | 
| 991 | 
            +
                    "norm_layer": "pixel_norm",
         | 
| 992 | 
            +
                }
         | 
| 993 | 
            +
             | 
| 994 | 
            +
                return config
         | 
| 995 | 
            +
             | 
| 996 | 
            +
             | 
| 997 | 
            +
            def test_vae_patchify_unpatchify():
         | 
| 998 | 
            +
                import torch
         | 
| 999 | 
            +
             | 
| 1000 | 
            +
                x = torch.randn(2, 3, 8, 64, 64)
         | 
| 1001 | 
            +
                x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
         | 
| 1002 | 
            +
                x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
         | 
| 1003 | 
            +
                assert torch.allclose(x, x_unpatched)
         | 
| 1004 | 
            +
             | 
| 1005 | 
            +
             | 
| 1006 | 
            +
            def demo_video_autoencoder_forward_backward():
         | 
| 1007 | 
            +
                # Configuration for the VideoAutoencoder
         | 
| 1008 | 
            +
                config = create_video_autoencoder_pathify4x4x4_config()
         | 
| 1009 | 
            +
             | 
| 1010 | 
            +
                # Instantiate the VideoAutoencoder with the specified configuration
         | 
| 1011 | 
            +
                video_autoencoder = VideoAutoencoder.from_config(config)
         | 
| 1012 | 
            +
             | 
| 1013 | 
            +
                print(video_autoencoder)
         | 
| 1014 | 
            +
             | 
| 1015 | 
            +
                # Print the total number of parameters in the video autoencoder
         | 
| 1016 | 
            +
                total_params = sum(p.numel() for p in video_autoencoder.parameters())
         | 
| 1017 | 
            +
                print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
         | 
| 1018 | 
            +
             | 
| 1019 | 
            +
                # Create a mock input tensor simulating a batch of videos
         | 
| 1020 | 
            +
                # Shape: (batch_size, channels, depth, height, width)
         | 
| 1021 | 
            +
                # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
         | 
| 1022 | 
            +
                input_videos = torch.randn(2, 3, 8, 64, 64)
         | 
| 1023 | 
            +
             | 
| 1024 | 
            +
                # Forward pass: encode and decode the input videos
         | 
| 1025 | 
            +
                latent = video_autoencoder.encode(input_videos).latent_dist.mode()
         | 
| 1026 | 
            +
                print(f"input shape={input_videos.shape}")
         | 
| 1027 | 
            +
                print(f"latent shape={latent.shape}")
         | 
| 1028 | 
            +
                reconstructed_videos = video_autoencoder.decode(
         | 
| 1029 | 
            +
                    latent, target_shape=input_videos.shape
         | 
| 1030 | 
            +
                ).sample
         | 
| 1031 | 
            +
             | 
| 1032 | 
            +
                print(f"reconstructed shape={reconstructed_videos.shape}")
         | 
| 1033 | 
            +
             | 
| 1034 | 
            +
                # Calculate the loss (e.g., mean squared error)
         | 
| 1035 | 
            +
                loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
         | 
| 1036 | 
            +
             | 
| 1037 | 
            +
                # Perform backward pass
         | 
| 1038 | 
            +
                loss.backward()
         | 
| 1039 | 
            +
             | 
| 1040 | 
            +
                print(f"Demo completed with loss: {loss.item()}")
         | 
| 1041 | 
            +
             | 
| 1042 | 
            +
             | 
| 1043 | 
            +
            # Ensure to call the demo function to execute the forward and backward pass
         | 
| 1044 | 
            +
            if __name__ == "__main__":
         | 
| 1045 | 
            +
                demo_video_autoencoder_forward_backward()
         | 
    	
        ltx_video/models/transformers/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        ltx_video/models/transformers/attention.py
    ADDED
    
    | @@ -0,0 +1,1265 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import inspect
         | 
| 2 | 
            +
            from importlib import import_module
         | 
| 3 | 
            +
            from typing import Any, Dict, Optional, Tuple
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
            from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
         | 
| 8 | 
            +
            from diffusers.models.attention import _chunked_feed_forward
         | 
| 9 | 
            +
            from diffusers.models.attention_processor import (
         | 
| 10 | 
            +
                LoRAAttnAddedKVProcessor,
         | 
| 11 | 
            +
                LoRAAttnProcessor,
         | 
| 12 | 
            +
                LoRAAttnProcessor2_0,
         | 
| 13 | 
            +
                LoRAXFormersAttnProcessor,
         | 
| 14 | 
            +
                SpatialNorm,
         | 
| 15 | 
            +
            )
         | 
| 16 | 
            +
            from diffusers.models.lora import LoRACompatibleLinear
         | 
| 17 | 
            +
            from diffusers.models.normalization import RMSNorm
         | 
| 18 | 
            +
            from diffusers.utils import deprecate, logging
         | 
| 19 | 
            +
            from diffusers.utils.torch_utils import maybe_allow_in_graph
         | 
| 20 | 
            +
            from einops import rearrange
         | 
| 21 | 
            +
            from torch import nn
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            try:
         | 
| 26 | 
            +
                from torch_xla.experimental.custom_kernel import flash_attention
         | 
| 27 | 
            +
            except ImportError:
         | 
| 28 | 
            +
                # workaround for automatic tests. Currently this function is manually patched
         | 
| 29 | 
            +
                # to the torch_xla lib on setup of container
         | 
| 30 | 
            +
                pass
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            # code adapted from  https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            @maybe_allow_in_graph
         | 
| 38 | 
            +
            class BasicTransformerBlock(nn.Module):
         | 
| 39 | 
            +
                r"""
         | 
| 40 | 
            +
                A basic Transformer block.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                Parameters:
         | 
| 43 | 
            +
                    dim (`int`): The number of channels in the input and output.
         | 
| 44 | 
            +
                    num_attention_heads (`int`): The number of heads to use for multi-head attention.
         | 
| 45 | 
            +
                    attention_head_dim (`int`): The number of channels in each head.
         | 
| 46 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 47 | 
            +
                    cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
         | 
| 48 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 49 | 
            +
                    num_embeds_ada_norm (:
         | 
| 50 | 
            +
                        obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
         | 
| 51 | 
            +
                    attention_bias (:
         | 
| 52 | 
            +
                        obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
         | 
| 53 | 
            +
                    only_cross_attention (`bool`, *optional*):
         | 
| 54 | 
            +
                        Whether to use only cross-attention layers. In this case two cross attention layers are used.
         | 
| 55 | 
            +
                    double_self_attention (`bool`, *optional*):
         | 
| 56 | 
            +
                        Whether to use two self-attention layers. In this case no cross attention layers are used.
         | 
| 57 | 
            +
                    upcast_attention (`bool`, *optional*):
         | 
| 58 | 
            +
                        Whether to upcast the attention computation to float32. This is useful for mixed precision training.
         | 
| 59 | 
            +
                    norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
         | 
| 60 | 
            +
                        Whether to use learnable elementwise affine parameters for normalization.
         | 
| 61 | 
            +
                    qk_norm (`str`, *optional*, defaults to None):
         | 
| 62 | 
            +
                        Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
         | 
| 63 | 
            +
                    adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`):
         | 
| 64 | 
            +
                        The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none".
         | 
| 65 | 
            +
                    standardization_norm (`str`, *optional*, defaults to `"layer_norm"`):
         | 
| 66 | 
            +
                        The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
         | 
| 67 | 
            +
                    final_dropout (`bool` *optional*, defaults to False):
         | 
| 68 | 
            +
                        Whether to apply a final dropout after the last feed-forward layer.
         | 
| 69 | 
            +
                    attention_type (`str`, *optional*, defaults to `"default"`):
         | 
| 70 | 
            +
                        The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
         | 
| 71 | 
            +
                    positional_embeddings (`str`, *optional*, defaults to `None`):
         | 
| 72 | 
            +
                        The type of positional embeddings to apply to.
         | 
| 73 | 
            +
                    num_positional_embeddings (`int`, *optional*, defaults to `None`):
         | 
| 74 | 
            +
                        The maximum number of positional embeddings to apply.
         | 
| 75 | 
            +
                """
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                def __init__(
         | 
| 78 | 
            +
                    self,
         | 
| 79 | 
            +
                    dim: int,
         | 
| 80 | 
            +
                    num_attention_heads: int,
         | 
| 81 | 
            +
                    attention_head_dim: int,
         | 
| 82 | 
            +
                    dropout=0.0,
         | 
| 83 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 84 | 
            +
                    activation_fn: str = "geglu",
         | 
| 85 | 
            +
                    num_embeds_ada_norm: Optional[int] = None,  # pylint: disable=unused-argument
         | 
| 86 | 
            +
                    attention_bias: bool = False,
         | 
| 87 | 
            +
                    only_cross_attention: bool = False,
         | 
| 88 | 
            +
                    double_self_attention: bool = False,
         | 
| 89 | 
            +
                    upcast_attention: bool = False,
         | 
| 90 | 
            +
                    norm_elementwise_affine: bool = True,
         | 
| 91 | 
            +
                    adaptive_norm: str = "single_scale_shift",  # 'single_scale_shift', 'single_scale' or 'none'
         | 
| 92 | 
            +
                    standardization_norm: str = "layer_norm",  # 'layer_norm' or 'rms_norm'
         | 
| 93 | 
            +
                    norm_eps: float = 1e-5,
         | 
| 94 | 
            +
                    qk_norm: Optional[str] = None,
         | 
| 95 | 
            +
                    final_dropout: bool = False,
         | 
| 96 | 
            +
                    attention_type: str = "default",  # pylint: disable=unused-argument
         | 
| 97 | 
            +
                    ff_inner_dim: Optional[int] = None,
         | 
| 98 | 
            +
                    ff_bias: bool = True,
         | 
| 99 | 
            +
                    attention_out_bias: bool = True,
         | 
| 100 | 
            +
                    use_tpu_flash_attention: bool = False,
         | 
| 101 | 
            +
                    use_rope: bool = False,
         | 
| 102 | 
            +
                ):
         | 
| 103 | 
            +
                    super().__init__()
         | 
| 104 | 
            +
                    self.only_cross_attention = only_cross_attention
         | 
| 105 | 
            +
                    self.use_tpu_flash_attention = use_tpu_flash_attention
         | 
| 106 | 
            +
                    self.adaptive_norm = adaptive_norm
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    assert standardization_norm in ["layer_norm", "rms_norm"]
         | 
| 109 | 
            +
                    assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    make_norm_layer = (
         | 
| 112 | 
            +
                        nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm
         | 
| 113 | 
            +
                    )
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    # Define 3 blocks. Each block has its own normalization layer.
         | 
| 116 | 
            +
                    # 1. Self-Attn
         | 
| 117 | 
            +
                    self.norm1 = make_norm_layer(
         | 
| 118 | 
            +
                        dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
         | 
| 119 | 
            +
                    )
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    self.attn1 = Attention(
         | 
| 122 | 
            +
                        query_dim=dim,
         | 
| 123 | 
            +
                        heads=num_attention_heads,
         | 
| 124 | 
            +
                        dim_head=attention_head_dim,
         | 
| 125 | 
            +
                        dropout=dropout,
         | 
| 126 | 
            +
                        bias=attention_bias,
         | 
| 127 | 
            +
                        cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         | 
| 128 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 129 | 
            +
                        out_bias=attention_out_bias,
         | 
| 130 | 
            +
                        use_tpu_flash_attention=use_tpu_flash_attention,
         | 
| 131 | 
            +
                        qk_norm=qk_norm,
         | 
| 132 | 
            +
                        use_rope=use_rope,
         | 
| 133 | 
            +
                    )
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # 2. Cross-Attn
         | 
| 136 | 
            +
                    if cross_attention_dim is not None or double_self_attention:
         | 
| 137 | 
            +
                        self.attn2 = Attention(
         | 
| 138 | 
            +
                            query_dim=dim,
         | 
| 139 | 
            +
                            cross_attention_dim=(
         | 
| 140 | 
            +
                                cross_attention_dim if not double_self_attention else None
         | 
| 141 | 
            +
                            ),
         | 
| 142 | 
            +
                            heads=num_attention_heads,
         | 
| 143 | 
            +
                            dim_head=attention_head_dim,
         | 
| 144 | 
            +
                            dropout=dropout,
         | 
| 145 | 
            +
                            bias=attention_bias,
         | 
| 146 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 147 | 
            +
                            out_bias=attention_out_bias,
         | 
| 148 | 
            +
                            use_tpu_flash_attention=use_tpu_flash_attention,
         | 
| 149 | 
            +
                            qk_norm=qk_norm,
         | 
| 150 | 
            +
                            use_rope=use_rope,
         | 
| 151 | 
            +
                        )  # is self-attn if encoder_hidden_states is none
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                        if adaptive_norm == "none":
         | 
| 154 | 
            +
                            self.attn2_norm = make_norm_layer(
         | 
| 155 | 
            +
                                dim, norm_eps, norm_elementwise_affine
         | 
| 156 | 
            +
                            )
         | 
| 157 | 
            +
                    else:
         | 
| 158 | 
            +
                        self.attn2 = None
         | 
| 159 | 
            +
                        self.attn2_norm = None
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    # 3. Feed-forward
         | 
| 164 | 
            +
                    self.ff = FeedForward(
         | 
| 165 | 
            +
                        dim,
         | 
| 166 | 
            +
                        dropout=dropout,
         | 
| 167 | 
            +
                        activation_fn=activation_fn,
         | 
| 168 | 
            +
                        final_dropout=final_dropout,
         | 
| 169 | 
            +
                        inner_dim=ff_inner_dim,
         | 
| 170 | 
            +
                        bias=ff_bias,
         | 
| 171 | 
            +
                    )
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    # 5. Scale-shift for PixArt-Alpha.
         | 
| 174 | 
            +
                    if adaptive_norm != "none":
         | 
| 175 | 
            +
                        num_ada_params = 4 if adaptive_norm == "single_scale" else 6
         | 
| 176 | 
            +
                        self.scale_shift_table = nn.Parameter(
         | 
| 177 | 
            +
                            torch.randn(num_ada_params, dim) / dim**0.5
         | 
| 178 | 
            +
                        )
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    # let chunk size default to None
         | 
| 181 | 
            +
                    self._chunk_size = None
         | 
| 182 | 
            +
                    self._chunk_dim = 0
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                def set_use_tpu_flash_attention(self):
         | 
| 185 | 
            +
                    r"""
         | 
| 186 | 
            +
                    Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
         | 
| 187 | 
            +
                    attention kernel.
         | 
| 188 | 
            +
                    """
         | 
| 189 | 
            +
                    self.use_tpu_flash_attention = True
         | 
| 190 | 
            +
                    self.attn1.set_use_tpu_flash_attention()
         | 
| 191 | 
            +
                    self.attn2.set_use_tpu_flash_attention()
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
         | 
| 194 | 
            +
                    # Sets chunk feed-forward
         | 
| 195 | 
            +
                    self._chunk_size = chunk_size
         | 
| 196 | 
            +
                    self._chunk_dim = dim
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def forward(
         | 
| 199 | 
            +
                    self,
         | 
| 200 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 201 | 
            +
                    freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
         | 
| 202 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 203 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 204 | 
            +
                    encoder_attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 205 | 
            +
                    timestep: Optional[torch.LongTensor] = None,
         | 
| 206 | 
            +
                    cross_attention_kwargs: Dict[str, Any] = None,
         | 
| 207 | 
            +
                    class_labels: Optional[torch.LongTensor] = None,
         | 
| 208 | 
            +
                    added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
         | 
| 209 | 
            +
                    skip_layer_mask: Optional[torch.Tensor] = None,
         | 
| 210 | 
            +
                    skip_layer_strategy: Optional[SkipLayerStrategy] = None,
         | 
| 211 | 
            +
                ) -> torch.FloatTensor:
         | 
| 212 | 
            +
                    if cross_attention_kwargs is not None:
         | 
| 213 | 
            +
                        if cross_attention_kwargs.get("scale", None) is not None:
         | 
| 214 | 
            +
                            logger.warning(
         | 
| 215 | 
            +
                                "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored."
         | 
| 216 | 
            +
                            )
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    # Notice that normalization is always applied before the real computation in the following blocks.
         | 
| 219 | 
            +
                    # 0. Self-Attention
         | 
| 220 | 
            +
                    batch_size = hidden_states.shape[0]
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    original_hidden_states = hidden_states
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    norm_hidden_states = self.norm1(hidden_states)
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                    # Apply ada_norm_single
         | 
| 227 | 
            +
                    if self.adaptive_norm in ["single_scale_shift", "single_scale"]:
         | 
| 228 | 
            +
                        assert timestep.ndim == 3  # [batch, 1 or num_tokens, embedding_dim]
         | 
| 229 | 
            +
                        num_ada_params = self.scale_shift_table.shape[0]
         | 
| 230 | 
            +
                        ada_values = self.scale_shift_table[None, None] + timestep.reshape(
         | 
| 231 | 
            +
                            batch_size, timestep.shape[1], num_ada_params, -1
         | 
| 232 | 
            +
                        )
         | 
| 233 | 
            +
                        if self.adaptive_norm == "single_scale_shift":
         | 
| 234 | 
            +
                            shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
         | 
| 235 | 
            +
                                ada_values.unbind(dim=2)
         | 
| 236 | 
            +
                            )
         | 
| 237 | 
            +
                            norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
         | 
| 238 | 
            +
                        else:
         | 
| 239 | 
            +
                            scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
         | 
| 240 | 
            +
                            norm_hidden_states = norm_hidden_states * (1 + scale_msa)
         | 
| 241 | 
            +
                    elif self.adaptive_norm == "none":
         | 
| 242 | 
            +
                        scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None
         | 
| 243 | 
            +
                    else:
         | 
| 244 | 
            +
                        raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    norm_hidden_states = norm_hidden_states.squeeze(
         | 
| 247 | 
            +
                        1
         | 
| 248 | 
            +
                    )  # TODO: Check if this is needed
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    # 1. Prepare GLIGEN inputs
         | 
| 251 | 
            +
                    cross_attention_kwargs = (
         | 
| 252 | 
            +
                        cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
         | 
| 253 | 
            +
                    )
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    attn_output = self.attn1(
         | 
| 256 | 
            +
                        norm_hidden_states,
         | 
| 257 | 
            +
                        freqs_cis=freqs_cis,
         | 
| 258 | 
            +
                        encoder_hidden_states=(
         | 
| 259 | 
            +
                            encoder_hidden_states if self.only_cross_attention else None
         | 
| 260 | 
            +
                        ),
         | 
| 261 | 
            +
                        attention_mask=attention_mask,
         | 
| 262 | 
            +
                        skip_layer_mask=skip_layer_mask,
         | 
| 263 | 
            +
                        skip_layer_strategy=skip_layer_strategy,
         | 
| 264 | 
            +
                        **cross_attention_kwargs,
         | 
| 265 | 
            +
                    )
         | 
| 266 | 
            +
                    if gate_msa is not None:
         | 
| 267 | 
            +
                        attn_output = gate_msa * attn_output
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    hidden_states = attn_output + hidden_states
         | 
| 270 | 
            +
                    if hidden_states.ndim == 4:
         | 
| 271 | 
            +
                        hidden_states = hidden_states.squeeze(1)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    # 3. Cross-Attention
         | 
| 274 | 
            +
                    if self.attn2 is not None:
         | 
| 275 | 
            +
                        if self.adaptive_norm == "none":
         | 
| 276 | 
            +
                            attn_input = self.attn2_norm(hidden_states)
         | 
| 277 | 
            +
                        else:
         | 
| 278 | 
            +
                            attn_input = hidden_states
         | 
| 279 | 
            +
                        attn_output = self.attn2(
         | 
| 280 | 
            +
                            attn_input,
         | 
| 281 | 
            +
                            freqs_cis=freqs_cis,
         | 
| 282 | 
            +
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 283 | 
            +
                            attention_mask=encoder_attention_mask,
         | 
| 284 | 
            +
                            **cross_attention_kwargs,
         | 
| 285 | 
            +
                        )
         | 
| 286 | 
            +
                        hidden_states = attn_output + hidden_states
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    # 4. Feed-forward
         | 
| 289 | 
            +
                    norm_hidden_states = self.norm2(hidden_states)
         | 
| 290 | 
            +
                    if self.adaptive_norm == "single_scale_shift":
         | 
| 291 | 
            +
                        norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
         | 
| 292 | 
            +
                    elif self.adaptive_norm == "single_scale":
         | 
| 293 | 
            +
                        norm_hidden_states = norm_hidden_states * (1 + scale_mlp)
         | 
| 294 | 
            +
                    elif self.adaptive_norm == "none":
         | 
| 295 | 
            +
                        pass
         | 
| 296 | 
            +
                    else:
         | 
| 297 | 
            +
                        raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    if self._chunk_size is not None:
         | 
| 300 | 
            +
                        # "feed_forward_chunk_size" can be used to save memory
         | 
| 301 | 
            +
                        ff_output = _chunked_feed_forward(
         | 
| 302 | 
            +
                            self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
         | 
| 303 | 
            +
                        )
         | 
| 304 | 
            +
                    else:
         | 
| 305 | 
            +
                        ff_output = self.ff(norm_hidden_states)
         | 
| 306 | 
            +
                    if gate_mlp is not None:
         | 
| 307 | 
            +
                        ff_output = gate_mlp * ff_output
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    hidden_states = ff_output + hidden_states
         | 
| 310 | 
            +
                    if hidden_states.ndim == 4:
         | 
| 311 | 
            +
                        hidden_states = hidden_states.squeeze(1)
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    if (
         | 
| 314 | 
            +
                        skip_layer_mask is not None
         | 
| 315 | 
            +
                        and skip_layer_strategy == SkipLayerStrategy.TransformerBlock
         | 
| 316 | 
            +
                    ):
         | 
| 317 | 
            +
                        skip_layer_mask = skip_layer_mask.view(-1, 1, 1)
         | 
| 318 | 
            +
                        hidden_states = hidden_states * skip_layer_mask + original_hidden_states * (
         | 
| 319 | 
            +
                            1.0 - skip_layer_mask
         | 
| 320 | 
            +
                        )
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    return hidden_states
         | 
| 323 | 
            +
             | 
| 324 | 
            +
             | 
| 325 | 
            +
            @maybe_allow_in_graph
         | 
| 326 | 
            +
            class Attention(nn.Module):
         | 
| 327 | 
            +
                r"""
         | 
| 328 | 
            +
                A cross attention layer.
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                Parameters:
         | 
| 331 | 
            +
                    query_dim (`int`):
         | 
| 332 | 
            +
                        The number of channels in the query.
         | 
| 333 | 
            +
                    cross_attention_dim (`int`, *optional*):
         | 
| 334 | 
            +
                        The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
         | 
| 335 | 
            +
                    heads (`int`,  *optional*, defaults to 8):
         | 
| 336 | 
            +
                        The number of heads to use for multi-head attention.
         | 
| 337 | 
            +
                    dim_head (`int`,  *optional*, defaults to 64):
         | 
| 338 | 
            +
                        The number of channels in each head.
         | 
| 339 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0):
         | 
| 340 | 
            +
                        The dropout probability to use.
         | 
| 341 | 
            +
                    bias (`bool`, *optional*, defaults to False):
         | 
| 342 | 
            +
                        Set to `True` for the query, key, and value linear layers to contain a bias parameter.
         | 
| 343 | 
            +
                    upcast_attention (`bool`, *optional*, defaults to False):
         | 
| 344 | 
            +
                        Set to `True` to upcast the attention computation to `float32`.
         | 
| 345 | 
            +
                    upcast_softmax (`bool`, *optional*, defaults to False):
         | 
| 346 | 
            +
                        Set to `True` to upcast the softmax computation to `float32`.
         | 
| 347 | 
            +
                    cross_attention_norm (`str`, *optional*, defaults to `None`):
         | 
| 348 | 
            +
                        The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
         | 
| 349 | 
            +
                    cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
         | 
| 350 | 
            +
                        The number of groups to use for the group norm in the cross attention.
         | 
| 351 | 
            +
                    added_kv_proj_dim (`int`, *optional*, defaults to `None`):
         | 
| 352 | 
            +
                        The number of channels to use for the added key and value projections. If `None`, no projection is used.
         | 
| 353 | 
            +
                    norm_num_groups (`int`, *optional*, defaults to `None`):
         | 
| 354 | 
            +
                        The number of groups to use for the group norm in the attention.
         | 
| 355 | 
            +
                    spatial_norm_dim (`int`, *optional*, defaults to `None`):
         | 
| 356 | 
            +
                        The number of channels to use for the spatial normalization.
         | 
| 357 | 
            +
                    out_bias (`bool`, *optional*, defaults to `True`):
         | 
| 358 | 
            +
                        Set to `True` to use a bias in the output linear layer.
         | 
| 359 | 
            +
                    scale_qk (`bool`, *optional*, defaults to `True`):
         | 
| 360 | 
            +
                        Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
         | 
| 361 | 
            +
                    qk_norm (`str`, *optional*, defaults to None):
         | 
| 362 | 
            +
                        Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
         | 
| 363 | 
            +
                    only_cross_attention (`bool`, *optional*, defaults to `False`):
         | 
| 364 | 
            +
                        Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
         | 
| 365 | 
            +
                        `added_kv_proj_dim` is not `None`.
         | 
| 366 | 
            +
                    eps (`float`, *optional*, defaults to 1e-5):
         | 
| 367 | 
            +
                        An additional value added to the denominator in group normalization that is used for numerical stability.
         | 
| 368 | 
            +
                    rescale_output_factor (`float`, *optional*, defaults to 1.0):
         | 
| 369 | 
            +
                        A factor to rescale the output by dividing it with this value.
         | 
| 370 | 
            +
                    residual_connection (`bool`, *optional*, defaults to `False`):
         | 
| 371 | 
            +
                        Set to `True` to add the residual connection to the output.
         | 
| 372 | 
            +
                    _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
         | 
| 373 | 
            +
                        Set to `True` if the attention block is loaded from a deprecated state dict.
         | 
| 374 | 
            +
                    processor (`AttnProcessor`, *optional*, defaults to `None`):
         | 
| 375 | 
            +
                        The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
         | 
| 376 | 
            +
                        `AttnProcessor` otherwise.
         | 
| 377 | 
            +
                """
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                def __init__(
         | 
| 380 | 
            +
                    self,
         | 
| 381 | 
            +
                    query_dim: int,
         | 
| 382 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 383 | 
            +
                    heads: int = 8,
         | 
| 384 | 
            +
                    dim_head: int = 64,
         | 
| 385 | 
            +
                    dropout: float = 0.0,
         | 
| 386 | 
            +
                    bias: bool = False,
         | 
| 387 | 
            +
                    upcast_attention: bool = False,
         | 
| 388 | 
            +
                    upcast_softmax: bool = False,
         | 
| 389 | 
            +
                    cross_attention_norm: Optional[str] = None,
         | 
| 390 | 
            +
                    cross_attention_norm_num_groups: int = 32,
         | 
| 391 | 
            +
                    added_kv_proj_dim: Optional[int] = None,
         | 
| 392 | 
            +
                    norm_num_groups: Optional[int] = None,
         | 
| 393 | 
            +
                    spatial_norm_dim: Optional[int] = None,
         | 
| 394 | 
            +
                    out_bias: bool = True,
         | 
| 395 | 
            +
                    scale_qk: bool = True,
         | 
| 396 | 
            +
                    qk_norm: Optional[str] = None,
         | 
| 397 | 
            +
                    only_cross_attention: bool = False,
         | 
| 398 | 
            +
                    eps: float = 1e-5,
         | 
| 399 | 
            +
                    rescale_output_factor: float = 1.0,
         | 
| 400 | 
            +
                    residual_connection: bool = False,
         | 
| 401 | 
            +
                    _from_deprecated_attn_block: bool = False,
         | 
| 402 | 
            +
                    processor: Optional["AttnProcessor"] = None,
         | 
| 403 | 
            +
                    out_dim: int = None,
         | 
| 404 | 
            +
                    use_tpu_flash_attention: bool = False,
         | 
| 405 | 
            +
                    use_rope: bool = False,
         | 
| 406 | 
            +
                ):
         | 
| 407 | 
            +
                    super().__init__()
         | 
| 408 | 
            +
                    self.inner_dim = out_dim if out_dim is not None else dim_head * heads
         | 
| 409 | 
            +
                    self.query_dim = query_dim
         | 
| 410 | 
            +
                    self.use_bias = bias
         | 
| 411 | 
            +
                    self.is_cross_attention = cross_attention_dim is not None
         | 
| 412 | 
            +
                    self.cross_attention_dim = (
         | 
| 413 | 
            +
                        cross_attention_dim if cross_attention_dim is not None else query_dim
         | 
| 414 | 
            +
                    )
         | 
| 415 | 
            +
                    self.upcast_attention = upcast_attention
         | 
| 416 | 
            +
                    self.upcast_softmax = upcast_softmax
         | 
| 417 | 
            +
                    self.rescale_output_factor = rescale_output_factor
         | 
| 418 | 
            +
                    self.residual_connection = residual_connection
         | 
| 419 | 
            +
                    self.dropout = dropout
         | 
| 420 | 
            +
                    self.fused_projections = False
         | 
| 421 | 
            +
                    self.out_dim = out_dim if out_dim is not None else query_dim
         | 
| 422 | 
            +
                    self.use_tpu_flash_attention = use_tpu_flash_attention
         | 
| 423 | 
            +
                    self.use_rope = use_rope
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    # we make use of this private variable to know whether this class is loaded
         | 
| 426 | 
            +
                    # with an deprecated state dict so that we can convert it on the fly
         | 
| 427 | 
            +
                    self._from_deprecated_attn_block = _from_deprecated_attn_block
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                    self.scale_qk = scale_qk
         | 
| 430 | 
            +
                    self.scale = dim_head**-0.5 if self.scale_qk else 1.0
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    if qk_norm is None:
         | 
| 433 | 
            +
                        self.q_norm = nn.Identity()
         | 
| 434 | 
            +
                        self.k_norm = nn.Identity()
         | 
| 435 | 
            +
                    elif qk_norm == "rms_norm":
         | 
| 436 | 
            +
                        self.q_norm = RMSNorm(dim_head * heads, eps=1e-5)
         | 
| 437 | 
            +
                        self.k_norm = RMSNorm(dim_head * heads, eps=1e-5)
         | 
| 438 | 
            +
                    elif qk_norm == "layer_norm":
         | 
| 439 | 
            +
                        self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
         | 
| 440 | 
            +
                        self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
         | 
| 441 | 
            +
                    else:
         | 
| 442 | 
            +
                        raise ValueError(f"Unsupported qk_norm method: {qk_norm}")
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    self.heads = out_dim // dim_head if out_dim is not None else heads
         | 
| 445 | 
            +
                    # for slice_size > 0 the attention score computation
         | 
| 446 | 
            +
                    # is split across the batch axis to save memory
         | 
| 447 | 
            +
                    # You can set slice_size with `set_attention_slice`
         | 
| 448 | 
            +
                    self.sliceable_head_dim = heads
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                    self.added_kv_proj_dim = added_kv_proj_dim
         | 
| 451 | 
            +
                    self.only_cross_attention = only_cross_attention
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                    if self.added_kv_proj_dim is None and self.only_cross_attention:
         | 
| 454 | 
            +
                        raise ValueError(
         | 
| 455 | 
            +
                            "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
         | 
| 456 | 
            +
                        )
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                    if norm_num_groups is not None:
         | 
| 459 | 
            +
                        self.group_norm = nn.GroupNorm(
         | 
| 460 | 
            +
                            num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
         | 
| 461 | 
            +
                        )
         | 
| 462 | 
            +
                    else:
         | 
| 463 | 
            +
                        self.group_norm = None
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                    if spatial_norm_dim is not None:
         | 
| 466 | 
            +
                        self.spatial_norm = SpatialNorm(
         | 
| 467 | 
            +
                            f_channels=query_dim, zq_channels=spatial_norm_dim
         | 
| 468 | 
            +
                        )
         | 
| 469 | 
            +
                    else:
         | 
| 470 | 
            +
                        self.spatial_norm = None
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                    if cross_attention_norm is None:
         | 
| 473 | 
            +
                        self.norm_cross = None
         | 
| 474 | 
            +
                    elif cross_attention_norm == "layer_norm":
         | 
| 475 | 
            +
                        self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
         | 
| 476 | 
            +
                    elif cross_attention_norm == "group_norm":
         | 
| 477 | 
            +
                        if self.added_kv_proj_dim is not None:
         | 
| 478 | 
            +
                            # The given `encoder_hidden_states` are initially of shape
         | 
| 479 | 
            +
                            # (batch_size, seq_len, added_kv_proj_dim) before being projected
         | 
| 480 | 
            +
                            # to (batch_size, seq_len, cross_attention_dim). The norm is applied
         | 
| 481 | 
            +
                            # before the projection, so we need to use `added_kv_proj_dim` as
         | 
| 482 | 
            +
                            # the number of channels for the group norm.
         | 
| 483 | 
            +
                            norm_cross_num_channels = added_kv_proj_dim
         | 
| 484 | 
            +
                        else:
         | 
| 485 | 
            +
                            norm_cross_num_channels = self.cross_attention_dim
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                        self.norm_cross = nn.GroupNorm(
         | 
| 488 | 
            +
                            num_channels=norm_cross_num_channels,
         | 
| 489 | 
            +
                            num_groups=cross_attention_norm_num_groups,
         | 
| 490 | 
            +
                            eps=1e-5,
         | 
| 491 | 
            +
                            affine=True,
         | 
| 492 | 
            +
                        )
         | 
| 493 | 
            +
                    else:
         | 
| 494 | 
            +
                        raise ValueError(
         | 
| 495 | 
            +
                            f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
         | 
| 496 | 
            +
                        )
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    linear_cls = nn.Linear
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                    self.linear_cls = linear_cls
         | 
| 501 | 
            +
                    self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                    if not self.only_cross_attention:
         | 
| 504 | 
            +
                        # only relevant for the `AddedKVProcessor` classes
         | 
| 505 | 
            +
                        self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
         | 
| 506 | 
            +
                        self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
         | 
| 507 | 
            +
                    else:
         | 
| 508 | 
            +
                        self.to_k = None
         | 
| 509 | 
            +
                        self.to_v = None
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                    if self.added_kv_proj_dim is not None:
         | 
| 512 | 
            +
                        self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
         | 
| 513 | 
            +
                        self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    self.to_out = nn.ModuleList([])
         | 
| 516 | 
            +
                    self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
         | 
| 517 | 
            +
                    self.to_out.append(nn.Dropout(dropout))
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                    # set attention processor
         | 
| 520 | 
            +
                    # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
         | 
| 521 | 
            +
                    # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
         | 
| 522 | 
            +
                    # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
         | 
| 523 | 
            +
                    if processor is None:
         | 
| 524 | 
            +
                        processor = AttnProcessor2_0()
         | 
| 525 | 
            +
                    self.set_processor(processor)
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                def set_use_tpu_flash_attention(self):
         | 
| 528 | 
            +
                    r"""
         | 
| 529 | 
            +
                    Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
         | 
| 530 | 
            +
                    """
         | 
| 531 | 
            +
                    self.use_tpu_flash_attention = True
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                def set_processor(self, processor: "AttnProcessor") -> None:
         | 
| 534 | 
            +
                    r"""
         | 
| 535 | 
            +
                    Set the attention processor to use.
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    Args:
         | 
| 538 | 
            +
                        processor (`AttnProcessor`):
         | 
| 539 | 
            +
                            The attention processor to use.
         | 
| 540 | 
            +
                    """
         | 
| 541 | 
            +
                    # if current processor is in `self._modules` and if passed `processor` is not, we need to
         | 
| 542 | 
            +
                    # pop `processor` from `self._modules`
         | 
| 543 | 
            +
                    if (
         | 
| 544 | 
            +
                        hasattr(self, "processor")
         | 
| 545 | 
            +
                        and isinstance(self.processor, torch.nn.Module)
         | 
| 546 | 
            +
                        and not isinstance(processor, torch.nn.Module)
         | 
| 547 | 
            +
                    ):
         | 
| 548 | 
            +
                        logger.info(
         | 
| 549 | 
            +
                            f"You are removing possibly trained weights of {self.processor} with {processor}"
         | 
| 550 | 
            +
                        )
         | 
| 551 | 
            +
                        self._modules.pop("processor")
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                    self.processor = processor
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                def get_processor(
         | 
| 556 | 
            +
                    self, return_deprecated_lora: bool = False
         | 
| 557 | 
            +
                ) -> "AttentionProcessor":  # noqa: F821
         | 
| 558 | 
            +
                    r"""
         | 
| 559 | 
            +
                    Get the attention processor in use.
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                    Args:
         | 
| 562 | 
            +
                        return_deprecated_lora (`bool`, *optional*, defaults to `False`):
         | 
| 563 | 
            +
                            Set to `True` to return the deprecated LoRA attention processor.
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                    Returns:
         | 
| 566 | 
            +
                        "AttentionProcessor": The attention processor in use.
         | 
| 567 | 
            +
                    """
         | 
| 568 | 
            +
                    if not return_deprecated_lora:
         | 
| 569 | 
            +
                        return self.processor
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                    # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
         | 
| 572 | 
            +
                    # serialization format for LoRA Attention Processors. It should be deleted once the integration
         | 
| 573 | 
            +
                    # with PEFT is completed.
         | 
| 574 | 
            +
                    is_lora_activated = {
         | 
| 575 | 
            +
                        name: module.lora_layer is not None
         | 
| 576 | 
            +
                        for name, module in self.named_modules()
         | 
| 577 | 
            +
                        if hasattr(module, "lora_layer")
         | 
| 578 | 
            +
                    }
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                    # 1. if no layer has a LoRA activated we can return the processor as usual
         | 
| 581 | 
            +
                    if not any(is_lora_activated.values()):
         | 
| 582 | 
            +
                        return self.processor
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                    # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
         | 
| 585 | 
            +
                    is_lora_activated.pop("add_k_proj", None)
         | 
| 586 | 
            +
                    is_lora_activated.pop("add_v_proj", None)
         | 
| 587 | 
            +
                    # 2. else it is not posssible that only some layers have LoRA activated
         | 
| 588 | 
            +
                    if not all(is_lora_activated.values()):
         | 
| 589 | 
            +
                        raise ValueError(
         | 
| 590 | 
            +
                            f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
         | 
| 591 | 
            +
                        )
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                    # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
         | 
| 594 | 
            +
                    non_lora_processor_cls_name = self.processor.__class__.__name__
         | 
| 595 | 
            +
                    lora_processor_cls = getattr(
         | 
| 596 | 
            +
                        import_module(__name__), "LoRA" + non_lora_processor_cls_name
         | 
| 597 | 
            +
                    )
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    hidden_size = self.inner_dim
         | 
| 600 | 
            +
             | 
| 601 | 
            +
                    # now create a LoRA attention processor from the LoRA layers
         | 
| 602 | 
            +
                    if lora_processor_cls in [
         | 
| 603 | 
            +
                        LoRAAttnProcessor,
         | 
| 604 | 
            +
                        LoRAAttnProcessor2_0,
         | 
| 605 | 
            +
                        LoRAXFormersAttnProcessor,
         | 
| 606 | 
            +
                    ]:
         | 
| 607 | 
            +
                        kwargs = {
         | 
| 608 | 
            +
                            "cross_attention_dim": self.cross_attention_dim,
         | 
| 609 | 
            +
                            "rank": self.to_q.lora_layer.rank,
         | 
| 610 | 
            +
                            "network_alpha": self.to_q.lora_layer.network_alpha,
         | 
| 611 | 
            +
                            "q_rank": self.to_q.lora_layer.rank,
         | 
| 612 | 
            +
                            "q_hidden_size": self.to_q.lora_layer.out_features,
         | 
| 613 | 
            +
                            "k_rank": self.to_k.lora_layer.rank,
         | 
| 614 | 
            +
                            "k_hidden_size": self.to_k.lora_layer.out_features,
         | 
| 615 | 
            +
                            "v_rank": self.to_v.lora_layer.rank,
         | 
| 616 | 
            +
                            "v_hidden_size": self.to_v.lora_layer.out_features,
         | 
| 617 | 
            +
                            "out_rank": self.to_out[0].lora_layer.rank,
         | 
| 618 | 
            +
                            "out_hidden_size": self.to_out[0].lora_layer.out_features,
         | 
| 619 | 
            +
                        }
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                        if hasattr(self.processor, "attention_op"):
         | 
| 622 | 
            +
                            kwargs["attention_op"] = self.processor.attention_op
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                        lora_processor = lora_processor_cls(hidden_size, **kwargs)
         | 
| 625 | 
            +
                        lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
         | 
| 626 | 
            +
                        lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
         | 
| 627 | 
            +
                        lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
         | 
| 628 | 
            +
                        lora_processor.to_out_lora.load_state_dict(
         | 
| 629 | 
            +
                            self.to_out[0].lora_layer.state_dict()
         | 
| 630 | 
            +
                        )
         | 
| 631 | 
            +
                    elif lora_processor_cls == LoRAAttnAddedKVProcessor:
         | 
| 632 | 
            +
                        lora_processor = lora_processor_cls(
         | 
| 633 | 
            +
                            hidden_size,
         | 
| 634 | 
            +
                            cross_attention_dim=self.add_k_proj.weight.shape[0],
         | 
| 635 | 
            +
                            rank=self.to_q.lora_layer.rank,
         | 
| 636 | 
            +
                            network_alpha=self.to_q.lora_layer.network_alpha,
         | 
| 637 | 
            +
                        )
         | 
| 638 | 
            +
                        lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
         | 
| 639 | 
            +
                        lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
         | 
| 640 | 
            +
                        lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
         | 
| 641 | 
            +
                        lora_processor.to_out_lora.load_state_dict(
         | 
| 642 | 
            +
                            self.to_out[0].lora_layer.state_dict()
         | 
| 643 | 
            +
                        )
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                        # only save if used
         | 
| 646 | 
            +
                        if self.add_k_proj.lora_layer is not None:
         | 
| 647 | 
            +
                            lora_processor.add_k_proj_lora.load_state_dict(
         | 
| 648 | 
            +
                                self.add_k_proj.lora_layer.state_dict()
         | 
| 649 | 
            +
                            )
         | 
| 650 | 
            +
                            lora_processor.add_v_proj_lora.load_state_dict(
         | 
| 651 | 
            +
                                self.add_v_proj.lora_layer.state_dict()
         | 
| 652 | 
            +
                            )
         | 
| 653 | 
            +
                        else:
         | 
| 654 | 
            +
                            lora_processor.add_k_proj_lora = None
         | 
| 655 | 
            +
                            lora_processor.add_v_proj_lora = None
         | 
| 656 | 
            +
                    else:
         | 
| 657 | 
            +
                        raise ValueError(f"{lora_processor_cls} does not exist.")
         | 
| 658 | 
            +
             | 
| 659 | 
            +
                    return lora_processor
         | 
| 660 | 
            +
             | 
| 661 | 
            +
                def forward(
         | 
| 662 | 
            +
                    self,
         | 
| 663 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 664 | 
            +
                    freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
         | 
| 665 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 666 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 667 | 
            +
                    skip_layer_mask: Optional[torch.Tensor] = None,
         | 
| 668 | 
            +
                    skip_layer_strategy: Optional[SkipLayerStrategy] = None,
         | 
| 669 | 
            +
                    **cross_attention_kwargs,
         | 
| 670 | 
            +
                ) -> torch.Tensor:
         | 
| 671 | 
            +
                    r"""
         | 
| 672 | 
            +
                    The forward method of the `Attention` class.
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                    Args:
         | 
| 675 | 
            +
                        hidden_states (`torch.Tensor`):
         | 
| 676 | 
            +
                            The hidden states of the query.
         | 
| 677 | 
            +
                        encoder_hidden_states (`torch.Tensor`, *optional*):
         | 
| 678 | 
            +
                            The hidden states of the encoder.
         | 
| 679 | 
            +
                        attention_mask (`torch.Tensor`, *optional*):
         | 
| 680 | 
            +
                            The attention mask to use. If `None`, no mask is applied.
         | 
| 681 | 
            +
                        skip_layer_mask (`torch.Tensor`, *optional*):
         | 
| 682 | 
            +
                            The skip layer mask to use. If `None`, no mask is applied.
         | 
| 683 | 
            +
                        skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`):
         | 
| 684 | 
            +
                            Controls which layers to skip for spatiotemporal guidance.
         | 
| 685 | 
            +
                        **cross_attention_kwargs:
         | 
| 686 | 
            +
                            Additional keyword arguments to pass along to the cross attention.
         | 
| 687 | 
            +
             | 
| 688 | 
            +
                    Returns:
         | 
| 689 | 
            +
                        `torch.Tensor`: The output of the attention layer.
         | 
| 690 | 
            +
                    """
         | 
| 691 | 
            +
                    # The `Attention` class can call different attention processors / attention functions
         | 
| 692 | 
            +
                    # here we simply pass along all tensors to the selected processor class
         | 
| 693 | 
            +
                    # For standard processors that are defined here, `**cross_attention_kwargs` is empty
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                    attn_parameters = set(
         | 
| 696 | 
            +
                        inspect.signature(self.processor.__call__).parameters.keys()
         | 
| 697 | 
            +
                    )
         | 
| 698 | 
            +
                    unused_kwargs = [
         | 
| 699 | 
            +
                        k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters
         | 
| 700 | 
            +
                    ]
         | 
| 701 | 
            +
                    if len(unused_kwargs) > 0:
         | 
| 702 | 
            +
                        logger.warning(
         | 
| 703 | 
            +
                            f"cross_attention_kwargs {unused_kwargs} are not expected by"
         | 
| 704 | 
            +
                            f" {self.processor.__class__.__name__} and will be ignored."
         | 
| 705 | 
            +
                        )
         | 
| 706 | 
            +
                    cross_attention_kwargs = {
         | 
| 707 | 
            +
                        k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
         | 
| 708 | 
            +
                    }
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                    return self.processor(
         | 
| 711 | 
            +
                        self,
         | 
| 712 | 
            +
                        hidden_states,
         | 
| 713 | 
            +
                        freqs_cis=freqs_cis,
         | 
| 714 | 
            +
                        encoder_hidden_states=encoder_hidden_states,
         | 
| 715 | 
            +
                        attention_mask=attention_mask,
         | 
| 716 | 
            +
                        skip_layer_mask=skip_layer_mask,
         | 
| 717 | 
            +
                        skip_layer_strategy=skip_layer_strategy,
         | 
| 718 | 
            +
                        **cross_attention_kwargs,
         | 
| 719 | 
            +
                    )
         | 
| 720 | 
            +
             | 
| 721 | 
            +
                def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
         | 
| 722 | 
            +
                    r"""
         | 
| 723 | 
            +
                    Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
         | 
| 724 | 
            +
                    is the number of heads initialized while constructing the `Attention` class.
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                    Args:
         | 
| 727 | 
            +
                        tensor (`torch.Tensor`): The tensor to reshape.
         | 
| 728 | 
            +
             | 
| 729 | 
            +
                    Returns:
         | 
| 730 | 
            +
                        `torch.Tensor`: The reshaped tensor.
         | 
| 731 | 
            +
                    """
         | 
| 732 | 
            +
                    head_size = self.heads
         | 
| 733 | 
            +
                    batch_size, seq_len, dim = tensor.shape
         | 
| 734 | 
            +
                    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
         | 
| 735 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3).reshape(
         | 
| 736 | 
            +
                        batch_size // head_size, seq_len, dim * head_size
         | 
| 737 | 
            +
                    )
         | 
| 738 | 
            +
                    return tensor
         | 
| 739 | 
            +
             | 
| 740 | 
            +
                def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
         | 
| 741 | 
            +
                    r"""
         | 
| 742 | 
            +
                    Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
         | 
| 743 | 
            +
                    the number of heads initialized while constructing the `Attention` class.
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                    Args:
         | 
| 746 | 
            +
                        tensor (`torch.Tensor`): The tensor to reshape.
         | 
| 747 | 
            +
                        out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
         | 
| 748 | 
            +
                            reshaped to `[batch_size * heads, seq_len, dim // heads]`.
         | 
| 749 | 
            +
             | 
| 750 | 
            +
                    Returns:
         | 
| 751 | 
            +
                        `torch.Tensor`: The reshaped tensor.
         | 
| 752 | 
            +
                    """
         | 
| 753 | 
            +
             | 
| 754 | 
            +
                    head_size = self.heads
         | 
| 755 | 
            +
                    if tensor.ndim == 3:
         | 
| 756 | 
            +
                        batch_size, seq_len, dim = tensor.shape
         | 
| 757 | 
            +
                        extra_dim = 1
         | 
| 758 | 
            +
                    else:
         | 
| 759 | 
            +
                        batch_size, extra_dim, seq_len, dim = tensor.shape
         | 
| 760 | 
            +
                    tensor = tensor.reshape(
         | 
| 761 | 
            +
                        batch_size, seq_len * extra_dim, head_size, dim // head_size
         | 
| 762 | 
            +
                    )
         | 
| 763 | 
            +
                    tensor = tensor.permute(0, 2, 1, 3)
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                    if out_dim == 3:
         | 
| 766 | 
            +
                        tensor = tensor.reshape(
         | 
| 767 | 
            +
                            batch_size * head_size, seq_len * extra_dim, dim // head_size
         | 
| 768 | 
            +
                        )
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                    return tensor
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                def get_attention_scores(
         | 
| 773 | 
            +
                    self,
         | 
| 774 | 
            +
                    query: torch.Tensor,
         | 
| 775 | 
            +
                    key: torch.Tensor,
         | 
| 776 | 
            +
                    attention_mask: torch.Tensor = None,
         | 
| 777 | 
            +
                ) -> torch.Tensor:
         | 
| 778 | 
            +
                    r"""
         | 
| 779 | 
            +
                    Compute the attention scores.
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                    Args:
         | 
| 782 | 
            +
                        query (`torch.Tensor`): The query tensor.
         | 
| 783 | 
            +
                        key (`torch.Tensor`): The key tensor.
         | 
| 784 | 
            +
                        attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
         | 
| 785 | 
            +
             | 
| 786 | 
            +
                    Returns:
         | 
| 787 | 
            +
                        `torch.Tensor`: The attention probabilities/scores.
         | 
| 788 | 
            +
                    """
         | 
| 789 | 
            +
                    dtype = query.dtype
         | 
| 790 | 
            +
                    if self.upcast_attention:
         | 
| 791 | 
            +
                        query = query.float()
         | 
| 792 | 
            +
                        key = key.float()
         | 
| 793 | 
            +
             | 
| 794 | 
            +
                    if attention_mask is None:
         | 
| 795 | 
            +
                        baddbmm_input = torch.empty(
         | 
| 796 | 
            +
                            query.shape[0],
         | 
| 797 | 
            +
                            query.shape[1],
         | 
| 798 | 
            +
                            key.shape[1],
         | 
| 799 | 
            +
                            dtype=query.dtype,
         | 
| 800 | 
            +
                            device=query.device,
         | 
| 801 | 
            +
                        )
         | 
| 802 | 
            +
                        beta = 0
         | 
| 803 | 
            +
                    else:
         | 
| 804 | 
            +
                        baddbmm_input = attention_mask
         | 
| 805 | 
            +
                        beta = 1
         | 
| 806 | 
            +
             | 
| 807 | 
            +
                    attention_scores = torch.baddbmm(
         | 
| 808 | 
            +
                        baddbmm_input,
         | 
| 809 | 
            +
                        query,
         | 
| 810 | 
            +
                        key.transpose(-1, -2),
         | 
| 811 | 
            +
                        beta=beta,
         | 
| 812 | 
            +
                        alpha=self.scale,
         | 
| 813 | 
            +
                    )
         | 
| 814 | 
            +
                    del baddbmm_input
         | 
| 815 | 
            +
             | 
| 816 | 
            +
                    if self.upcast_softmax:
         | 
| 817 | 
            +
                        attention_scores = attention_scores.float()
         | 
| 818 | 
            +
             | 
| 819 | 
            +
                    attention_probs = attention_scores.softmax(dim=-1)
         | 
| 820 | 
            +
                    del attention_scores
         | 
| 821 | 
            +
             | 
| 822 | 
            +
                    attention_probs = attention_probs.to(dtype)
         | 
| 823 | 
            +
             | 
| 824 | 
            +
                    return attention_probs
         | 
| 825 | 
            +
             | 
| 826 | 
            +
                def prepare_attention_mask(
         | 
| 827 | 
            +
                    self,
         | 
| 828 | 
            +
                    attention_mask: torch.Tensor,
         | 
| 829 | 
            +
                    target_length: int,
         | 
| 830 | 
            +
                    batch_size: int,
         | 
| 831 | 
            +
                    out_dim: int = 3,
         | 
| 832 | 
            +
                ) -> torch.Tensor:
         | 
| 833 | 
            +
                    r"""
         | 
| 834 | 
            +
                    Prepare the attention mask for the attention computation.
         | 
| 835 | 
            +
             | 
| 836 | 
            +
                    Args:
         | 
| 837 | 
            +
                        attention_mask (`torch.Tensor`):
         | 
| 838 | 
            +
                            The attention mask to prepare.
         | 
| 839 | 
            +
                        target_length (`int`):
         | 
| 840 | 
            +
                            The target length of the attention mask. This is the length of the attention mask after padding.
         | 
| 841 | 
            +
                        batch_size (`int`):
         | 
| 842 | 
            +
                            The batch size, which is used to repeat the attention mask.
         | 
| 843 | 
            +
                        out_dim (`int`, *optional*, defaults to `3`):
         | 
| 844 | 
            +
                            The output dimension of the attention mask. Can be either `3` or `4`.
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                    Returns:
         | 
| 847 | 
            +
                        `torch.Tensor`: The prepared attention mask.
         | 
| 848 | 
            +
                    """
         | 
| 849 | 
            +
                    head_size = self.heads
         | 
| 850 | 
            +
                    if attention_mask is None:
         | 
| 851 | 
            +
                        return attention_mask
         | 
| 852 | 
            +
             | 
| 853 | 
            +
                    current_length: int = attention_mask.shape[-1]
         | 
| 854 | 
            +
                    if current_length != target_length:
         | 
| 855 | 
            +
                        if attention_mask.device.type == "mps":
         | 
| 856 | 
            +
                            # HACK: MPS: Does not support padding by greater than dimension of input tensor.
         | 
| 857 | 
            +
                            # Instead, we can manually construct the padding tensor.
         | 
| 858 | 
            +
                            padding_shape = (
         | 
| 859 | 
            +
                                attention_mask.shape[0],
         | 
| 860 | 
            +
                                attention_mask.shape[1],
         | 
| 861 | 
            +
                                target_length,
         | 
| 862 | 
            +
                            )
         | 
| 863 | 
            +
                            padding = torch.zeros(
         | 
| 864 | 
            +
                                padding_shape,
         | 
| 865 | 
            +
                                dtype=attention_mask.dtype,
         | 
| 866 | 
            +
                                device=attention_mask.device,
         | 
| 867 | 
            +
                            )
         | 
| 868 | 
            +
                            attention_mask = torch.cat([attention_mask, padding], dim=2)
         | 
| 869 | 
            +
                        else:
         | 
| 870 | 
            +
                            # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
         | 
| 871 | 
            +
                            #       we want to instead pad by (0, remaining_length), where remaining_length is:
         | 
| 872 | 
            +
                            #       remaining_length: int = target_length - current_length
         | 
| 873 | 
            +
                            # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
         | 
| 874 | 
            +
                            attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
         | 
| 875 | 
            +
             | 
| 876 | 
            +
                    if out_dim == 3:
         | 
| 877 | 
            +
                        if attention_mask.shape[0] < batch_size * head_size:
         | 
| 878 | 
            +
                            attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
         | 
| 879 | 
            +
                    elif out_dim == 4:
         | 
| 880 | 
            +
                        attention_mask = attention_mask.unsqueeze(1)
         | 
| 881 | 
            +
                        attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
         | 
| 882 | 
            +
             | 
| 883 | 
            +
                    return attention_mask
         | 
| 884 | 
            +
             | 
| 885 | 
            +
                def norm_encoder_hidden_states(
         | 
| 886 | 
            +
                    self, encoder_hidden_states: torch.Tensor
         | 
| 887 | 
            +
                ) -> torch.Tensor:
         | 
| 888 | 
            +
                    r"""
         | 
| 889 | 
            +
                    Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
         | 
| 890 | 
            +
                    `Attention` class.
         | 
| 891 | 
            +
             | 
| 892 | 
            +
                    Args:
         | 
| 893 | 
            +
                        encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
         | 
| 894 | 
            +
             | 
| 895 | 
            +
                    Returns:
         | 
| 896 | 
            +
                        `torch.Tensor`: The normalized encoder hidden states.
         | 
| 897 | 
            +
                    """
         | 
| 898 | 
            +
                    assert (
         | 
| 899 | 
            +
                        self.norm_cross is not None
         | 
| 900 | 
            +
                    ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
         | 
| 901 | 
            +
             | 
| 902 | 
            +
                    if isinstance(self.norm_cross, nn.LayerNorm):
         | 
| 903 | 
            +
                        encoder_hidden_states = self.norm_cross(encoder_hidden_states)
         | 
| 904 | 
            +
                    elif isinstance(self.norm_cross, nn.GroupNorm):
         | 
| 905 | 
            +
                        # Group norm norms along the channels dimension and expects
         | 
| 906 | 
            +
                        # input to be in the shape of (N, C, *). In this case, we want
         | 
| 907 | 
            +
                        # to norm along the hidden dimension, so we need to move
         | 
| 908 | 
            +
                        # (batch_size, sequence_length, hidden_size) ->
         | 
| 909 | 
            +
                        # (batch_size, hidden_size, sequence_length)
         | 
| 910 | 
            +
                        encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
         | 
| 911 | 
            +
                        encoder_hidden_states = self.norm_cross(encoder_hidden_states)
         | 
| 912 | 
            +
                        encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
         | 
| 913 | 
            +
                    else:
         | 
| 914 | 
            +
                        assert False
         | 
| 915 | 
            +
             | 
| 916 | 
            +
                    return encoder_hidden_states
         | 
| 917 | 
            +
             | 
| 918 | 
            +
                @staticmethod
         | 
| 919 | 
            +
                def apply_rotary_emb(
         | 
| 920 | 
            +
                    input_tensor: torch.Tensor,
         | 
| 921 | 
            +
                    freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
         | 
| 922 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor]:
         | 
| 923 | 
            +
                    cos_freqs = freqs_cis[0]
         | 
| 924 | 
            +
                    sin_freqs = freqs_cis[1]
         | 
| 925 | 
            +
             | 
| 926 | 
            +
                    t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
         | 
| 927 | 
            +
                    t1, t2 = t_dup.unbind(dim=-1)
         | 
| 928 | 
            +
                    t_dup = torch.stack((-t2, t1), dim=-1)
         | 
| 929 | 
            +
                    input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                    out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
         | 
| 932 | 
            +
             | 
| 933 | 
            +
                    return out
         | 
| 934 | 
            +
             | 
| 935 | 
            +
             | 
| 936 | 
            +
            class AttnProcessor2_0:
         | 
| 937 | 
            +
                r"""
         | 
| 938 | 
            +
                Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
         | 
| 939 | 
            +
                """
         | 
| 940 | 
            +
             | 
| 941 | 
            +
                def __init__(self):
         | 
| 942 | 
            +
                    pass
         | 
| 943 | 
            +
             | 
| 944 | 
            +
                def __call__(
         | 
| 945 | 
            +
                    self,
         | 
| 946 | 
            +
                    attn: Attention,
         | 
| 947 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 948 | 
            +
                    freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
         | 
| 949 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 950 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 951 | 
            +
                    temb: Optional[torch.FloatTensor] = None,
         | 
| 952 | 
            +
                    skip_layer_mask: Optional[torch.FloatTensor] = None,
         | 
| 953 | 
            +
                    skip_layer_strategy: Optional[SkipLayerStrategy] = None,
         | 
| 954 | 
            +
                    *args,
         | 
| 955 | 
            +
                    **kwargs,
         | 
| 956 | 
            +
                ) -> torch.FloatTensor:
         | 
| 957 | 
            +
                    if len(args) > 0 or kwargs.get("scale", None) is not None:
         | 
| 958 | 
            +
                        deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
         | 
| 959 | 
            +
                        deprecate("scale", "1.0.0", deprecation_message)
         | 
| 960 | 
            +
             | 
| 961 | 
            +
                    residual = hidden_states
         | 
| 962 | 
            +
                    if attn.spatial_norm is not None:
         | 
| 963 | 
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 964 | 
            +
             | 
| 965 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 966 | 
            +
             | 
| 967 | 
            +
                    if input_ndim == 4:
         | 
| 968 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 969 | 
            +
                        hidden_states = hidden_states.view(
         | 
| 970 | 
            +
                            batch_size, channel, height * width
         | 
| 971 | 
            +
                        ).transpose(1, 2)
         | 
| 972 | 
            +
             | 
| 973 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 974 | 
            +
                        hidden_states.shape
         | 
| 975 | 
            +
                        if encoder_hidden_states is None
         | 
| 976 | 
            +
                        else encoder_hidden_states.shape
         | 
| 977 | 
            +
                    )
         | 
| 978 | 
            +
             | 
| 979 | 
            +
                    if skip_layer_mask is not None:
         | 
| 980 | 
            +
                        skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1)
         | 
| 981 | 
            +
             | 
| 982 | 
            +
                    if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
         | 
| 983 | 
            +
                        attention_mask = attn.prepare_attention_mask(
         | 
| 984 | 
            +
                            attention_mask, sequence_length, batch_size
         | 
| 985 | 
            +
                        )
         | 
| 986 | 
            +
                        # scaled_dot_product_attention expects attention_mask shape to be
         | 
| 987 | 
            +
                        # (batch, heads, source_length, target_length)
         | 
| 988 | 
            +
                        attention_mask = attention_mask.view(
         | 
| 989 | 
            +
                            batch_size, attn.heads, -1, attention_mask.shape[-1]
         | 
| 990 | 
            +
                        )
         | 
| 991 | 
            +
             | 
| 992 | 
            +
                    if attn.group_norm is not None:
         | 
| 993 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
         | 
| 994 | 
            +
                            1, 2
         | 
| 995 | 
            +
                        )
         | 
| 996 | 
            +
             | 
| 997 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 998 | 
            +
                    query = attn.q_norm(query)
         | 
| 999 | 
            +
             | 
| 1000 | 
            +
                    if encoder_hidden_states is not None:
         | 
| 1001 | 
            +
                        if attn.norm_cross:
         | 
| 1002 | 
            +
                            encoder_hidden_states = attn.norm_encoder_hidden_states(
         | 
| 1003 | 
            +
                                encoder_hidden_states
         | 
| 1004 | 
            +
                            )
         | 
| 1005 | 
            +
                        key = attn.to_k(encoder_hidden_states)
         | 
| 1006 | 
            +
                        key = attn.k_norm(key)
         | 
| 1007 | 
            +
                    else:  # if no context provided do self-attention
         | 
| 1008 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 1009 | 
            +
                        key = attn.to_k(hidden_states)
         | 
| 1010 | 
            +
                        key = attn.k_norm(key)
         | 
| 1011 | 
            +
                        if attn.use_rope:
         | 
| 1012 | 
            +
                            key = attn.apply_rotary_emb(key, freqs_cis)
         | 
| 1013 | 
            +
                            query = attn.apply_rotary_emb(query, freqs_cis)
         | 
| 1014 | 
            +
             | 
| 1015 | 
            +
                    value = attn.to_v(encoder_hidden_states)
         | 
| 1016 | 
            +
                    value_for_stg = value
         | 
| 1017 | 
            +
             | 
| 1018 | 
            +
                    inner_dim = key.shape[-1]
         | 
| 1019 | 
            +
                    head_dim = inner_dim // attn.heads
         | 
| 1020 | 
            +
             | 
| 1021 | 
            +
                    query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 1022 | 
            +
                    key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 1023 | 
            +
                    value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
         | 
| 1024 | 
            +
             | 
| 1025 | 
            +
                    # the output of sdp = (batch, num_heads, seq_len, head_dim)
         | 
| 1026 | 
            +
             | 
| 1027 | 
            +
                    if attn.use_tpu_flash_attention:  # use tpu attention offload 'flash attention'
         | 
| 1028 | 
            +
                        q_segment_indexes = None
         | 
| 1029 | 
            +
                        if (
         | 
| 1030 | 
            +
                            attention_mask is not None
         | 
| 1031 | 
            +
                        ):  # if mask is required need to tune both segmenIds fields
         | 
| 1032 | 
            +
                            # attention_mask = torch.squeeze(attention_mask).to(torch.float32)
         | 
| 1033 | 
            +
                            attention_mask = attention_mask.to(torch.float32)
         | 
| 1034 | 
            +
                            q_segment_indexes = torch.ones(
         | 
| 1035 | 
            +
                                batch_size, query.shape[2], device=query.device, dtype=torch.float32
         | 
| 1036 | 
            +
                            )
         | 
| 1037 | 
            +
                            assert (
         | 
| 1038 | 
            +
                                attention_mask.shape[1] == key.shape[2]
         | 
| 1039 | 
            +
                            ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
         | 
| 1040 | 
            +
             | 
| 1041 | 
            +
                        assert (
         | 
| 1042 | 
            +
                            query.shape[2] % 128 == 0
         | 
| 1043 | 
            +
                        ), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]"
         | 
| 1044 | 
            +
                        assert (
         | 
| 1045 | 
            +
                            key.shape[2] % 128 == 0
         | 
| 1046 | 
            +
                        ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]"
         | 
| 1047 | 
            +
             | 
| 1048 | 
            +
                        # run the TPU kernel implemented in jax with pallas
         | 
| 1049 | 
            +
                        hidden_states_a = flash_attention(
         | 
| 1050 | 
            +
                            q=query,
         | 
| 1051 | 
            +
                            k=key,
         | 
| 1052 | 
            +
                            v=value,
         | 
| 1053 | 
            +
                            q_segment_ids=q_segment_indexes,
         | 
| 1054 | 
            +
                            kv_segment_ids=attention_mask,
         | 
| 1055 | 
            +
                            sm_scale=attn.scale,
         | 
| 1056 | 
            +
                        )
         | 
| 1057 | 
            +
                    else:
         | 
| 1058 | 
            +
                        hidden_states_a = F.scaled_dot_product_attention(
         | 
| 1059 | 
            +
                            query,
         | 
| 1060 | 
            +
                            key,
         | 
| 1061 | 
            +
                            value,
         | 
| 1062 | 
            +
                            attn_mask=attention_mask,
         | 
| 1063 | 
            +
                            dropout_p=0.0,
         | 
| 1064 | 
            +
                            is_causal=False,
         | 
| 1065 | 
            +
                        )
         | 
| 1066 | 
            +
             | 
| 1067 | 
            +
                    hidden_states_a = hidden_states_a.transpose(1, 2).reshape(
         | 
| 1068 | 
            +
                        batch_size, -1, attn.heads * head_dim
         | 
| 1069 | 
            +
                    )
         | 
| 1070 | 
            +
                    hidden_states_a = hidden_states_a.to(query.dtype)
         | 
| 1071 | 
            +
             | 
| 1072 | 
            +
                    if (
         | 
| 1073 | 
            +
                        skip_layer_mask is not None
         | 
| 1074 | 
            +
                        and skip_layer_strategy == SkipLayerStrategy.AttentionSkip
         | 
| 1075 | 
            +
                    ):
         | 
| 1076 | 
            +
                        hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (
         | 
| 1077 | 
            +
                            1.0 - skip_layer_mask
         | 
| 1078 | 
            +
                        )
         | 
| 1079 | 
            +
                    elif (
         | 
| 1080 | 
            +
                        skip_layer_mask is not None
         | 
| 1081 | 
            +
                        and skip_layer_strategy == SkipLayerStrategy.AttentionValues
         | 
| 1082 | 
            +
                    ):
         | 
| 1083 | 
            +
                        hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (
         | 
| 1084 | 
            +
                            1.0 - skip_layer_mask
         | 
| 1085 | 
            +
                        )
         | 
| 1086 | 
            +
                    else:
         | 
| 1087 | 
            +
                        hidden_states = hidden_states_a
         | 
| 1088 | 
            +
             | 
| 1089 | 
            +
                    # linear proj
         | 
| 1090 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 1091 | 
            +
                    # dropout
         | 
| 1092 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 1093 | 
            +
             | 
| 1094 | 
            +
                    if input_ndim == 4:
         | 
| 1095 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(
         | 
| 1096 | 
            +
                            batch_size, channel, height, width
         | 
| 1097 | 
            +
                        )
         | 
| 1098 | 
            +
                        if (
         | 
| 1099 | 
            +
                            skip_layer_mask is not None
         | 
| 1100 | 
            +
                            and skip_layer_strategy == SkipLayerStrategy.Residual
         | 
| 1101 | 
            +
                        ):
         | 
| 1102 | 
            +
                            skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1)
         | 
| 1103 | 
            +
             | 
| 1104 | 
            +
                    if attn.residual_connection:
         | 
| 1105 | 
            +
                        if (
         | 
| 1106 | 
            +
                            skip_layer_mask is not None
         | 
| 1107 | 
            +
                            and skip_layer_strategy == SkipLayerStrategy.Residual
         | 
| 1108 | 
            +
                        ):
         | 
| 1109 | 
            +
                            hidden_states = hidden_states + residual * skip_layer_mask
         | 
| 1110 | 
            +
                        else:
         | 
| 1111 | 
            +
                            hidden_states = hidden_states + residual
         | 
| 1112 | 
            +
             | 
| 1113 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 1114 | 
            +
             | 
| 1115 | 
            +
                    return hidden_states
         | 
| 1116 | 
            +
             | 
| 1117 | 
            +
             | 
| 1118 | 
            +
            class AttnProcessor:
         | 
| 1119 | 
            +
                r"""
         | 
| 1120 | 
            +
                Default processor for performing attention-related computations.
         | 
| 1121 | 
            +
                """
         | 
| 1122 | 
            +
             | 
| 1123 | 
            +
                def __call__(
         | 
| 1124 | 
            +
                    self,
         | 
| 1125 | 
            +
                    attn: Attention,
         | 
| 1126 | 
            +
                    hidden_states: torch.FloatTensor,
         | 
| 1127 | 
            +
                    encoder_hidden_states: Optional[torch.FloatTensor] = None,
         | 
| 1128 | 
            +
                    attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 1129 | 
            +
                    temb: Optional[torch.FloatTensor] = None,
         | 
| 1130 | 
            +
                    *args,
         | 
| 1131 | 
            +
                    **kwargs,
         | 
| 1132 | 
            +
                ) -> torch.Tensor:
         | 
| 1133 | 
            +
                    if len(args) > 0 or kwargs.get("scale", None) is not None:
         | 
| 1134 | 
            +
                        deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
         | 
| 1135 | 
            +
                        deprecate("scale", "1.0.0", deprecation_message)
         | 
| 1136 | 
            +
             | 
| 1137 | 
            +
                    residual = hidden_states
         | 
| 1138 | 
            +
             | 
| 1139 | 
            +
                    if attn.spatial_norm is not None:
         | 
| 1140 | 
            +
                        hidden_states = attn.spatial_norm(hidden_states, temb)
         | 
| 1141 | 
            +
             | 
| 1142 | 
            +
                    input_ndim = hidden_states.ndim
         | 
| 1143 | 
            +
             | 
| 1144 | 
            +
                    if input_ndim == 4:
         | 
| 1145 | 
            +
                        batch_size, channel, height, width = hidden_states.shape
         | 
| 1146 | 
            +
                        hidden_states = hidden_states.view(
         | 
| 1147 | 
            +
                            batch_size, channel, height * width
         | 
| 1148 | 
            +
                        ).transpose(1, 2)
         | 
| 1149 | 
            +
             | 
| 1150 | 
            +
                    batch_size, sequence_length, _ = (
         | 
| 1151 | 
            +
                        hidden_states.shape
         | 
| 1152 | 
            +
                        if encoder_hidden_states is None
         | 
| 1153 | 
            +
                        else encoder_hidden_states.shape
         | 
| 1154 | 
            +
                    )
         | 
| 1155 | 
            +
                    attention_mask = attn.prepare_attention_mask(
         | 
| 1156 | 
            +
                        attention_mask, sequence_length, batch_size
         | 
| 1157 | 
            +
                    )
         | 
| 1158 | 
            +
             | 
| 1159 | 
            +
                    if attn.group_norm is not None:
         | 
| 1160 | 
            +
                        hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
         | 
| 1161 | 
            +
                            1, 2
         | 
| 1162 | 
            +
                        )
         | 
| 1163 | 
            +
             | 
| 1164 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 1165 | 
            +
             | 
| 1166 | 
            +
                    if encoder_hidden_states is None:
         | 
| 1167 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 1168 | 
            +
                    elif attn.norm_cross:
         | 
| 1169 | 
            +
                        encoder_hidden_states = attn.norm_encoder_hidden_states(
         | 
| 1170 | 
            +
                            encoder_hidden_states
         | 
| 1171 | 
            +
                        )
         | 
| 1172 | 
            +
             | 
| 1173 | 
            +
                    key = attn.to_k(encoder_hidden_states)
         | 
| 1174 | 
            +
                    value = attn.to_v(encoder_hidden_states)
         | 
| 1175 | 
            +
             | 
| 1176 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 1177 | 
            +
                    key = attn.head_to_batch_dim(key)
         | 
| 1178 | 
            +
                    value = attn.head_to_batch_dim(value)
         | 
| 1179 | 
            +
             | 
| 1180 | 
            +
                    query = attn.q_norm(query)
         | 
| 1181 | 
            +
                    key = attn.k_norm(key)
         | 
| 1182 | 
            +
             | 
| 1183 | 
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         | 
| 1184 | 
            +
                    hidden_states = torch.bmm(attention_probs, value)
         | 
| 1185 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 1186 | 
            +
             | 
| 1187 | 
            +
                    # linear proj
         | 
| 1188 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 1189 | 
            +
                    # dropout
         | 
| 1190 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 1191 | 
            +
             | 
| 1192 | 
            +
                    if input_ndim == 4:
         | 
| 1193 | 
            +
                        hidden_states = hidden_states.transpose(-1, -2).reshape(
         | 
| 1194 | 
            +
                            batch_size, channel, height, width
         | 
| 1195 | 
            +
                        )
         | 
| 1196 | 
            +
             | 
| 1197 | 
            +
                    if attn.residual_connection:
         | 
| 1198 | 
            +
                        hidden_states = hidden_states + residual
         | 
| 1199 | 
            +
             | 
| 1200 | 
            +
                    hidden_states = hidden_states / attn.rescale_output_factor
         | 
| 1201 | 
            +
             | 
| 1202 | 
            +
                    return hidden_states
         | 
| 1203 | 
            +
             | 
| 1204 | 
            +
             | 
| 1205 | 
            +
            class FeedForward(nn.Module):
         | 
| 1206 | 
            +
                r"""
         | 
| 1207 | 
            +
                A feed-forward layer.
         | 
| 1208 | 
            +
             | 
| 1209 | 
            +
                Parameters:
         | 
| 1210 | 
            +
                    dim (`int`): The number of channels in the input.
         | 
| 1211 | 
            +
                    dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
         | 
| 1212 | 
            +
                    mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
         | 
| 1213 | 
            +
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         | 
| 1214 | 
            +
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         | 
| 1215 | 
            +
                    final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
         | 
| 1216 | 
            +
                    bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
         | 
| 1217 | 
            +
                """
         | 
| 1218 | 
            +
             | 
| 1219 | 
            +
                def __init__(
         | 
| 1220 | 
            +
                    self,
         | 
| 1221 | 
            +
                    dim: int,
         | 
| 1222 | 
            +
                    dim_out: Optional[int] = None,
         | 
| 1223 | 
            +
                    mult: int = 4,
         | 
| 1224 | 
            +
                    dropout: float = 0.0,
         | 
| 1225 | 
            +
                    activation_fn: str = "geglu",
         | 
| 1226 | 
            +
                    final_dropout: bool = False,
         | 
| 1227 | 
            +
                    inner_dim=None,
         | 
| 1228 | 
            +
                    bias: bool = True,
         | 
| 1229 | 
            +
                ):
         | 
| 1230 | 
            +
                    super().__init__()
         | 
| 1231 | 
            +
                    if inner_dim is None:
         | 
| 1232 | 
            +
                        inner_dim = int(dim * mult)
         | 
| 1233 | 
            +
                    dim_out = dim_out if dim_out is not None else dim
         | 
| 1234 | 
            +
                    linear_cls = nn.Linear
         | 
| 1235 | 
            +
             | 
| 1236 | 
            +
                    if activation_fn == "gelu":
         | 
| 1237 | 
            +
                        act_fn = GELU(dim, inner_dim, bias=bias)
         | 
| 1238 | 
            +
                    elif activation_fn == "gelu-approximate":
         | 
| 1239 | 
            +
                        act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
         | 
| 1240 | 
            +
                    elif activation_fn == "geglu":
         | 
| 1241 | 
            +
                        act_fn = GEGLU(dim, inner_dim, bias=bias)
         | 
| 1242 | 
            +
                    elif activation_fn == "geglu-approximate":
         | 
| 1243 | 
            +
                        act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
         | 
| 1244 | 
            +
                    else:
         | 
| 1245 | 
            +
                        raise ValueError(f"Unsupported activation function: {activation_fn}")
         | 
| 1246 | 
            +
             | 
| 1247 | 
            +
                    self.net = nn.ModuleList([])
         | 
| 1248 | 
            +
                    # project in
         | 
| 1249 | 
            +
                    self.net.append(act_fn)
         | 
| 1250 | 
            +
                    # project dropout
         | 
| 1251 | 
            +
                    self.net.append(nn.Dropout(dropout))
         | 
| 1252 | 
            +
                    # project out
         | 
| 1253 | 
            +
                    self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
         | 
| 1254 | 
            +
                    # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
         | 
| 1255 | 
            +
                    if final_dropout:
         | 
| 1256 | 
            +
                        self.net.append(nn.Dropout(dropout))
         | 
| 1257 | 
            +
             | 
| 1258 | 
            +
                def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
         | 
| 1259 | 
            +
                    compatible_cls = (GEGLU, LoRACompatibleLinear)
         | 
| 1260 | 
            +
                    for module in self.net:
         | 
| 1261 | 
            +
                        if isinstance(module, compatible_cls):
         | 
| 1262 | 
            +
                            hidden_states = module(hidden_states, scale)
         | 
| 1263 | 
            +
                        else:
         | 
| 1264 | 
            +
                            hidden_states = module(hidden_states)
         | 
| 1265 | 
            +
                    return hidden_states
         | 
    	
        ltx_video/models/transformers/embeddings.py
    ADDED
    
    | @@ -0,0 +1,129 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
            from torch import nn
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def get_timestep_embedding(
         | 
| 11 | 
            +
                timesteps: torch.Tensor,
         | 
| 12 | 
            +
                embedding_dim: int,
         | 
| 13 | 
            +
                flip_sin_to_cos: bool = False,
         | 
| 14 | 
            +
                downscale_freq_shift: float = 1,
         | 
| 15 | 
            +
                scale: float = 1,
         | 
| 16 | 
            +
                max_period: int = 10000,
         | 
| 17 | 
            +
            ):
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         | 
| 22 | 
            +
                                  These may be fractional.
         | 
| 23 | 
            +
                :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
         | 
| 24 | 
            +
                embeddings. :return: an [N x dim] Tensor of positional embeddings.
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                half_dim = embedding_dim // 2
         | 
| 29 | 
            +
                exponent = -math.log(max_period) * torch.arange(
         | 
| 30 | 
            +
                    start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
         | 
| 31 | 
            +
                )
         | 
| 32 | 
            +
                exponent = exponent / (half_dim - downscale_freq_shift)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                emb = torch.exp(exponent)
         | 
| 35 | 
            +
                emb = timesteps[:, None].float() * emb[None, :]
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                # scale embeddings
         | 
| 38 | 
            +
                emb = scale * emb
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                # concat sine and cosine embeddings
         | 
| 41 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                # flip sine and cosine embeddings
         | 
| 44 | 
            +
                if flip_sin_to_cos:
         | 
| 45 | 
            +
                    emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                # zero pad
         | 
| 48 | 
            +
                if embedding_dim % 2 == 1:
         | 
| 49 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
         | 
| 50 | 
            +
                return emb
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
         | 
| 54 | 
            +
                """
         | 
| 55 | 
            +
                grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
         | 
| 56 | 
            +
                [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
         | 
| 57 | 
            +
                """
         | 
| 58 | 
            +
                grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
         | 
| 59 | 
            +
                grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
         | 
| 60 | 
            +
                grid = grid.reshape([3, 1, w, h, f])
         | 
| 61 | 
            +
                pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
         | 
| 62 | 
            +
                pos_embed = pos_embed.transpose(1, 0, 2, 3)
         | 
| 63 | 
            +
                return rearrange(pos_embed, "h w f c -> (f h w) c")
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
         | 
| 67 | 
            +
                if embed_dim % 3 != 0:
         | 
| 68 | 
            +
                    raise ValueError("embed_dim must be divisible by 3")
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                # use half of dimensions to encode grid_h
         | 
| 71 | 
            +
                emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0])  # (H*W*T, D/3)
         | 
| 72 | 
            +
                emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1])  # (H*W*T, D/3)
         | 
| 73 | 
            +
                emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2])  # (H*W*T, D/3)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1)  # (H*W*T, D)
         | 
| 76 | 
            +
                return emb
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
         | 
| 80 | 
            +
                """
         | 
| 81 | 
            +
                embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
         | 
| 82 | 
            +
                """
         | 
| 83 | 
            +
                if embed_dim % 2 != 0:
         | 
| 84 | 
            +
                    raise ValueError("embed_dim must be divisible by 2")
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                omega = np.arange(embed_dim // 2, dtype=np.float64)
         | 
| 87 | 
            +
                omega /= embed_dim / 2.0
         | 
| 88 | 
            +
                omega = 1.0 / 10000**omega  # (D/2,)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                pos_shape = pos.shape
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                pos = pos.reshape(-1)
         | 
| 93 | 
            +
                out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product
         | 
| 94 | 
            +
                out = out.reshape([*pos_shape, -1])[0]
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                emb_sin = np.sin(out)  # (M, D/2)
         | 
| 97 | 
            +
                emb_cos = np.cos(out)  # (M, D/2)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                emb = np.concatenate([emb_sin, emb_cos], axis=-1)  # (M, D)
         | 
| 100 | 
            +
                return emb
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            class SinusoidalPositionalEmbedding(nn.Module):
         | 
| 104 | 
            +
                """Apply positional information to a sequence of embeddings.
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
         | 
| 107 | 
            +
                them
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                Args:
         | 
| 110 | 
            +
                    embed_dim: (int): Dimension of the positional embedding.
         | 
| 111 | 
            +
                    max_seq_length: Maximum sequence length to apply positional embeddings
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                """
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def __init__(self, embed_dim: int, max_seq_length: int = 32):
         | 
| 116 | 
            +
                    super().__init__()
         | 
| 117 | 
            +
                    position = torch.arange(max_seq_length).unsqueeze(1)
         | 
| 118 | 
            +
                    div_term = torch.exp(
         | 
| 119 | 
            +
                        torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
                    pe = torch.zeros(1, max_seq_length, embed_dim)
         | 
| 122 | 
            +
                    pe[0, :, 0::2] = torch.sin(position * div_term)
         | 
| 123 | 
            +
                    pe[0, :, 1::2] = torch.cos(position * div_term)
         | 
| 124 | 
            +
                    self.register_buffer("pe", pe)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def forward(self, x):
         | 
| 127 | 
            +
                    _, seq_length, _ = x.shape
         | 
| 128 | 
            +
                    x = x + self.pe[:, :seq_length]
         | 
| 129 | 
            +
                    return x
         | 
    	
        ltx_video/models/transformers/symmetric_patchifier.py
    ADDED
    
    | @@ -0,0 +1,84 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from abc import ABC, abstractmethod
         | 
| 2 | 
            +
            from typing import Tuple
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from diffusers.configuration_utils import ConfigMixin
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
            from torch import Tensor
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class Patchifier(ConfigMixin, ABC):
         | 
| 11 | 
            +
                def __init__(self, patch_size: int):
         | 
| 12 | 
            +
                    super().__init__()
         | 
| 13 | 
            +
                    self._patch_size = (1, patch_size, patch_size)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                @abstractmethod
         | 
| 16 | 
            +
                def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
         | 
| 17 | 
            +
                    raise NotImplementedError("Patchify method not implemented")
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                @abstractmethod
         | 
| 20 | 
            +
                def unpatchify(
         | 
| 21 | 
            +
                    self,
         | 
| 22 | 
            +
                    latents: Tensor,
         | 
| 23 | 
            +
                    output_height: int,
         | 
| 24 | 
            +
                    output_width: int,
         | 
| 25 | 
            +
                    out_channels: int,
         | 
| 26 | 
            +
                ) -> Tuple[Tensor, Tensor]:
         | 
| 27 | 
            +
                    pass
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                @property
         | 
| 30 | 
            +
                def patch_size(self):
         | 
| 31 | 
            +
                    return self._patch_size
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def get_latent_coords(
         | 
| 34 | 
            +
                    self, latent_num_frames, latent_height, latent_width, batch_size, device
         | 
| 35 | 
            +
                ):
         | 
| 36 | 
            +
                    """
         | 
| 37 | 
            +
                    Return a tensor of shape [batch_size, 3, num_patches] containing the
         | 
| 38 | 
            +
                        top-left corner latent coordinates of each latent patch.
         | 
| 39 | 
            +
                    The tensor is repeated for each batch element.
         | 
| 40 | 
            +
                    """
         | 
| 41 | 
            +
                    latent_sample_coords = torch.meshgrid(
         | 
| 42 | 
            +
                        torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
         | 
| 43 | 
            +
                        torch.arange(0, latent_height, self._patch_size[1], device=device),
         | 
| 44 | 
            +
                        torch.arange(0, latent_width, self._patch_size[2], device=device),
         | 
| 45 | 
            +
                    )
         | 
| 46 | 
            +
                    latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
         | 
| 47 | 
            +
                    latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
         | 
| 48 | 
            +
                    latent_coords = rearrange(
         | 
| 49 | 
            +
                        latent_coords, "b c f h w -> b c (f h w)", b=batch_size
         | 
| 50 | 
            +
                    )
         | 
| 51 | 
            +
                    return latent_coords
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            class SymmetricPatchifier(Patchifier):
         | 
| 55 | 
            +
                def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
         | 
| 56 | 
            +
                    b, _, f, h, w = latents.shape
         | 
| 57 | 
            +
                    latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
         | 
| 58 | 
            +
                    latents = rearrange(
         | 
| 59 | 
            +
                        latents,
         | 
| 60 | 
            +
                        "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
         | 
| 61 | 
            +
                        p1=self._patch_size[0],
         | 
| 62 | 
            +
                        p2=self._patch_size[1],
         | 
| 63 | 
            +
                        p3=self._patch_size[2],
         | 
| 64 | 
            +
                    )
         | 
| 65 | 
            +
                    return latents, latent_coords
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def unpatchify(
         | 
| 68 | 
            +
                    self,
         | 
| 69 | 
            +
                    latents: Tensor,
         | 
| 70 | 
            +
                    output_height: int,
         | 
| 71 | 
            +
                    output_width: int,
         | 
| 72 | 
            +
                    out_channels: int,
         | 
| 73 | 
            +
                ) -> Tuple[Tensor, Tensor]:
         | 
| 74 | 
            +
                    output_height = output_height // self._patch_size[1]
         | 
| 75 | 
            +
                    output_width = output_width // self._patch_size[2]
         | 
| 76 | 
            +
                    latents = rearrange(
         | 
| 77 | 
            +
                        latents,
         | 
| 78 | 
            +
                        "b (f h w) (c p q) -> b c f (h p) (w q)",
         | 
| 79 | 
            +
                        h=output_height,
         | 
| 80 | 
            +
                        w=output_width,
         | 
| 81 | 
            +
                        p=self._patch_size[1],
         | 
| 82 | 
            +
                        q=self._patch_size[2],
         | 
| 83 | 
            +
                    )
         | 
| 84 | 
            +
                    return latents
         | 
    	
        ltx_video/models/transformers/transformer3d.py
    ADDED
    
    | @@ -0,0 +1,507 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            from dataclasses import dataclass
         | 
| 4 | 
            +
            from typing import Any, Dict, List, Optional, Union
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import json
         | 
| 7 | 
            +
            import glob
         | 
| 8 | 
            +
            from pathlib import Path
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 12 | 
            +
            from diffusers.models.embeddings import PixArtAlphaTextProjection
         | 
| 13 | 
            +
            from diffusers.models.modeling_utils import ModelMixin
         | 
| 14 | 
            +
            from diffusers.models.normalization import AdaLayerNormSingle
         | 
| 15 | 
            +
            from diffusers.utils import BaseOutput, is_torch_version
         | 
| 16 | 
            +
            from diffusers.utils import logging
         | 
| 17 | 
            +
            from torch import nn
         | 
| 18 | 
            +
            from safetensors import safe_open
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            from ltx_video.models.transformers.attention import BasicTransformerBlock
         | 
| 22 | 
            +
            from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from ltx_video.utils.diffusers_config_mapping import (
         | 
| 25 | 
            +
                diffusers_and_ours_config_mapping,
         | 
| 26 | 
            +
                make_hashable_key,
         | 
| 27 | 
            +
                TRANSFORMER_KEYS_RENAME_DICT,
         | 
| 28 | 
            +
            )
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            logger = logging.get_logger(__name__)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            @dataclass
         | 
| 35 | 
            +
            class Transformer3DModelOutput(BaseOutput):
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                The output of [`Transformer2DModel`].
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                Args:
         | 
| 40 | 
            +
                    sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
         | 
| 41 | 
            +
                        The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
         | 
| 42 | 
            +
                        distributions for the unnoised latent pixels.
         | 
| 43 | 
            +
                """
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                sample: torch.FloatTensor
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            class Transformer3DModel(ModelMixin, ConfigMixin):
         | 
| 49 | 
            +
                _supports_gradient_checkpointing = True
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                @register_to_config
         | 
| 52 | 
            +
                def __init__(
         | 
| 53 | 
            +
                    self,
         | 
| 54 | 
            +
                    num_attention_heads: int = 16,
         | 
| 55 | 
            +
                    attention_head_dim: int = 88,
         | 
| 56 | 
            +
                    in_channels: Optional[int] = None,
         | 
| 57 | 
            +
                    out_channels: Optional[int] = None,
         | 
| 58 | 
            +
                    num_layers: int = 1,
         | 
| 59 | 
            +
                    dropout: float = 0.0,
         | 
| 60 | 
            +
                    norm_num_groups: int = 32,
         | 
| 61 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 62 | 
            +
                    attention_bias: bool = False,
         | 
| 63 | 
            +
                    num_vector_embeds: Optional[int] = None,
         | 
| 64 | 
            +
                    activation_fn: str = "geglu",
         | 
| 65 | 
            +
                    num_embeds_ada_norm: Optional[int] = None,
         | 
| 66 | 
            +
                    use_linear_projection: bool = False,
         | 
| 67 | 
            +
                    only_cross_attention: bool = False,
         | 
| 68 | 
            +
                    double_self_attention: bool = False,
         | 
| 69 | 
            +
                    upcast_attention: bool = False,
         | 
| 70 | 
            +
                    adaptive_norm: str = "single_scale_shift",  # 'single_scale_shift' or 'single_scale'
         | 
| 71 | 
            +
                    standardization_norm: str = "layer_norm",  # 'layer_norm' or 'rms_norm'
         | 
| 72 | 
            +
                    norm_elementwise_affine: bool = True,
         | 
| 73 | 
            +
                    norm_eps: float = 1e-5,
         | 
| 74 | 
            +
                    attention_type: str = "default",
         | 
| 75 | 
            +
                    caption_channels: int = None,
         | 
| 76 | 
            +
                    use_tpu_flash_attention: bool = False,  # if True uses the TPU attention offload ('flash attention')
         | 
| 77 | 
            +
                    qk_norm: Optional[str] = None,
         | 
| 78 | 
            +
                    positional_embedding_type: str = "rope",
         | 
| 79 | 
            +
                    positional_embedding_theta: Optional[float] = None,
         | 
| 80 | 
            +
                    positional_embedding_max_pos: Optional[List[int]] = None,
         | 
| 81 | 
            +
                    timestep_scale_multiplier: Optional[float] = None,
         | 
| 82 | 
            +
                    causal_temporal_positioning: bool = False,  # For backward compatibility, will be deprecated
         | 
| 83 | 
            +
                ):
         | 
| 84 | 
            +
                    super().__init__()
         | 
| 85 | 
            +
                    self.use_tpu_flash_attention = (
         | 
| 86 | 
            +
                        use_tpu_flash_attention  # FIXME: push config down to the attention modules
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
                    self.use_linear_projection = use_linear_projection
         | 
| 89 | 
            +
                    self.num_attention_heads = num_attention_heads
         | 
| 90 | 
            +
                    self.attention_head_dim = attention_head_dim
         | 
| 91 | 
            +
                    inner_dim = num_attention_heads * attention_head_dim
         | 
| 92 | 
            +
                    self.inner_dim = inner_dim
         | 
| 93 | 
            +
                    self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True)
         | 
| 94 | 
            +
                    self.positional_embedding_type = positional_embedding_type
         | 
| 95 | 
            +
                    self.positional_embedding_theta = positional_embedding_theta
         | 
| 96 | 
            +
                    self.positional_embedding_max_pos = positional_embedding_max_pos
         | 
| 97 | 
            +
                    self.use_rope = self.positional_embedding_type == "rope"
         | 
| 98 | 
            +
                    self.timestep_scale_multiplier = timestep_scale_multiplier
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    if self.positional_embedding_type == "absolute":
         | 
| 101 | 
            +
                        raise ValueError("Absolute positional embedding is no longer supported")
         | 
| 102 | 
            +
                    elif self.positional_embedding_type == "rope":
         | 
| 103 | 
            +
                        if positional_embedding_theta is None:
         | 
| 104 | 
            +
                            raise ValueError(
         | 
| 105 | 
            +
                                "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined"
         | 
| 106 | 
            +
                            )
         | 
| 107 | 
            +
                        if positional_embedding_max_pos is None:
         | 
| 108 | 
            +
                            raise ValueError(
         | 
| 109 | 
            +
                                "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined"
         | 
| 110 | 
            +
                            )
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    # 3. Define transformers blocks
         | 
| 113 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 114 | 
            +
                        [
         | 
| 115 | 
            +
                            BasicTransformerBlock(
         | 
| 116 | 
            +
                                inner_dim,
         | 
| 117 | 
            +
                                num_attention_heads,
         | 
| 118 | 
            +
                                attention_head_dim,
         | 
| 119 | 
            +
                                dropout=dropout,
         | 
| 120 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 121 | 
            +
                                activation_fn=activation_fn,
         | 
| 122 | 
            +
                                num_embeds_ada_norm=num_embeds_ada_norm,
         | 
| 123 | 
            +
                                attention_bias=attention_bias,
         | 
| 124 | 
            +
                                only_cross_attention=only_cross_attention,
         | 
| 125 | 
            +
                                double_self_attention=double_self_attention,
         | 
| 126 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 127 | 
            +
                                adaptive_norm=adaptive_norm,
         | 
| 128 | 
            +
                                standardization_norm=standardization_norm,
         | 
| 129 | 
            +
                                norm_elementwise_affine=norm_elementwise_affine,
         | 
| 130 | 
            +
                                norm_eps=norm_eps,
         | 
| 131 | 
            +
                                attention_type=attention_type,
         | 
| 132 | 
            +
                                use_tpu_flash_attention=use_tpu_flash_attention,
         | 
| 133 | 
            +
                                qk_norm=qk_norm,
         | 
| 134 | 
            +
                                use_rope=self.use_rope,
         | 
| 135 | 
            +
                            )
         | 
| 136 | 
            +
                            for d in range(num_layers)
         | 
| 137 | 
            +
                        ]
         | 
| 138 | 
            +
                    )
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    # 4. Define output layers
         | 
| 141 | 
            +
                    self.out_channels = in_channels if out_channels is None else out_channels
         | 
| 142 | 
            +
                    self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
         | 
| 143 | 
            +
                    self.scale_shift_table = nn.Parameter(
         | 
| 144 | 
            +
                        torch.randn(2, inner_dim) / inner_dim**0.5
         | 
| 145 | 
            +
                    )
         | 
| 146 | 
            +
                    self.proj_out = nn.Linear(inner_dim, self.out_channels)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    self.adaln_single = AdaLayerNormSingle(
         | 
| 149 | 
            +
                        inner_dim, use_additional_conditions=False
         | 
| 150 | 
            +
                    )
         | 
| 151 | 
            +
                    if adaptive_norm == "single_scale":
         | 
| 152 | 
            +
                        self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    self.caption_projection = None
         | 
| 155 | 
            +
                    if caption_channels is not None:
         | 
| 156 | 
            +
                        self.caption_projection = PixArtAlphaTextProjection(
         | 
| 157 | 
            +
                            in_features=caption_channels, hidden_size=inner_dim
         | 
| 158 | 
            +
                        )
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    self.gradient_checkpointing = False
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def set_use_tpu_flash_attention(self):
         | 
| 163 | 
            +
                    r"""
         | 
| 164 | 
            +
                    Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
         | 
| 165 | 
            +
                    attention kernel.
         | 
| 166 | 
            +
                    """
         | 
| 167 | 
            +
                    logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
         | 
| 168 | 
            +
                    self.use_tpu_flash_attention = True
         | 
| 169 | 
            +
                    # push config down to the attention modules
         | 
| 170 | 
            +
                    for block in self.transformer_blocks:
         | 
| 171 | 
            +
                        block.set_use_tpu_flash_attention()
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def create_skip_layer_mask(
         | 
| 174 | 
            +
                    self,
         | 
| 175 | 
            +
                    batch_size: int,
         | 
| 176 | 
            +
                    num_conds: int,
         | 
| 177 | 
            +
                    ptb_index: int,
         | 
| 178 | 
            +
                    skip_block_list: Optional[List[int]] = None,
         | 
| 179 | 
            +
                ):
         | 
| 180 | 
            +
                    if skip_block_list is None or len(skip_block_list) == 0:
         | 
| 181 | 
            +
                        return None
         | 
| 182 | 
            +
                    num_layers = len(self.transformer_blocks)
         | 
| 183 | 
            +
                    mask = torch.ones(
         | 
| 184 | 
            +
                        (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype
         | 
| 185 | 
            +
                    )
         | 
| 186 | 
            +
                    for block_idx in skip_block_list:
         | 
| 187 | 
            +
                        mask[block_idx, ptb_index::num_conds] = 0
         | 
| 188 | 
            +
                    return mask
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 191 | 
            +
                    if hasattr(module, "gradient_checkpointing"):
         | 
| 192 | 
            +
                        module.gradient_checkpointing = value
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def get_fractional_positions(self, indices_grid):
         | 
| 195 | 
            +
                    fractional_positions = torch.stack(
         | 
| 196 | 
            +
                        [
         | 
| 197 | 
            +
                            indices_grid[:, i] / self.positional_embedding_max_pos[i]
         | 
| 198 | 
            +
                            for i in range(3)
         | 
| 199 | 
            +
                        ],
         | 
| 200 | 
            +
                        dim=-1,
         | 
| 201 | 
            +
                    )
         | 
| 202 | 
            +
                    return fractional_positions
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def precompute_freqs_cis(self, indices_grid, spacing="exp"):
         | 
| 205 | 
            +
                    dtype = torch.float32  # We need full precision in the freqs_cis computation.
         | 
| 206 | 
            +
                    dim = self.inner_dim
         | 
| 207 | 
            +
                    theta = self.positional_embedding_theta
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    fractional_positions = self.get_fractional_positions(indices_grid)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    start = 1
         | 
| 212 | 
            +
                    end = theta
         | 
| 213 | 
            +
                    device = fractional_positions.device
         | 
| 214 | 
            +
                    if spacing == "exp":
         | 
| 215 | 
            +
                        indices = theta ** (
         | 
| 216 | 
            +
                            torch.linspace(
         | 
| 217 | 
            +
                                math.log(start, theta),
         | 
| 218 | 
            +
                                math.log(end, theta),
         | 
| 219 | 
            +
                                dim // 6,
         | 
| 220 | 
            +
                                device=device,
         | 
| 221 | 
            +
                                dtype=dtype,
         | 
| 222 | 
            +
                            )
         | 
| 223 | 
            +
                        )
         | 
| 224 | 
            +
                        indices = indices.to(dtype=dtype)
         | 
| 225 | 
            +
                    elif spacing == "exp_2":
         | 
| 226 | 
            +
                        indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim)
         | 
| 227 | 
            +
                        indices = indices.to(dtype=dtype)
         | 
| 228 | 
            +
                    elif spacing == "linear":
         | 
| 229 | 
            +
                        indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
         | 
| 230 | 
            +
                    elif spacing == "sqrt":
         | 
| 231 | 
            +
                        indices = torch.linspace(
         | 
| 232 | 
            +
                            start**2, end**2, dim // 6, device=device, dtype=dtype
         | 
| 233 | 
            +
                        ).sqrt()
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    indices = indices * math.pi / 2
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    if spacing == "exp_2":
         | 
| 238 | 
            +
                        freqs = (
         | 
| 239 | 
            +
                            (indices * fractional_positions.unsqueeze(-1))
         | 
| 240 | 
            +
                            .transpose(-1, -2)
         | 
| 241 | 
            +
                            .flatten(2)
         | 
| 242 | 
            +
                        )
         | 
| 243 | 
            +
                    else:
         | 
| 244 | 
            +
                        freqs = (
         | 
| 245 | 
            +
                            (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
         | 
| 246 | 
            +
                            .transpose(-1, -2)
         | 
| 247 | 
            +
                            .flatten(2)
         | 
| 248 | 
            +
                        )
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
         | 
| 251 | 
            +
                    sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
         | 
| 252 | 
            +
                    if dim % 6 != 0:
         | 
| 253 | 
            +
                        cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
         | 
| 254 | 
            +
                        sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
         | 
| 255 | 
            +
                        cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
         | 
| 256 | 
            +
                        sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
         | 
| 257 | 
            +
                    return cos_freq.to(self.dtype), sin_freq.to(self.dtype)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                def load_state_dict(
         | 
| 260 | 
            +
                    self,
         | 
| 261 | 
            +
                    state_dict: Dict,
         | 
| 262 | 
            +
                    *args,
         | 
| 263 | 
            +
                    **kwargs,
         | 
| 264 | 
            +
                ):
         | 
| 265 | 
            +
                    if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]):
         | 
| 266 | 
            +
                        state_dict = {
         | 
| 267 | 
            +
                            key.replace("model.diffusion_model.", ""): value
         | 
| 268 | 
            +
                            for key, value in state_dict.items()
         | 
| 269 | 
            +
                            if key.startswith("model.diffusion_model.")
         | 
| 270 | 
            +
                        }
         | 
| 271 | 
            +
                    super().load_state_dict(state_dict, **kwargs)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                @classmethod
         | 
| 274 | 
            +
                def from_pretrained(
         | 
| 275 | 
            +
                    cls,
         | 
| 276 | 
            +
                    pretrained_model_path: Optional[Union[str, os.PathLike]],
         | 
| 277 | 
            +
                    *args,
         | 
| 278 | 
            +
                    **kwargs,
         | 
| 279 | 
            +
                ):
         | 
| 280 | 
            +
                    pretrained_model_path = Path(pretrained_model_path)
         | 
| 281 | 
            +
                    if pretrained_model_path.is_dir():
         | 
| 282 | 
            +
                        config_path = pretrained_model_path / "transformer" / "config.json"
         | 
| 283 | 
            +
                        with open(config_path, "r") as f:
         | 
| 284 | 
            +
                            config = make_hashable_key(json.load(f))
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                        assert config in diffusers_and_ours_config_mapping, (
         | 
| 287 | 
            +
                            "Provided diffusers checkpoint config for transformer is not suppported. "
         | 
| 288 | 
            +
                            "We only support diffusers configs found in Lightricks/LTX-Video."
         | 
| 289 | 
            +
                        )
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                        config = diffusers_and_ours_config_mapping[config]
         | 
| 292 | 
            +
                        state_dict = {}
         | 
| 293 | 
            +
                        ckpt_paths = (
         | 
| 294 | 
            +
                            pretrained_model_path
         | 
| 295 | 
            +
                            / "transformer"
         | 
| 296 | 
            +
                            / "diffusion_pytorch_model*.safetensors"
         | 
| 297 | 
            +
                        )
         | 
| 298 | 
            +
                        dict_list = glob.glob(str(ckpt_paths))
         | 
| 299 | 
            +
                        for dict_path in dict_list:
         | 
| 300 | 
            +
                            part_dict = {}
         | 
| 301 | 
            +
                            with safe_open(dict_path, framework="pt", device="cpu") as f:
         | 
| 302 | 
            +
                                for k in f.keys():
         | 
| 303 | 
            +
                                    part_dict[k] = f.get_tensor(k)
         | 
| 304 | 
            +
                            state_dict.update(part_dict)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                        for key in list(state_dict.keys()):
         | 
| 307 | 
            +
                            new_key = key
         | 
| 308 | 
            +
                            for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
         | 
| 309 | 
            +
                                new_key = new_key.replace(replace_key, rename_key)
         | 
| 310 | 
            +
                            state_dict[new_key] = state_dict.pop(key)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                        with torch.device("meta"):
         | 
| 313 | 
            +
                            transformer = cls.from_config(config)
         | 
| 314 | 
            +
                        transformer.load_state_dict(state_dict, assign=True, strict=True)
         | 
| 315 | 
            +
                    elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
         | 
| 316 | 
            +
                        ".safetensors"
         | 
| 317 | 
            +
                    ):
         | 
| 318 | 
            +
                        comfy_single_file_state_dict = {}
         | 
| 319 | 
            +
                        with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
         | 
| 320 | 
            +
                            metadata = f.metadata()
         | 
| 321 | 
            +
                            for k in f.keys():
         | 
| 322 | 
            +
                                comfy_single_file_state_dict[k] = f.get_tensor(k)
         | 
| 323 | 
            +
                        configs = json.loads(metadata["config"])
         | 
| 324 | 
            +
                        transformer_config = configs["transformer"]
         | 
| 325 | 
            +
                        with torch.device("meta"):
         | 
| 326 | 
            +
                            transformer = Transformer3DModel.from_config(transformer_config)
         | 
| 327 | 
            +
                        transformer.load_state_dict(comfy_single_file_state_dict, assign=True)
         | 
| 328 | 
            +
                    return transformer
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def forward(
         | 
| 331 | 
            +
                    self,
         | 
| 332 | 
            +
                    hidden_states: torch.Tensor,
         | 
| 333 | 
            +
                    indices_grid: torch.Tensor,
         | 
| 334 | 
            +
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         | 
| 335 | 
            +
                    timestep: Optional[torch.LongTensor] = None,
         | 
| 336 | 
            +
                    class_labels: Optional[torch.LongTensor] = None,
         | 
| 337 | 
            +
                    cross_attention_kwargs: Dict[str, Any] = None,
         | 
| 338 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 339 | 
            +
                    encoder_attention_mask: Optional[torch.Tensor] = None,
         | 
| 340 | 
            +
                    skip_layer_mask: Optional[torch.Tensor] = None,
         | 
| 341 | 
            +
                    skip_layer_strategy: Optional[SkipLayerStrategy] = None,
         | 
| 342 | 
            +
                    return_dict: bool = True,
         | 
| 343 | 
            +
                ):
         | 
| 344 | 
            +
                    """
         | 
| 345 | 
            +
                    The [`Transformer2DModel`] forward method.
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    Args:
         | 
| 348 | 
            +
                        hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
         | 
| 349 | 
            +
                            Input `hidden_states`.
         | 
| 350 | 
            +
                        indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
         | 
| 351 | 
            +
                        encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
         | 
| 352 | 
            +
                            Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
         | 
| 353 | 
            +
                            self-attention.
         | 
| 354 | 
            +
                        timestep ( `torch.LongTensor`, *optional*):
         | 
| 355 | 
            +
                            Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
         | 
| 356 | 
            +
                        class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
         | 
| 357 | 
            +
                            Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
         | 
| 358 | 
            +
                            `AdaLayerZeroNorm`.
         | 
| 359 | 
            +
                        cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
         | 
| 360 | 
            +
                            A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
         | 
| 361 | 
            +
                            `self.processor` in
         | 
| 362 | 
            +
                            [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
         | 
| 363 | 
            +
                        attention_mask ( `torch.Tensor`, *optional*):
         | 
| 364 | 
            +
                            An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
         | 
| 365 | 
            +
                            is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
         | 
| 366 | 
            +
                            negative values to the attention scores corresponding to "discard" tokens.
         | 
| 367 | 
            +
                        encoder_attention_mask ( `torch.Tensor`, *optional*):
         | 
| 368 | 
            +
                            Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                                * Mask `(batch, sequence_length)` True = keep, False = discard.
         | 
| 371 | 
            +
                                * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                            If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
         | 
| 374 | 
            +
                            above. This bias will be added to the cross-attention scores.
         | 
| 375 | 
            +
                        skip_layer_mask ( `torch.Tensor`, *optional*):
         | 
| 376 | 
            +
                            A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position
         | 
| 377 | 
            +
                            `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index.
         | 
| 378 | 
            +
                        skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`):
         | 
| 379 | 
            +
                            Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance.
         | 
| 380 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 381 | 
            +
                            Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
         | 
| 382 | 
            +
                            tuple.
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    Returns:
         | 
| 385 | 
            +
                        If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
         | 
| 386 | 
            +
                        `tuple` where the first element is the sample tensor.
         | 
| 387 | 
            +
                    """
         | 
| 388 | 
            +
                    # for tpu attention offload 2d token masks are used. No need to transform.
         | 
| 389 | 
            +
                    if not self.use_tpu_flash_attention:
         | 
| 390 | 
            +
                        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
         | 
| 391 | 
            +
                        #   we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
         | 
| 392 | 
            +
                        #   we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
         | 
| 393 | 
            +
                        # expects mask of shape:
         | 
| 394 | 
            +
                        #   [batch, key_tokens]
         | 
| 395 | 
            +
                        # adds singleton query_tokens dimension:
         | 
| 396 | 
            +
                        #   [batch,                    1, key_tokens]
         | 
| 397 | 
            +
                        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
         | 
| 398 | 
            +
                        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)
         | 
| 399 | 
            +
                        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
         | 
| 400 | 
            +
                        if attention_mask is not None and attention_mask.ndim == 2:
         | 
| 401 | 
            +
                            # assume that mask is expressed as:
         | 
| 402 | 
            +
                            #   (1 = keep,      0 = discard)
         | 
| 403 | 
            +
                            # convert mask into a bias that can be added to attention scores:
         | 
| 404 | 
            +
                            #       (keep = +0,     discard = -10000.0)
         | 
| 405 | 
            +
                            attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
         | 
| 406 | 
            +
                            attention_mask = attention_mask.unsqueeze(1)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                        # convert encoder_attention_mask to a bias the same way we do for attention_mask
         | 
| 409 | 
            +
                        if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
         | 
| 410 | 
            +
                            encoder_attention_mask = (
         | 
| 411 | 
            +
                                1 - encoder_attention_mask.to(hidden_states.dtype)
         | 
| 412 | 
            +
                            ) * -10000.0
         | 
| 413 | 
            +
                            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    # 1. Input
         | 
| 416 | 
            +
                    hidden_states = self.patchify_proj(hidden_states)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                    if self.timestep_scale_multiplier:
         | 
| 419 | 
            +
                        timestep = self.timestep_scale_multiplier * timestep
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    freqs_cis = self.precompute_freqs_cis(indices_grid)
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    batch_size = hidden_states.shape[0]
         | 
| 424 | 
            +
                    timestep, embedded_timestep = self.adaln_single(
         | 
| 425 | 
            +
                        timestep.flatten(),
         | 
| 426 | 
            +
                        {"resolution": None, "aspect_ratio": None},
         | 
| 427 | 
            +
                        batch_size=batch_size,
         | 
| 428 | 
            +
                        hidden_dtype=hidden_states.dtype,
         | 
| 429 | 
            +
                    )
         | 
| 430 | 
            +
                    # Second dimension is 1 or number of tokens (if timestep_per_token)
         | 
| 431 | 
            +
                    timestep = timestep.view(batch_size, -1, timestep.shape[-1])
         | 
| 432 | 
            +
                    embedded_timestep = embedded_timestep.view(
         | 
| 433 | 
            +
                        batch_size, -1, embedded_timestep.shape[-1]
         | 
| 434 | 
            +
                    )
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                    # 2. Blocks
         | 
| 437 | 
            +
                    if self.caption_projection is not None:
         | 
| 438 | 
            +
                        batch_size = hidden_states.shape[0]
         | 
| 439 | 
            +
                        encoder_hidden_states = self.caption_projection(encoder_hidden_states)
         | 
| 440 | 
            +
                        encoder_hidden_states = encoder_hidden_states.view(
         | 
| 441 | 
            +
                            batch_size, -1, hidden_states.shape[-1]
         | 
| 442 | 
            +
                        )
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    for block_idx, block in enumerate(self.transformer_blocks):
         | 
| 445 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                            def create_custom_forward(module, return_dict=None):
         | 
| 448 | 
            +
                                def custom_forward(*inputs):
         | 
| 449 | 
            +
                                    if return_dict is not None:
         | 
| 450 | 
            +
                                        return module(*inputs, return_dict=return_dict)
         | 
| 451 | 
            +
                                    else:
         | 
| 452 | 
            +
                                        return module(*inputs)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                                return custom_forward
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                            ckpt_kwargs: Dict[str, Any] = (
         | 
| 457 | 
            +
                                {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
         | 
| 458 | 
            +
                            )
         | 
| 459 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 460 | 
            +
                                create_custom_forward(block),
         | 
| 461 | 
            +
                                hidden_states,
         | 
| 462 | 
            +
                                freqs_cis,
         | 
| 463 | 
            +
                                attention_mask,
         | 
| 464 | 
            +
                                encoder_hidden_states,
         | 
| 465 | 
            +
                                encoder_attention_mask,
         | 
| 466 | 
            +
                                timestep,
         | 
| 467 | 
            +
                                cross_attention_kwargs,
         | 
| 468 | 
            +
                                class_labels,
         | 
| 469 | 
            +
                                (
         | 
| 470 | 
            +
                                    skip_layer_mask[block_idx]
         | 
| 471 | 
            +
                                    if skip_layer_mask is not None
         | 
| 472 | 
            +
                                    else None
         | 
| 473 | 
            +
                                ),
         | 
| 474 | 
            +
                                skip_layer_strategy,
         | 
| 475 | 
            +
                                **ckpt_kwargs,
         | 
| 476 | 
            +
                            )
         | 
| 477 | 
            +
                        else:
         | 
| 478 | 
            +
                            hidden_states = block(
         | 
| 479 | 
            +
                                hidden_states,
         | 
| 480 | 
            +
                                freqs_cis=freqs_cis,
         | 
| 481 | 
            +
                                attention_mask=attention_mask,
         | 
| 482 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 483 | 
            +
                                encoder_attention_mask=encoder_attention_mask,
         | 
| 484 | 
            +
                                timestep=timestep,
         | 
| 485 | 
            +
                                cross_attention_kwargs=cross_attention_kwargs,
         | 
| 486 | 
            +
                                class_labels=class_labels,
         | 
| 487 | 
            +
                                skip_layer_mask=(
         | 
| 488 | 
            +
                                    skip_layer_mask[block_idx]
         | 
| 489 | 
            +
                                    if skip_layer_mask is not None
         | 
| 490 | 
            +
                                    else None
         | 
| 491 | 
            +
                                ),
         | 
| 492 | 
            +
                                skip_layer_strategy=skip_layer_strategy,
         | 
| 493 | 
            +
                            )
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                    # 3. Output
         | 
| 496 | 
            +
                    scale_shift_values = (
         | 
| 497 | 
            +
                        self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
         | 
| 498 | 
            +
                    )
         | 
| 499 | 
            +
                    shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
         | 
| 500 | 
            +
                    hidden_states = self.norm_out(hidden_states)
         | 
| 501 | 
            +
                    # Modulation
         | 
| 502 | 
            +
                    hidden_states = hidden_states * (1 + scale) + shift
         | 
| 503 | 
            +
                    hidden_states = self.proj_out(hidden_states)
         | 
| 504 | 
            +
                    if not return_dict:
         | 
| 505 | 
            +
                        return (hidden_states,)
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                    return Transformer3DModelOutput(sample=hidden_states)
         | 
    	
        ltx_video/pipelines/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        ltx_video/pipelines/crf_compressor.py
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import av
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import io
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def _encode_single_frame(output_file, image_array: np.ndarray, crf):
         | 
| 8 | 
            +
                container = av.open(output_file, "w", format="mp4")
         | 
| 9 | 
            +
                try:
         | 
| 10 | 
            +
                    stream = container.add_stream(
         | 
| 11 | 
            +
                        "libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"}
         | 
| 12 | 
            +
                    )
         | 
| 13 | 
            +
                    stream.height = image_array.shape[0]
         | 
| 14 | 
            +
                    stream.width = image_array.shape[1]
         | 
| 15 | 
            +
                    av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(
         | 
| 16 | 
            +
                        format="yuv420p"
         | 
| 17 | 
            +
                    )
         | 
| 18 | 
            +
                    container.mux(stream.encode(av_frame))
         | 
| 19 | 
            +
                    container.mux(stream.encode())
         | 
| 20 | 
            +
                finally:
         | 
| 21 | 
            +
                    container.close()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def _decode_single_frame(video_file):
         | 
| 25 | 
            +
                container = av.open(video_file)
         | 
| 26 | 
            +
                try:
         | 
| 27 | 
            +
                    stream = next(s for s in container.streams if s.type == "video")
         | 
| 28 | 
            +
                    frame = next(container.decode(stream))
         | 
| 29 | 
            +
                finally:
         | 
| 30 | 
            +
                    container.close()
         | 
| 31 | 
            +
                return frame.to_ndarray(format="rgb24")
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def compress(image: torch.Tensor, crf=29):
         | 
| 35 | 
            +
                if crf == 0:
         | 
| 36 | 
            +
                    return image
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                image_array = (
         | 
| 39 | 
            +
                    (image[: (image.shape[0] // 2) * 2, : (image.shape[1] // 2) * 2] * 255.0)
         | 
| 40 | 
            +
                    .byte()
         | 
| 41 | 
            +
                    .cpu()
         | 
| 42 | 
            +
                    .numpy()
         | 
| 43 | 
            +
                )
         | 
| 44 | 
            +
                with io.BytesIO() as output_file:
         | 
| 45 | 
            +
                    _encode_single_frame(output_file, image_array, crf)
         | 
| 46 | 
            +
                    video_bytes = output_file.getvalue()
         | 
| 47 | 
            +
                with io.BytesIO(video_bytes) as video_file:
         | 
| 48 | 
            +
                    image_array = _decode_single_frame(video_file)
         | 
| 49 | 
            +
                tensor = torch.tensor(image_array, dtype=image.dtype, device=image.device) / 255.0
         | 
| 50 | 
            +
                return tensor
         | 
    	
        ltx_video/pipelines/pipeline_ltx_video.py
    ADDED
    
    | @@ -0,0 +1,1845 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
         | 
| 2 | 
            +
            import copy
         | 
| 3 | 
            +
            import inspect
         | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
            import re
         | 
| 6 | 
            +
            from contextlib import nullcontext
         | 
| 7 | 
            +
            from dataclasses import dataclass
         | 
| 8 | 
            +
            from typing import Any, Callable, Dict, List, Optional, Tuple, Union
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            import torch.nn.functional as F
         | 
| 12 | 
            +
            from diffusers.image_processor import VaeImageProcessor
         | 
| 13 | 
            +
            from diffusers.models import AutoencoderKL
         | 
| 14 | 
            +
            from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
         | 
| 15 | 
            +
            from diffusers.schedulers import DPMSolverMultistepScheduler
         | 
| 16 | 
            +
            from diffusers.utils import deprecate, logging
         | 
| 17 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 18 | 
            +
            from einops import rearrange
         | 
| 19 | 
            +
            from transformers import (
         | 
| 20 | 
            +
                T5EncoderModel,
         | 
| 21 | 
            +
                T5Tokenizer,
         | 
| 22 | 
            +
                AutoModelForCausalLM,
         | 
| 23 | 
            +
                AutoProcessor,
         | 
| 24 | 
            +
                AutoTokenizer,
         | 
| 25 | 
            +
            )
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            from ltx_video.models.autoencoders.causal_video_autoencoder import (
         | 
| 28 | 
            +
                CausalVideoAutoencoder,
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
            from ltx_video.models.autoencoders.vae_encode import (
         | 
| 31 | 
            +
                get_vae_size_scale_factor,
         | 
| 32 | 
            +
                latent_to_pixel_coords,
         | 
| 33 | 
            +
                vae_decode,
         | 
| 34 | 
            +
                vae_encode,
         | 
| 35 | 
            +
            )
         | 
| 36 | 
            +
            from ltx_video.models.transformers.symmetric_patchifier import Patchifier
         | 
| 37 | 
            +
            from ltx_video.models.transformers.transformer3d import Transformer3DModel
         | 
| 38 | 
            +
            from ltx_video.schedulers.rf import TimestepShifter
         | 
| 39 | 
            +
            from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
         | 
| 40 | 
            +
            from ltx_video.utils.prompt_enhance_utils import generate_cinematic_prompt
         | 
| 41 | 
            +
            from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler
         | 
| 42 | 
            +
            from ltx_video.models.autoencoders.vae_encode import (
         | 
| 43 | 
            +
                un_normalize_latents,
         | 
| 44 | 
            +
                normalize_latents,
         | 
| 45 | 
            +
            )
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            try:
         | 
| 49 | 
            +
                import torch_xla.distributed.spmd as xs
         | 
| 50 | 
            +
            except ImportError:
         | 
| 51 | 
            +
                xs = None
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            ASPECT_RATIO_1024_BIN = {
         | 
| 57 | 
            +
                "0.25": [512.0, 2048.0],
         | 
| 58 | 
            +
                "0.28": [512.0, 1856.0],
         | 
| 59 | 
            +
                "0.32": [576.0, 1792.0],
         | 
| 60 | 
            +
                "0.33": [576.0, 1728.0],
         | 
| 61 | 
            +
                "0.35": [576.0, 1664.0],
         | 
| 62 | 
            +
                "0.4": [640.0, 1600.0],
         | 
| 63 | 
            +
                "0.42": [640.0, 1536.0],
         | 
| 64 | 
            +
                "0.48": [704.0, 1472.0],
         | 
| 65 | 
            +
                "0.5": [704.0, 1408.0],
         | 
| 66 | 
            +
                "0.52": [704.0, 1344.0],
         | 
| 67 | 
            +
                "0.57": [768.0, 1344.0],
         | 
| 68 | 
            +
                "0.6": [768.0, 1280.0],
         | 
| 69 | 
            +
                "0.68": [832.0, 1216.0],
         | 
| 70 | 
            +
                "0.72": [832.0, 1152.0],
         | 
| 71 | 
            +
                "0.78": [896.0, 1152.0],
         | 
| 72 | 
            +
                "0.82": [896.0, 1088.0],
         | 
| 73 | 
            +
                "0.88": [960.0, 1088.0],
         | 
| 74 | 
            +
                "0.94": [960.0, 1024.0],
         | 
| 75 | 
            +
                "1.0": [1024.0, 1024.0],
         | 
| 76 | 
            +
                "1.07": [1024.0, 960.0],
         | 
| 77 | 
            +
                "1.13": [1088.0, 960.0],
         | 
| 78 | 
            +
                "1.21": [1088.0, 896.0],
         | 
| 79 | 
            +
                "1.29": [1152.0, 896.0],
         | 
| 80 | 
            +
                "1.38": [1152.0, 832.0],
         | 
| 81 | 
            +
                "1.46": [1216.0, 832.0],
         | 
| 82 | 
            +
                "1.67": [1280.0, 768.0],
         | 
| 83 | 
            +
                "1.75": [1344.0, 768.0],
         | 
| 84 | 
            +
                "2.0": [1408.0, 704.0],
         | 
| 85 | 
            +
                "2.09": [1472.0, 704.0],
         | 
| 86 | 
            +
                "2.4": [1536.0, 640.0],
         | 
| 87 | 
            +
                "2.5": [1600.0, 640.0],
         | 
| 88 | 
            +
                "3.0": [1728.0, 576.0],
         | 
| 89 | 
            +
                "4.0": [2048.0, 512.0],
         | 
| 90 | 
            +
            }
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            ASPECT_RATIO_512_BIN = {
         | 
| 93 | 
            +
                "0.25": [256.0, 1024.0],
         | 
| 94 | 
            +
                "0.28": [256.0, 928.0],
         | 
| 95 | 
            +
                "0.32": [288.0, 896.0],
         | 
| 96 | 
            +
                "0.33": [288.0, 864.0],
         | 
| 97 | 
            +
                "0.35": [288.0, 832.0],
         | 
| 98 | 
            +
                "0.4": [320.0, 800.0],
         | 
| 99 | 
            +
                "0.42": [320.0, 768.0],
         | 
| 100 | 
            +
                "0.48": [352.0, 736.0],
         | 
| 101 | 
            +
                "0.5": [352.0, 704.0],
         | 
| 102 | 
            +
                "0.52": [352.0, 672.0],
         | 
| 103 | 
            +
                "0.57": [384.0, 672.0],
         | 
| 104 | 
            +
                "0.6": [384.0, 640.0],
         | 
| 105 | 
            +
                "0.68": [416.0, 608.0],
         | 
| 106 | 
            +
                "0.72": [416.0, 576.0],
         | 
| 107 | 
            +
                "0.78": [448.0, 576.0],
         | 
| 108 | 
            +
                "0.82": [448.0, 544.0],
         | 
| 109 | 
            +
                "0.88": [480.0, 544.0],
         | 
| 110 | 
            +
                "0.94": [480.0, 512.0],
         | 
| 111 | 
            +
                "1.0": [512.0, 512.0],
         | 
| 112 | 
            +
                "1.07": [512.0, 480.0],
         | 
| 113 | 
            +
                "1.13": [544.0, 480.0],
         | 
| 114 | 
            +
                "1.21": [544.0, 448.0],
         | 
| 115 | 
            +
                "1.29": [576.0, 448.0],
         | 
| 116 | 
            +
                "1.38": [576.0, 416.0],
         | 
| 117 | 
            +
                "1.46": [608.0, 416.0],
         | 
| 118 | 
            +
                "1.67": [640.0, 384.0],
         | 
| 119 | 
            +
                "1.75": [672.0, 384.0],
         | 
| 120 | 
            +
                "2.0": [704.0, 352.0],
         | 
| 121 | 
            +
                "2.09": [736.0, 352.0],
         | 
| 122 | 
            +
                "2.4": [768.0, 320.0],
         | 
| 123 | 
            +
                "2.5": [800.0, 320.0],
         | 
| 124 | 
            +
                "3.0": [864.0, 288.0],
         | 
| 125 | 
            +
                "4.0": [1024.0, 256.0],
         | 
| 126 | 
            +
            }
         | 
| 127 | 
            +
             | 
| 128 | 
            +
             | 
| 129 | 
            +
            # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
         | 
| 130 | 
            +
            def retrieve_timesteps(
         | 
| 131 | 
            +
                scheduler,
         | 
| 132 | 
            +
                num_inference_steps: Optional[int] = None,
         | 
| 133 | 
            +
                device: Optional[Union[str, torch.device]] = None,
         | 
| 134 | 
            +
                timesteps: Optional[List[int]] = None,
         | 
| 135 | 
            +
                skip_initial_inference_steps: int = 0,
         | 
| 136 | 
            +
                skip_final_inference_steps: int = 0,
         | 
| 137 | 
            +
                **kwargs,
         | 
| 138 | 
            +
            ):
         | 
| 139 | 
            +
                """
         | 
| 140 | 
            +
                Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
         | 
| 141 | 
            +
                custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                Args:
         | 
| 144 | 
            +
                    scheduler (`SchedulerMixin`):
         | 
| 145 | 
            +
                        The scheduler to get timesteps from.
         | 
| 146 | 
            +
                    num_inference_steps (`int`):
         | 
| 147 | 
            +
                        The number of diffusion steps used when generating samples with a pre-trained model. If used,
         | 
| 148 | 
            +
                        `timesteps` must be `None`.
         | 
| 149 | 
            +
                    device (`str` or `torch.device`, *optional*):
         | 
| 150 | 
            +
                        The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
         | 
| 151 | 
            +
                    timesteps (`List[int]`, *optional*):
         | 
| 152 | 
            +
                        Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
         | 
| 153 | 
            +
                        timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
         | 
| 154 | 
            +
                        must be `None`.
         | 
| 155 | 
            +
                    max_timestep ('float', *optional*, defaults to 1.0):
         | 
| 156 | 
            +
                        The initial noising level for image-to-image/video-to-video. The list if timestamps will be
         | 
| 157 | 
            +
                        truncated to start with a timestamp greater or equal to this.
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                Returns:
         | 
| 160 | 
            +
                    `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
         | 
| 161 | 
            +
                    second element is the number of inference steps.
         | 
| 162 | 
            +
                """
         | 
| 163 | 
            +
                if timesteps is not None:
         | 
| 164 | 
            +
                    accepts_timesteps = "timesteps" in set(
         | 
| 165 | 
            +
                        inspect.signature(scheduler.set_timesteps).parameters.keys()
         | 
| 166 | 
            +
                    )
         | 
| 167 | 
            +
                    if not accepts_timesteps:
         | 
| 168 | 
            +
                        raise ValueError(
         | 
| 169 | 
            +
                            f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
         | 
| 170 | 
            +
                            f" timestep schedules. Please check whether you are using the correct scheduler."
         | 
| 171 | 
            +
                        )
         | 
| 172 | 
            +
                    scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
         | 
| 173 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 174 | 
            +
                    num_inference_steps = len(timesteps)
         | 
| 175 | 
            +
                else:
         | 
| 176 | 
            +
                    scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
         | 
| 177 | 
            +
                    timesteps = scheduler.timesteps
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    if (
         | 
| 180 | 
            +
                        skip_initial_inference_steps < 0
         | 
| 181 | 
            +
                        or skip_final_inference_steps < 0
         | 
| 182 | 
            +
                        or skip_initial_inference_steps + skip_final_inference_steps
         | 
| 183 | 
            +
                        >= num_inference_steps
         | 
| 184 | 
            +
                    ):
         | 
| 185 | 
            +
                        raise ValueError(
         | 
| 186 | 
            +
                            "invalid skip inference step values: must be non-negative and the sum of skip_initial_inference_steps and skip_final_inference_steps must be less than the number of inference steps"
         | 
| 187 | 
            +
                        )
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    timesteps = timesteps[
         | 
| 190 | 
            +
                        skip_initial_inference_steps : len(timesteps) - skip_final_inference_steps
         | 
| 191 | 
            +
                    ]
         | 
| 192 | 
            +
                    scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
         | 
| 193 | 
            +
                    num_inference_steps = len(timesteps)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                return timesteps, num_inference_steps
         | 
| 196 | 
            +
             | 
| 197 | 
            +
             | 
| 198 | 
            +
            @dataclass
         | 
| 199 | 
            +
            class ConditioningItem:
         | 
| 200 | 
            +
                """
         | 
| 201 | 
            +
                Defines a single frame-conditioning item - a single frame or a sequence of frames.
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                Attributes:
         | 
| 204 | 
            +
                    media_item (torch.Tensor): shape=(b, 3, f, h, w). The media item to condition on.
         | 
| 205 | 
            +
                    media_frame_number (int): The start-frame number of the media item in the generated video.
         | 
| 206 | 
            +
                    conditioning_strength (float): The strength of the conditioning (1.0 = full conditioning).
         | 
| 207 | 
            +
                    media_x (Optional[int]): Optional left x coordinate of the media item in the generated frame.
         | 
| 208 | 
            +
                    media_y (Optional[int]): Optional top y coordinate of the media item in the generated frame.
         | 
| 209 | 
            +
                """
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                media_item: torch.Tensor
         | 
| 212 | 
            +
                media_frame_number: int
         | 
| 213 | 
            +
                conditioning_strength: float
         | 
| 214 | 
            +
                media_x: Optional[int] = None
         | 
| 215 | 
            +
                media_y: Optional[int] = None
         | 
| 216 | 
            +
             | 
| 217 | 
            +
             | 
| 218 | 
            +
            class LTXVideoPipeline(DiffusionPipeline):
         | 
| 219 | 
            +
                r"""
         | 
| 220 | 
            +
                Pipeline for text-to-image generation using LTX-Video.
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
         | 
| 223 | 
            +
                library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                Args:
         | 
| 226 | 
            +
                    vae ([`AutoencoderKL`]):
         | 
| 227 | 
            +
                        Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
         | 
| 228 | 
            +
                    text_encoder ([`T5EncoderModel`]):
         | 
| 229 | 
            +
                        Frozen text-encoder. This uses
         | 
| 230 | 
            +
                        [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
         | 
| 231 | 
            +
                        [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
         | 
| 232 | 
            +
                    tokenizer (`T5Tokenizer`):
         | 
| 233 | 
            +
                        Tokenizer of class
         | 
| 234 | 
            +
                        [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
         | 
| 235 | 
            +
                    transformer ([`Transformer2DModel`]):
         | 
| 236 | 
            +
                        A text conditioned `Transformer2DModel` to denoise the encoded image latents.
         | 
| 237 | 
            +
                    scheduler ([`SchedulerMixin`]):
         | 
| 238 | 
            +
                        A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
         | 
| 239 | 
            +
                """
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                bad_punct_regex = re.compile(
         | 
| 242 | 
            +
                    r"["
         | 
| 243 | 
            +
                    + "#®•©™&@·º½¾¿¡§~"
         | 
| 244 | 
            +
                    + r"\)"
         | 
| 245 | 
            +
                    + r"\("
         | 
| 246 | 
            +
                    + r"\]"
         | 
| 247 | 
            +
                    + r"\["
         | 
| 248 | 
            +
                    + r"\}"
         | 
| 249 | 
            +
                    + r"\{"
         | 
| 250 | 
            +
                    + r"\|"
         | 
| 251 | 
            +
                    + "\\"
         | 
| 252 | 
            +
                    + r"\/"
         | 
| 253 | 
            +
                    + r"\*"
         | 
| 254 | 
            +
                    + r"]{1,}"
         | 
| 255 | 
            +
                )  # noqa
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                _optional_components = [
         | 
| 258 | 
            +
                    "tokenizer",
         | 
| 259 | 
            +
                    "text_encoder",
         | 
| 260 | 
            +
                    "prompt_enhancer_image_caption_model",
         | 
| 261 | 
            +
                    "prompt_enhancer_image_caption_processor",
         | 
| 262 | 
            +
                    "prompt_enhancer_llm_model",
         | 
| 263 | 
            +
                    "prompt_enhancer_llm_tokenizer",
         | 
| 264 | 
            +
                ]
         | 
| 265 | 
            +
                model_cpu_offload_seq = "prompt_enhancer_image_caption_model->prompt_enhancer_llm_model->text_encoder->transformer->vae"
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                def __init__(
         | 
| 268 | 
            +
                    self,
         | 
| 269 | 
            +
                    tokenizer: T5Tokenizer,
         | 
| 270 | 
            +
                    text_encoder: T5EncoderModel,
         | 
| 271 | 
            +
                    vae: AutoencoderKL,
         | 
| 272 | 
            +
                    transformer: Transformer3DModel,
         | 
| 273 | 
            +
                    scheduler: DPMSolverMultistepScheduler,
         | 
| 274 | 
            +
                    patchifier: Patchifier,
         | 
| 275 | 
            +
                    prompt_enhancer_image_caption_model: AutoModelForCausalLM,
         | 
| 276 | 
            +
                    prompt_enhancer_image_caption_processor: AutoProcessor,
         | 
| 277 | 
            +
                    prompt_enhancer_llm_model: AutoModelForCausalLM,
         | 
| 278 | 
            +
                    prompt_enhancer_llm_tokenizer: AutoTokenizer,
         | 
| 279 | 
            +
                    allowed_inference_steps: Optional[List[float]] = None,
         | 
| 280 | 
            +
                ):
         | 
| 281 | 
            +
                    super().__init__()
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    self.register_modules(
         | 
| 284 | 
            +
                        tokenizer=tokenizer,
         | 
| 285 | 
            +
                        text_encoder=text_encoder,
         | 
| 286 | 
            +
                        vae=vae,
         | 
| 287 | 
            +
                        transformer=transformer,
         | 
| 288 | 
            +
                        scheduler=scheduler,
         | 
| 289 | 
            +
                        patchifier=patchifier,
         | 
| 290 | 
            +
                        prompt_enhancer_image_caption_model=prompt_enhancer_image_caption_model,
         | 
| 291 | 
            +
                        prompt_enhancer_image_caption_processor=prompt_enhancer_image_caption_processor,
         | 
| 292 | 
            +
                        prompt_enhancer_llm_model=prompt_enhancer_llm_model,
         | 
| 293 | 
            +
                        prompt_enhancer_llm_tokenizer=prompt_enhancer_llm_tokenizer,
         | 
| 294 | 
            +
                    )
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor(
         | 
| 297 | 
            +
                        self.vae
         | 
| 298 | 
            +
                    )
         | 
| 299 | 
            +
                    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    self.allowed_inference_steps = allowed_inference_steps
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                def mask_text_embeddings(self, emb, mask):
         | 
| 304 | 
            +
                    if emb.shape[0] == 1:
         | 
| 305 | 
            +
                        keep_index = mask.sum().item()
         | 
| 306 | 
            +
                        return emb[:, :, :keep_index, :], keep_index
         | 
| 307 | 
            +
                    else:
         | 
| 308 | 
            +
                        masked_feature = emb * mask[:, None, :, None]
         | 
| 309 | 
            +
                        return masked_feature, emb.shape[2]
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
         | 
| 312 | 
            +
                def encode_prompt(
         | 
| 313 | 
            +
                    self,
         | 
| 314 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 315 | 
            +
                    do_classifier_free_guidance: bool = True,
         | 
| 316 | 
            +
                    negative_prompt: str = "",
         | 
| 317 | 
            +
                    num_images_per_prompt: int = 1,
         | 
| 318 | 
            +
                    device: Optional[torch.device] = None,
         | 
| 319 | 
            +
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 320 | 
            +
                    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 321 | 
            +
                    prompt_attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 322 | 
            +
                    negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 323 | 
            +
                    text_encoder_max_tokens: int = 256,
         | 
| 324 | 
            +
                    **kwargs,
         | 
| 325 | 
            +
                ):
         | 
| 326 | 
            +
                    r"""
         | 
| 327 | 
            +
                    Encodes the prompt into text encoder hidden states.
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    Args:
         | 
| 330 | 
            +
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 331 | 
            +
                            prompt to be encoded
         | 
| 332 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 333 | 
            +
                            The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
         | 
| 334 | 
            +
                            instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
         | 
| 335 | 
            +
                            This should be "".
         | 
| 336 | 
            +
                        do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
         | 
| 337 | 
            +
                            whether to use classifier free guidance or not
         | 
| 338 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 339 | 
            +
                            number of images that should be generated per prompt
         | 
| 340 | 
            +
                        device: (`torch.device`, *optional*):
         | 
| 341 | 
            +
                            torch device to place the resulting embeddings on
         | 
| 342 | 
            +
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 343 | 
            +
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         | 
| 344 | 
            +
                            provided, text embeddings will be generated from `prompt` input argument.
         | 
| 345 | 
            +
                        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 346 | 
            +
                            Pre-generated negative text embeddings.
         | 
| 347 | 
            +
                    """
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    if "mask_feature" in kwargs:
         | 
| 350 | 
            +
                        deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
         | 
| 351 | 
            +
                        deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    if device is None:
         | 
| 354 | 
            +
                        device = self._execution_device
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    if prompt is not None and isinstance(prompt, str):
         | 
| 357 | 
            +
                        batch_size = 1
         | 
| 358 | 
            +
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 359 | 
            +
                        batch_size = len(prompt)
         | 
| 360 | 
            +
                    else:
         | 
| 361 | 
            +
                        batch_size = prompt_embeds.shape[0]
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    # See Section 3.1. of the paper.
         | 
| 364 | 
            +
                    max_length = (
         | 
| 365 | 
            +
                        text_encoder_max_tokens  # TPU supports only lengths multiple of 128
         | 
| 366 | 
            +
                    )
         | 
| 367 | 
            +
                    if prompt_embeds is None:
         | 
| 368 | 
            +
                        assert (
         | 
| 369 | 
            +
                            self.text_encoder is not None
         | 
| 370 | 
            +
                        ), "You should provide either prompt_embeds or self.text_encoder should not be None,"
         | 
| 371 | 
            +
                        text_enc_device = next(self.text_encoder.parameters()).device
         | 
| 372 | 
            +
                        prompt = self._text_preprocessing(prompt)
         | 
| 373 | 
            +
                        text_inputs = self.tokenizer(
         | 
| 374 | 
            +
                            prompt,
         | 
| 375 | 
            +
                            padding="max_length",
         | 
| 376 | 
            +
                            max_length=max_length,
         | 
| 377 | 
            +
                            truncation=True,
         | 
| 378 | 
            +
                            add_special_tokens=True,
         | 
| 379 | 
            +
                            return_tensors="pt",
         | 
| 380 | 
            +
                        )
         | 
| 381 | 
            +
                        text_input_ids = text_inputs.input_ids
         | 
| 382 | 
            +
                        untruncated_ids = self.tokenizer(
         | 
| 383 | 
            +
                            prompt, padding="longest", return_tensors="pt"
         | 
| 384 | 
            +
                        ).input_ids
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                        if untruncated_ids.shape[-1] >= text_input_ids.shape[
         | 
| 387 | 
            +
                            -1
         | 
| 388 | 
            +
                        ] and not torch.equal(text_input_ids, untruncated_ids):
         | 
| 389 | 
            +
                            removed_text = self.tokenizer.batch_decode(
         | 
| 390 | 
            +
                                untruncated_ids[:, max_length - 1 : -1]
         | 
| 391 | 
            +
                            )
         | 
| 392 | 
            +
                            logger.warning(
         | 
| 393 | 
            +
                                "The following part of your input was truncated because CLIP can only handle sequences up to"
         | 
| 394 | 
            +
                                f" {max_length} tokens: {removed_text}"
         | 
| 395 | 
            +
                            )
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                        prompt_attention_mask = text_inputs.attention_mask
         | 
| 398 | 
            +
                        prompt_attention_mask = prompt_attention_mask.to(text_enc_device)
         | 
| 399 | 
            +
                        prompt_attention_mask = prompt_attention_mask.to(device)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                        prompt_embeds = self.text_encoder(
         | 
| 402 | 
            +
                            text_input_ids.to(text_enc_device), attention_mask=prompt_attention_mask
         | 
| 403 | 
            +
                        )
         | 
| 404 | 
            +
                        prompt_embeds = prompt_embeds[0]
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    if self.text_encoder is not None:
         | 
| 407 | 
            +
                        dtype = self.text_encoder.dtype
         | 
| 408 | 
            +
                    elif self.transformer is not None:
         | 
| 409 | 
            +
                        dtype = self.transformer.dtype
         | 
| 410 | 
            +
                    else:
         | 
| 411 | 
            +
                        dtype = None
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    bs_embed, seq_len, _ = prompt_embeds.shape
         | 
| 416 | 
            +
                    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
         | 
| 417 | 
            +
                    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 418 | 
            +
                    prompt_embeds = prompt_embeds.view(
         | 
| 419 | 
            +
                        bs_embed * num_images_per_prompt, seq_len, -1
         | 
| 420 | 
            +
                    )
         | 
| 421 | 
            +
                    prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
         | 
| 422 | 
            +
                    prompt_attention_mask = prompt_attention_mask.view(
         | 
| 423 | 
            +
                        bs_embed * num_images_per_prompt, -1
         | 
| 424 | 
            +
                    )
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    # get unconditional embeddings for classifier free guidance
         | 
| 427 | 
            +
                    if do_classifier_free_guidance and negative_prompt_embeds is None:
         | 
| 428 | 
            +
                        uncond_tokens = self._text_preprocessing(negative_prompt)
         | 
| 429 | 
            +
                        uncond_tokens = uncond_tokens * batch_size
         | 
| 430 | 
            +
                        max_length = prompt_embeds.shape[1]
         | 
| 431 | 
            +
                        uncond_input = self.tokenizer(
         | 
| 432 | 
            +
                            uncond_tokens,
         | 
| 433 | 
            +
                            padding="max_length",
         | 
| 434 | 
            +
                            max_length=max_length,
         | 
| 435 | 
            +
                            truncation=True,
         | 
| 436 | 
            +
                            return_attention_mask=True,
         | 
| 437 | 
            +
                            add_special_tokens=True,
         | 
| 438 | 
            +
                            return_tensors="pt",
         | 
| 439 | 
            +
                        )
         | 
| 440 | 
            +
                        negative_prompt_attention_mask = uncond_input.attention_mask
         | 
| 441 | 
            +
                        negative_prompt_attention_mask = negative_prompt_attention_mask.to(
         | 
| 442 | 
            +
                            text_enc_device
         | 
| 443 | 
            +
                        )
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                        negative_prompt_embeds = self.text_encoder(
         | 
| 446 | 
            +
                            uncond_input.input_ids.to(text_enc_device),
         | 
| 447 | 
            +
                            attention_mask=negative_prompt_attention_mask,
         | 
| 448 | 
            +
                        )
         | 
| 449 | 
            +
                        negative_prompt_embeds = negative_prompt_embeds[0]
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    if do_classifier_free_guidance:
         | 
| 452 | 
            +
                        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
         | 
| 453 | 
            +
                        seq_len = negative_prompt_embeds.shape[1]
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                        negative_prompt_embeds = negative_prompt_embeds.to(
         | 
| 456 | 
            +
                            dtype=dtype, device=device
         | 
| 457 | 
            +
                        )
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                        negative_prompt_embeds = negative_prompt_embeds.repeat(
         | 
| 460 | 
            +
                            1, num_images_per_prompt, 1
         | 
| 461 | 
            +
                        )
         | 
| 462 | 
            +
                        negative_prompt_embeds = negative_prompt_embeds.view(
         | 
| 463 | 
            +
                            batch_size * num_images_per_prompt, seq_len, -1
         | 
| 464 | 
            +
                        )
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                        negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
         | 
| 467 | 
            +
                            1, num_images_per_prompt
         | 
| 468 | 
            +
                        )
         | 
| 469 | 
            +
                        negative_prompt_attention_mask = negative_prompt_attention_mask.view(
         | 
| 470 | 
            +
                            bs_embed * num_images_per_prompt, -1
         | 
| 471 | 
            +
                        )
         | 
| 472 | 
            +
                    else:
         | 
| 473 | 
            +
                        negative_prompt_embeds = None
         | 
| 474 | 
            +
                        negative_prompt_attention_mask = None
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    return (
         | 
| 477 | 
            +
                        prompt_embeds,
         | 
| 478 | 
            +
                        prompt_attention_mask,
         | 
| 479 | 
            +
                        negative_prompt_embeds,
         | 
| 480 | 
            +
                        negative_prompt_attention_mask,
         | 
| 481 | 
            +
                    )
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
         | 
| 484 | 
            +
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 485 | 
            +
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
| 486 | 
            +
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         | 
| 487 | 
            +
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         | 
| 488 | 
            +
                    # and should be between [0, 1]
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    accepts_eta = "eta" in set(
         | 
| 491 | 
            +
                        inspect.signature(self.scheduler.step).parameters.keys()
         | 
| 492 | 
            +
                    )
         | 
| 493 | 
            +
                    extra_step_kwargs = {}
         | 
| 494 | 
            +
                    if accepts_eta:
         | 
| 495 | 
            +
                        extra_step_kwargs["eta"] = eta
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                    # check if the scheduler accepts generator
         | 
| 498 | 
            +
                    accepts_generator = "generator" in set(
         | 
| 499 | 
            +
                        inspect.signature(self.scheduler.step).parameters.keys()
         | 
| 500 | 
            +
                    )
         | 
| 501 | 
            +
                    if accepts_generator:
         | 
| 502 | 
            +
                        extra_step_kwargs["generator"] = generator
         | 
| 503 | 
            +
                    return extra_step_kwargs
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                def check_inputs(
         | 
| 506 | 
            +
                    self,
         | 
| 507 | 
            +
                    prompt,
         | 
| 508 | 
            +
                    height,
         | 
| 509 | 
            +
                    width,
         | 
| 510 | 
            +
                    negative_prompt,
         | 
| 511 | 
            +
                    prompt_embeds=None,
         | 
| 512 | 
            +
                    negative_prompt_embeds=None,
         | 
| 513 | 
            +
                    prompt_attention_mask=None,
         | 
| 514 | 
            +
                    negative_prompt_attention_mask=None,
         | 
| 515 | 
            +
                    enhance_prompt=False,
         | 
| 516 | 
            +
                ):
         | 
| 517 | 
            +
                    if height % 8 != 0 or width % 8 != 0:
         | 
| 518 | 
            +
                        raise ValueError(
         | 
| 519 | 
            +
                            f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
         | 
| 520 | 
            +
                        )
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    if prompt is not None and prompt_embeds is not None:
         | 
| 523 | 
            +
                        raise ValueError(
         | 
| 524 | 
            +
                            f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
         | 
| 525 | 
            +
                            " only forward one of the two."
         | 
| 526 | 
            +
                        )
         | 
| 527 | 
            +
                    elif prompt is None and prompt_embeds is None:
         | 
| 528 | 
            +
                        raise ValueError(
         | 
| 529 | 
            +
                            "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
         | 
| 530 | 
            +
                        )
         | 
| 531 | 
            +
                    elif prompt is not None and (
         | 
| 532 | 
            +
                        not isinstance(prompt, str) and not isinstance(prompt, list)
         | 
| 533 | 
            +
                    ):
         | 
| 534 | 
            +
                        raise ValueError(
         | 
| 535 | 
            +
                            f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
         | 
| 536 | 
            +
                        )
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                    if prompt is not None and negative_prompt_embeds is not None:
         | 
| 539 | 
            +
                        raise ValueError(
         | 
| 540 | 
            +
                            f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
         | 
| 541 | 
            +
                            f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
         | 
| 542 | 
            +
                        )
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                    if negative_prompt is not None and negative_prompt_embeds is not None:
         | 
| 545 | 
            +
                        raise ValueError(
         | 
| 546 | 
            +
                            f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
         | 
| 547 | 
            +
                            f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
         | 
| 548 | 
            +
                        )
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                    if prompt_embeds is not None and prompt_attention_mask is None:
         | 
| 551 | 
            +
                        raise ValueError(
         | 
| 552 | 
            +
                            "Must provide `prompt_attention_mask` when specifying `prompt_embeds`."
         | 
| 553 | 
            +
                        )
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                    if (
         | 
| 556 | 
            +
                        negative_prompt_embeds is not None
         | 
| 557 | 
            +
                        and negative_prompt_attention_mask is None
         | 
| 558 | 
            +
                    ):
         | 
| 559 | 
            +
                        raise ValueError(
         | 
| 560 | 
            +
                            "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`."
         | 
| 561 | 
            +
                        )
         | 
| 562 | 
            +
             | 
| 563 | 
            +
                    if prompt_embeds is not None and negative_prompt_embeds is not None:
         | 
| 564 | 
            +
                        if prompt_embeds.shape != negative_prompt_embeds.shape:
         | 
| 565 | 
            +
                            raise ValueError(
         | 
| 566 | 
            +
                                "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
         | 
| 567 | 
            +
                                f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
         | 
| 568 | 
            +
                                f" {negative_prompt_embeds.shape}."
         | 
| 569 | 
            +
                            )
         | 
| 570 | 
            +
                        if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
         | 
| 571 | 
            +
                            raise ValueError(
         | 
| 572 | 
            +
                                "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
         | 
| 573 | 
            +
                                f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
         | 
| 574 | 
            +
                                f" {negative_prompt_attention_mask.shape}."
         | 
| 575 | 
            +
                            )
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                    if enhance_prompt:
         | 
| 578 | 
            +
                        assert (
         | 
| 579 | 
            +
                            self.prompt_enhancer_image_caption_model is not None
         | 
| 580 | 
            +
                        ), "Image caption model must be initialized if enhance_prompt is True"
         | 
| 581 | 
            +
                        assert (
         | 
| 582 | 
            +
                            self.prompt_enhancer_image_caption_processor is not None
         | 
| 583 | 
            +
                        ), "Image caption processor must be initialized if enhance_prompt is True"
         | 
| 584 | 
            +
                        assert (
         | 
| 585 | 
            +
                            self.prompt_enhancer_llm_model is not None
         | 
| 586 | 
            +
                        ), "Text prompt enhancer model must be initialized if enhance_prompt is True"
         | 
| 587 | 
            +
                        assert (
         | 
| 588 | 
            +
                            self.prompt_enhancer_llm_tokenizer is not None
         | 
| 589 | 
            +
                        ), "Text prompt enhancer tokenizer must be initialized if enhance_prompt is True"
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                def _text_preprocessing(self, text):
         | 
| 592 | 
            +
                    if not isinstance(text, (tuple, list)):
         | 
| 593 | 
            +
                        text = [text]
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    def process(text: str):
         | 
| 596 | 
            +
                        text = text.strip()
         | 
| 597 | 
            +
                        return text
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    return [process(t) for t in text]
         | 
| 600 | 
            +
             | 
| 601 | 
            +
                @staticmethod
         | 
| 602 | 
            +
                def add_noise_to_image_conditioning_latents(
         | 
| 603 | 
            +
                    t: float,
         | 
| 604 | 
            +
                    init_latents: torch.Tensor,
         | 
| 605 | 
            +
                    latents: torch.Tensor,
         | 
| 606 | 
            +
                    noise_scale: float,
         | 
| 607 | 
            +
                    conditioning_mask: torch.Tensor,
         | 
| 608 | 
            +
                    generator,
         | 
| 609 | 
            +
                    eps=1e-6,
         | 
| 610 | 
            +
                ):
         | 
| 611 | 
            +
                    """
         | 
| 612 | 
            +
                    Add timestep-dependent noise to the hard-conditioning latents.
         | 
| 613 | 
            +
                    This helps with motion continuity, especially when conditioned on a single frame.
         | 
| 614 | 
            +
                    """
         | 
| 615 | 
            +
                    noise = randn_tensor(
         | 
| 616 | 
            +
                        latents.shape,
         | 
| 617 | 
            +
                        generator=generator,
         | 
| 618 | 
            +
                        device=latents.device,
         | 
| 619 | 
            +
                        dtype=latents.dtype,
         | 
| 620 | 
            +
                    )
         | 
| 621 | 
            +
                    # Add noise only to hard-conditioning latents (conditioning_mask = 1.0)
         | 
| 622 | 
            +
                    need_to_noise = (conditioning_mask > 1.0 - eps).unsqueeze(-1)
         | 
| 623 | 
            +
                    noised_latents = init_latents + noise_scale * noise * (t**2)
         | 
| 624 | 
            +
                    latents = torch.where(need_to_noise, noised_latents, latents)
         | 
| 625 | 
            +
                    return latents
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
         | 
| 628 | 
            +
                def prepare_latents(
         | 
| 629 | 
            +
                    self,
         | 
| 630 | 
            +
                    latents: torch.Tensor | None,
         | 
| 631 | 
            +
                    media_items: torch.Tensor | None,
         | 
| 632 | 
            +
                    timestep: float,
         | 
| 633 | 
            +
                    latent_shape: torch.Size | Tuple[Any, ...],
         | 
| 634 | 
            +
                    dtype: torch.dtype,
         | 
| 635 | 
            +
                    device: torch.device,
         | 
| 636 | 
            +
                    generator: torch.Generator | List[torch.Generator],
         | 
| 637 | 
            +
                    vae_per_channel_normalize: bool = True,
         | 
| 638 | 
            +
                ):
         | 
| 639 | 
            +
                    """
         | 
| 640 | 
            +
                    Prepare the initial latent tensor to be denoised.
         | 
| 641 | 
            +
                    The latents are either pure noise or a noised version of the encoded media items.
         | 
| 642 | 
            +
                    Args:
         | 
| 643 | 
            +
                        latents (`torch.FloatTensor` or `None`):
         | 
| 644 | 
            +
                            The latents to use (provided by the user) or `None` to create new latents.
         | 
| 645 | 
            +
                        media_items (`torch.FloatTensor` or `None`):
         | 
| 646 | 
            +
                            An image or video to be updated using img2img or vid2vid. The media item is encoded and noised.
         | 
| 647 | 
            +
                        timestep (`float`):
         | 
| 648 | 
            +
                            The timestep to noise the encoded media_items to.
         | 
| 649 | 
            +
                        latent_shape (`torch.Size`):
         | 
| 650 | 
            +
                            The target latent shape.
         | 
| 651 | 
            +
                        dtype (`torch.dtype`):
         | 
| 652 | 
            +
                            The target dtype.
         | 
| 653 | 
            +
                        device (`torch.device`):
         | 
| 654 | 
            +
                            The target device.
         | 
| 655 | 
            +
                        generator (`torch.Generator` or `List[torch.Generator]`):
         | 
| 656 | 
            +
                            Generator(s) to be used for the noising process.
         | 
| 657 | 
            +
                        vae_per_channel_normalize ('bool'):
         | 
| 658 | 
            +
                            When encoding the media_items, whether to normalize the latents per-channel.
         | 
| 659 | 
            +
                    Returns:
         | 
| 660 | 
            +
                        `torch.FloatTensor`: The latents to be used for the denoising process. This is a tensor of shape
         | 
| 661 | 
            +
                        (batch_size, num_channels, height, width).
         | 
| 662 | 
            +
                    """
         | 
| 663 | 
            +
                    if isinstance(generator, list) and len(generator) != latent_shape[0]:
         | 
| 664 | 
            +
                        raise ValueError(
         | 
| 665 | 
            +
                            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
         | 
| 666 | 
            +
                            f" size of {latent_shape[0]}. Make sure the batch size matches the length of the generators."
         | 
| 667 | 
            +
                        )
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                    # Initialize the latents with the given latents or encoded media item, if provided
         | 
| 670 | 
            +
                    assert (
         | 
| 671 | 
            +
                        latents is None or media_items is None
         | 
| 672 | 
            +
                    ), "Cannot provide both latents and media_items. Please provide only one of the two."
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                    assert (
         | 
| 675 | 
            +
                        latents is None and media_items is None or timestep < 1.0
         | 
| 676 | 
            +
                    ), "Input media_item or latents are provided, but they will be replaced with noise."
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                    if media_items is not None:
         | 
| 679 | 
            +
                        latents = vae_encode(
         | 
| 680 | 
            +
                            media_items.to(dtype=self.vae.dtype, device=self.vae.device),
         | 
| 681 | 
            +
                            self.vae,
         | 
| 682 | 
            +
                            vae_per_channel_normalize=vae_per_channel_normalize,
         | 
| 683 | 
            +
                        )
         | 
| 684 | 
            +
                    if latents is not None:
         | 
| 685 | 
            +
                        assert (
         | 
| 686 | 
            +
                            latents.shape == latent_shape
         | 
| 687 | 
            +
                        ), f"Latents have to be of shape {latent_shape} but are {latents.shape}."
         | 
| 688 | 
            +
                        latents = latents.to(device=device, dtype=dtype)
         | 
| 689 | 
            +
             | 
| 690 | 
            +
                    # For backward compatibility, generate in the "patchified" shape and rearrange
         | 
| 691 | 
            +
                    b, c, f, h, w = latent_shape
         | 
| 692 | 
            +
                    noise = randn_tensor(
         | 
| 693 | 
            +
                        (b, f * h * w, c), generator=generator, device=device, dtype=dtype
         | 
| 694 | 
            +
                    )
         | 
| 695 | 
            +
                    noise = rearrange(noise, "b (f h w) c -> b c f h w", f=f, h=h, w=w)
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 698 | 
            +
                    noise = noise * self.scheduler.init_noise_sigma
         | 
| 699 | 
            +
             | 
| 700 | 
            +
                    if latents is None:
         | 
| 701 | 
            +
                        latents = noise
         | 
| 702 | 
            +
                    else:
         | 
| 703 | 
            +
                        # Noise the latents to the required (first) timestep
         | 
| 704 | 
            +
                        latents = timestep * noise + (1 - timestep) * latents
         | 
| 705 | 
            +
             | 
| 706 | 
            +
                    return latents
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                @staticmethod
         | 
| 709 | 
            +
                def classify_height_width_bin(
         | 
| 710 | 
            +
                    height: int, width: int, ratios: dict
         | 
| 711 | 
            +
                ) -> Tuple[int, int]:
         | 
| 712 | 
            +
                    """Returns binned height and width."""
         | 
| 713 | 
            +
                    ar = float(height / width)
         | 
| 714 | 
            +
                    closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
         | 
| 715 | 
            +
                    default_hw = ratios[closest_ratio]
         | 
| 716 | 
            +
                    return int(default_hw[0]), int(default_hw[1])
         | 
| 717 | 
            +
             | 
| 718 | 
            +
                @staticmethod
         | 
| 719 | 
            +
                def resize_and_crop_tensor(
         | 
| 720 | 
            +
                    samples: torch.Tensor, new_width: int, new_height: int
         | 
| 721 | 
            +
                ) -> torch.Tensor:
         | 
| 722 | 
            +
                    n_frames, orig_height, orig_width = samples.shape[-3:]
         | 
| 723 | 
            +
             | 
| 724 | 
            +
                    # Check if resizing is needed
         | 
| 725 | 
            +
                    if orig_height != new_height or orig_width != new_width:
         | 
| 726 | 
            +
                        ratio = max(new_height / orig_height, new_width / orig_width)
         | 
| 727 | 
            +
                        resized_width = int(orig_width * ratio)
         | 
| 728 | 
            +
                        resized_height = int(orig_height * ratio)
         | 
| 729 | 
            +
             | 
| 730 | 
            +
                        # Resize
         | 
| 731 | 
            +
                        samples = LTXVideoPipeline.resize_tensor(
         | 
| 732 | 
            +
                            samples, resized_height, resized_width
         | 
| 733 | 
            +
                        )
         | 
| 734 | 
            +
             | 
| 735 | 
            +
                        # Center Crop
         | 
| 736 | 
            +
                        start_x = (resized_width - new_width) // 2
         | 
| 737 | 
            +
                        end_x = start_x + new_width
         | 
| 738 | 
            +
                        start_y = (resized_height - new_height) // 2
         | 
| 739 | 
            +
                        end_y = start_y + new_height
         | 
| 740 | 
            +
                        samples = samples[..., start_y:end_y, start_x:end_x]
         | 
| 741 | 
            +
             | 
| 742 | 
            +
                    return samples
         | 
| 743 | 
            +
             | 
| 744 | 
            +
                @staticmethod
         | 
| 745 | 
            +
                def resize_tensor(media_items, height, width):
         | 
| 746 | 
            +
                    n_frames = media_items.shape[2]
         | 
| 747 | 
            +
                    if media_items.shape[-2:] != (height, width):
         | 
| 748 | 
            +
                        media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
         | 
| 749 | 
            +
                        media_items = F.interpolate(
         | 
| 750 | 
            +
                            media_items,
         | 
| 751 | 
            +
                            size=(height, width),
         | 
| 752 | 
            +
                            mode="bilinear",
         | 
| 753 | 
            +
                            align_corners=False,
         | 
| 754 | 
            +
                        )
         | 
| 755 | 
            +
                        media_items = rearrange(media_items, "(b n) c h w -> b c n h w", n=n_frames)
         | 
| 756 | 
            +
                    return media_items
         | 
| 757 | 
            +
             | 
| 758 | 
            +
                @torch.no_grad()
         | 
| 759 | 
            +
                def __call__(
         | 
| 760 | 
            +
                    self,
         | 
| 761 | 
            +
                    height: int,
         | 
| 762 | 
            +
                    width: int,
         | 
| 763 | 
            +
                    num_frames: int,
         | 
| 764 | 
            +
                    frame_rate: float,
         | 
| 765 | 
            +
                    prompt: Union[str, List[str]] = None,
         | 
| 766 | 
            +
                    negative_prompt: str = "",
         | 
| 767 | 
            +
                    num_inference_steps: int = 20,
         | 
| 768 | 
            +
                    skip_initial_inference_steps: int = 0,
         | 
| 769 | 
            +
                    skip_final_inference_steps: int = 0,
         | 
| 770 | 
            +
                    timesteps: List[int] = None,
         | 
| 771 | 
            +
                    guidance_scale: Union[float, List[float]] = 4.5,
         | 
| 772 | 
            +
                    cfg_star_rescale: bool = False,
         | 
| 773 | 
            +
                    skip_layer_strategy: Optional[SkipLayerStrategy] = None,
         | 
| 774 | 
            +
                    skip_block_list: Optional[Union[List[List[int]], List[int]]] = None,
         | 
| 775 | 
            +
                    stg_scale: Union[float, List[float]] = 1.0,
         | 
| 776 | 
            +
                    rescaling_scale: Union[float, List[float]] = 0.7,
         | 
| 777 | 
            +
                    guidance_timesteps: Optional[List[int]] = None,
         | 
| 778 | 
            +
                    num_images_per_prompt: Optional[int] = 1,
         | 
| 779 | 
            +
                    eta: float = 0.0,
         | 
| 780 | 
            +
                    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
         | 
| 781 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 782 | 
            +
                    prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 783 | 
            +
                    prompt_attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 784 | 
            +
                    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
         | 
| 785 | 
            +
                    negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
         | 
| 786 | 
            +
                    output_type: Optional[str] = "pil",
         | 
| 787 | 
            +
                    return_dict: bool = True,
         | 
| 788 | 
            +
                    callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
         | 
| 789 | 
            +
                    conditioning_items: Optional[List[ConditioningItem]] = None,
         | 
| 790 | 
            +
                    decode_timestep: Union[List[float], float] = 0.0,
         | 
| 791 | 
            +
                    decode_noise_scale: Optional[List[float]] = None,
         | 
| 792 | 
            +
                    mixed_precision: bool = False,
         | 
| 793 | 
            +
                    offload_to_cpu: bool = False,
         | 
| 794 | 
            +
                    enhance_prompt: bool = False,
         | 
| 795 | 
            +
                    text_encoder_max_tokens: int = 256,
         | 
| 796 | 
            +
                    stochastic_sampling: bool = False,
         | 
| 797 | 
            +
                    media_items: Optional[torch.Tensor] = None,
         | 
| 798 | 
            +
                    **kwargs,
         | 
| 799 | 
            +
                ) -> Union[ImagePipelineOutput, Tuple]:
         | 
| 800 | 
            +
                    """
         | 
| 801 | 
            +
                    Function invoked when calling the pipeline for generation.
         | 
| 802 | 
            +
             | 
| 803 | 
            +
                    Args:
         | 
| 804 | 
            +
                        prompt (`str` or `List[str]`, *optional*):
         | 
| 805 | 
            +
                            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
         | 
| 806 | 
            +
                            instead.
         | 
| 807 | 
            +
                        negative_prompt (`str` or `List[str]`, *optional*):
         | 
| 808 | 
            +
                            The prompt or prompts not to guide the image generation. If not defined, one has to pass
         | 
| 809 | 
            +
                            `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
         | 
| 810 | 
            +
                            less than `1`).
         | 
| 811 | 
            +
                        num_inference_steps (`int`, *optional*, defaults to 100):
         | 
| 812 | 
            +
                            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
         | 
| 813 | 
            +
                            expense of slower inference. If `timesteps` is provided, this parameter is ignored.
         | 
| 814 | 
            +
                        skip_initial_inference_steps (`int`, *optional*, defaults to 0):
         | 
| 815 | 
            +
                            The number of initial timesteps to skip. After calculating the timesteps, this number of timesteps will
         | 
| 816 | 
            +
                            be removed from the beginning of the timesteps list. Meaning the highest-timesteps values will not run.
         | 
| 817 | 
            +
                        skip_final_inference_steps (`int`, *optional*, defaults to 0):
         | 
| 818 | 
            +
                            The number of final timesteps to skip. After calculating the timesteps, this number of timesteps will
         | 
| 819 | 
            +
                            be removed from the end of the timesteps list. Meaning the lowest-timesteps values will not run.
         | 
| 820 | 
            +
                        timesteps (`List[int]`, *optional*):
         | 
| 821 | 
            +
                            Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
         | 
| 822 | 
            +
                            timesteps are used. Must be in descending order.
         | 
| 823 | 
            +
                        guidance_scale (`float`, *optional*, defaults to 4.5):
         | 
| 824 | 
            +
                            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
         | 
| 825 | 
            +
                            `guidance_scale` is defined as `w` of equation 2. of [Imagen
         | 
| 826 | 
            +
                            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
         | 
| 827 | 
            +
                            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
         | 
| 828 | 
            +
                            usually at the expense of lower image quality.
         | 
| 829 | 
            +
                        cfg_star_rescale (`bool`, *optional*, defaults to `False`):
         | 
| 830 | 
            +
                            If set to `True`, applies the CFG star rescale. Scales the negative prediction according to dot
         | 
| 831 | 
            +
                            product between positive and negative.
         | 
| 832 | 
            +
                        num_images_per_prompt (`int`, *optional*, defaults to 1):
         | 
| 833 | 
            +
                            The number of images to generate per prompt.
         | 
| 834 | 
            +
                        height (`int`, *optional*, defaults to self.unet.config.sample_size):
         | 
| 835 | 
            +
                            The height in pixels of the generated image.
         | 
| 836 | 
            +
                        width (`int`, *optional*, defaults to self.unet.config.sample_size):
         | 
| 837 | 
            +
                            The width in pixels of the generated image.
         | 
| 838 | 
            +
                        eta (`float`, *optional*, defaults to 0.0):
         | 
| 839 | 
            +
                            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
         | 
| 840 | 
            +
                            [`schedulers.DDIMScheduler`], will be ignored for others.
         | 
| 841 | 
            +
                        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
         | 
| 842 | 
            +
                            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
         | 
| 843 | 
            +
                            to make generation deterministic.
         | 
| 844 | 
            +
                        latents (`torch.FloatTensor`, *optional*):
         | 
| 845 | 
            +
                            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
         | 
| 846 | 
            +
                            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
         | 
| 847 | 
            +
                            tensor will ge generated by sampling using the supplied random `generator`.
         | 
| 848 | 
            +
                        prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 849 | 
            +
                            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
         | 
| 850 | 
            +
                            provided, text embeddings will be generated from `prompt` input argument.
         | 
| 851 | 
            +
                        prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings.
         | 
| 852 | 
            +
                        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
         | 
| 853 | 
            +
                            Pre-generated negative text embeddings. This negative prompt should be "". If not
         | 
| 854 | 
            +
                            provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
         | 
| 855 | 
            +
                        negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
         | 
| 856 | 
            +
                            Pre-generated attention mask for negative text embeddings.
         | 
| 857 | 
            +
                        output_type (`str`, *optional*, defaults to `"pil"`):
         | 
| 858 | 
            +
                            The output format of the generate image. Choose between
         | 
| 859 | 
            +
                            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
         | 
| 860 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 861 | 
            +
                            Whether to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
         | 
| 862 | 
            +
                        callback_on_step_end (`Callable`, *optional*):
         | 
| 863 | 
            +
                            A function that calls at the end of each denoising steps during the inference. The function is called
         | 
| 864 | 
            +
                            with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
         | 
| 865 | 
            +
                            callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
         | 
| 866 | 
            +
                            `callback_on_step_end_tensor_inputs`.
         | 
| 867 | 
            +
                        use_resolution_binning (`bool` defaults to `True`):
         | 
| 868 | 
            +
                            If set to `True`, the requested height and width are first mapped to the closest resolutions using
         | 
| 869 | 
            +
                            `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
         | 
| 870 | 
            +
                            the requested resolution. Useful for generating non-square images.
         | 
| 871 | 
            +
                        enhance_prompt (`bool`, *optional*, defaults to `False`):
         | 
| 872 | 
            +
                            If set to `True`, the prompt is enhanced using a LLM model.
         | 
| 873 | 
            +
                        text_encoder_max_tokens (`int`, *optional*, defaults to `256`):
         | 
| 874 | 
            +
                            The maximum number of tokens to use for the text encoder.
         | 
| 875 | 
            +
                        stochastic_sampling (`bool`, *optional*, defaults to `False`):
         | 
| 876 | 
            +
                            If set to `True`, the sampling is stochastic. If set to `False`, the sampling is deterministic.
         | 
| 877 | 
            +
                        media_items ('torch.Tensor', *optional*):
         | 
| 878 | 
            +
                            The input media item used for image-to-image / video-to-video.
         | 
| 879 | 
            +
                    Examples:
         | 
| 880 | 
            +
             | 
| 881 | 
            +
                    Returns:
         | 
| 882 | 
            +
                        [`~pipelines.ImagePipelineOutput`] or `tuple`:
         | 
| 883 | 
            +
                            If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
         | 
| 884 | 
            +
                            returned where the first element is a list with the generated images
         | 
| 885 | 
            +
                    """
         | 
| 886 | 
            +
                    if "mask_feature" in kwargs:
         | 
| 887 | 
            +
                        deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
         | 
| 888 | 
            +
                        deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
         | 
| 889 | 
            +
             | 
| 890 | 
            +
                    is_video = kwargs.get("is_video", False)
         | 
| 891 | 
            +
                    self.check_inputs(
         | 
| 892 | 
            +
                        prompt,
         | 
| 893 | 
            +
                        height,
         | 
| 894 | 
            +
                        width,
         | 
| 895 | 
            +
                        negative_prompt,
         | 
| 896 | 
            +
                        prompt_embeds,
         | 
| 897 | 
            +
                        negative_prompt_embeds,
         | 
| 898 | 
            +
                        prompt_attention_mask,
         | 
| 899 | 
            +
                        negative_prompt_attention_mask,
         | 
| 900 | 
            +
                    )
         | 
| 901 | 
            +
             | 
| 902 | 
            +
                    # 2. Default height and width to transformer
         | 
| 903 | 
            +
                    if prompt is not None and isinstance(prompt, str):
         | 
| 904 | 
            +
                        batch_size = 1
         | 
| 905 | 
            +
                    elif prompt is not None and isinstance(prompt, list):
         | 
| 906 | 
            +
                        batch_size = len(prompt)
         | 
| 907 | 
            +
                    else:
         | 
| 908 | 
            +
                        batch_size = prompt_embeds.shape[0]
         | 
| 909 | 
            +
             | 
| 910 | 
            +
                    device = self._execution_device
         | 
| 911 | 
            +
             | 
| 912 | 
            +
                    self.video_scale_factor = self.video_scale_factor if is_video else 1
         | 
| 913 | 
            +
                    vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", True)
         | 
| 914 | 
            +
                    image_cond_noise_scale = kwargs.get("image_cond_noise_scale", 0.0)
         | 
| 915 | 
            +
             | 
| 916 | 
            +
                    latent_height = height // self.vae_scale_factor
         | 
| 917 | 
            +
                    latent_width = width // self.vae_scale_factor
         | 
| 918 | 
            +
                    latent_num_frames = num_frames // self.video_scale_factor
         | 
| 919 | 
            +
                    if isinstance(self.vae, CausalVideoAutoencoder) and is_video:
         | 
| 920 | 
            +
                        latent_num_frames += 1
         | 
| 921 | 
            +
                    latent_shape = (
         | 
| 922 | 
            +
                        batch_size * num_images_per_prompt,
         | 
| 923 | 
            +
                        self.transformer.config.in_channels,
         | 
| 924 | 
            +
                        latent_num_frames,
         | 
| 925 | 
            +
                        latent_height,
         | 
| 926 | 
            +
                        latent_width,
         | 
| 927 | 
            +
                    )
         | 
| 928 | 
            +
             | 
| 929 | 
            +
                    # Prepare the list of denoising time-steps
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                    retrieve_timesteps_kwargs = {}
         | 
| 932 | 
            +
                    if isinstance(self.scheduler, TimestepShifter):
         | 
| 933 | 
            +
                        retrieve_timesteps_kwargs["samples_shape"] = latent_shape
         | 
| 934 | 
            +
             | 
| 935 | 
            +
                    assert (
         | 
| 936 | 
            +
                        skip_initial_inference_steps == 0
         | 
| 937 | 
            +
                        or latents is not None
         | 
| 938 | 
            +
                        or media_items is not None
         | 
| 939 | 
            +
                    ), (
         | 
| 940 | 
            +
                        f"skip_initial_inference_steps ({skip_initial_inference_steps}) is used for image-to-image/video-to-video - "
         | 
| 941 | 
            +
                        "media_item or latents should be provided."
         | 
| 942 | 
            +
                    )
         | 
| 943 | 
            +
             | 
| 944 | 
            +
                    timesteps, num_inference_steps = retrieve_timesteps(
         | 
| 945 | 
            +
                        self.scheduler,
         | 
| 946 | 
            +
                        num_inference_steps,
         | 
| 947 | 
            +
                        device,
         | 
| 948 | 
            +
                        timesteps,
         | 
| 949 | 
            +
                        skip_initial_inference_steps=skip_initial_inference_steps,
         | 
| 950 | 
            +
                        skip_final_inference_steps=skip_final_inference_steps,
         | 
| 951 | 
            +
                        **retrieve_timesteps_kwargs,
         | 
| 952 | 
            +
                    )
         | 
| 953 | 
            +
             | 
| 954 | 
            +
                    if self.allowed_inference_steps is not None:
         | 
| 955 | 
            +
                        for timestep in [round(x, 4) for x in timesteps.tolist()]:
         | 
| 956 | 
            +
                            assert (
         | 
| 957 | 
            +
                                timestep in self.allowed_inference_steps
         | 
| 958 | 
            +
                            ), f"Invalid inference timestep {timestep}. Allowed timesteps are {self.allowed_inference_steps}."
         | 
| 959 | 
            +
             | 
| 960 | 
            +
                    if guidance_timesteps:
         | 
| 961 | 
            +
                        guidance_mapping = []
         | 
| 962 | 
            +
                        for timestep in timesteps:
         | 
| 963 | 
            +
                            indices = [
         | 
| 964 | 
            +
                                i for i, val in enumerate(guidance_timesteps) if val <= timestep
         | 
| 965 | 
            +
                            ]
         | 
| 966 | 
            +
                            # assert len(indices) > 0, f"No guidance timestep found for {timestep}"
         | 
| 967 | 
            +
                            guidance_mapping.append(
         | 
| 968 | 
            +
                                indices[0] if len(indices) > 0 else (len(guidance_timesteps) - 1)
         | 
| 969 | 
            +
                            )
         | 
| 970 | 
            +
             | 
| 971 | 
            +
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 972 | 
            +
                    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 973 | 
            +
                    # corresponds to doing no classifier free guidance.
         | 
| 974 | 
            +
                    if not isinstance(guidance_scale, List):
         | 
| 975 | 
            +
                        guidance_scale = [guidance_scale] * len(timesteps)
         | 
| 976 | 
            +
                    else:
         | 
| 977 | 
            +
                        guidance_scale = [
         | 
| 978 | 
            +
                            guidance_scale[guidance_mapping[i]] for i in range(len(timesteps))
         | 
| 979 | 
            +
                        ]
         | 
| 980 | 
            +
             | 
| 981 | 
            +
                    # For simplicity, we are using a constant num_conds for all timesteps, so we need to zero
         | 
| 982 | 
            +
                    # out cases where the guidance scale should not be applied.
         | 
| 983 | 
            +
                    guidance_scale = [x if x > 1.0 else 0.0 for x in guidance_scale]
         | 
| 984 | 
            +
             | 
| 985 | 
            +
                    if not isinstance(stg_scale, List):
         | 
| 986 | 
            +
                        stg_scale = [stg_scale] * len(timesteps)
         | 
| 987 | 
            +
                    else:
         | 
| 988 | 
            +
                        stg_scale = [stg_scale[guidance_mapping[i]] for i in range(len(timesteps))]
         | 
| 989 | 
            +
             | 
| 990 | 
            +
                    if not isinstance(rescaling_scale, List):
         | 
| 991 | 
            +
                        rescaling_scale = [rescaling_scale] * len(timesteps)
         | 
| 992 | 
            +
                    else:
         | 
| 993 | 
            +
                        rescaling_scale = [
         | 
| 994 | 
            +
                            rescaling_scale[guidance_mapping[i]] for i in range(len(timesteps))
         | 
| 995 | 
            +
                        ]
         | 
| 996 | 
            +
             | 
| 997 | 
            +
                    do_classifier_free_guidance = any(x > 1.0 for x in guidance_scale)
         | 
| 998 | 
            +
                    do_spatio_temporal_guidance = any(x > 0.0 for x in stg_scale)
         | 
| 999 | 
            +
                    do_rescaling = any(x != 1.0 for x in rescaling_scale)
         | 
| 1000 | 
            +
             | 
| 1001 | 
            +
                    num_conds = 1
         | 
| 1002 | 
            +
                    if do_classifier_free_guidance:
         | 
| 1003 | 
            +
                        num_conds += 1
         | 
| 1004 | 
            +
                    if do_spatio_temporal_guidance:
         | 
| 1005 | 
            +
                        num_conds += 1
         | 
| 1006 | 
            +
             | 
| 1007 | 
            +
                    # Normalize skip_block_list to always be None or a list of lists matching timesteps
         | 
| 1008 | 
            +
                    if skip_block_list is not None:
         | 
| 1009 | 
            +
                        # Convert single list to list of lists if needed
         | 
| 1010 | 
            +
                        if len(skip_block_list) == 0 or not isinstance(skip_block_list[0], list):
         | 
| 1011 | 
            +
                            skip_block_list = [skip_block_list] * len(timesteps)
         | 
| 1012 | 
            +
                        else:
         | 
| 1013 | 
            +
                            new_skip_block_list = []
         | 
| 1014 | 
            +
                            for i, timestep in enumerate(timesteps):
         | 
| 1015 | 
            +
                                new_skip_block_list.append(skip_block_list[guidance_mapping[i]])
         | 
| 1016 | 
            +
                            skip_block_list = new_skip_block_list
         | 
| 1017 | 
            +
             | 
| 1018 | 
            +
                    # Prepare skip layer masks
         | 
| 1019 | 
            +
                    skip_layer_masks: Optional[List[torch.Tensor]] = None
         | 
| 1020 | 
            +
                    if do_spatio_temporal_guidance:
         | 
| 1021 | 
            +
                        if skip_block_list is not None:
         | 
| 1022 | 
            +
                            skip_layer_masks = [
         | 
| 1023 | 
            +
                                self.transformer.create_skip_layer_mask(
         | 
| 1024 | 
            +
                                    batch_size, num_conds, num_conds - 1, skip_blocks
         | 
| 1025 | 
            +
                                )
         | 
| 1026 | 
            +
                                for skip_blocks in skip_block_list
         | 
| 1027 | 
            +
                            ]
         | 
| 1028 | 
            +
             | 
| 1029 | 
            +
                    if enhance_prompt:
         | 
| 1030 | 
            +
                        self.prompt_enhancer_image_caption_model = (
         | 
| 1031 | 
            +
                            self.prompt_enhancer_image_caption_model.to(self._execution_device)
         | 
| 1032 | 
            +
                        )
         | 
| 1033 | 
            +
                        self.prompt_enhancer_llm_model = self.prompt_enhancer_llm_model.to(
         | 
| 1034 | 
            +
                            self._execution_device
         | 
| 1035 | 
            +
                        )
         | 
| 1036 | 
            +
             | 
| 1037 | 
            +
                        prompt = generate_cinematic_prompt(
         | 
| 1038 | 
            +
                            self.prompt_enhancer_image_caption_model,
         | 
| 1039 | 
            +
                            self.prompt_enhancer_image_caption_processor,
         | 
| 1040 | 
            +
                            self.prompt_enhancer_llm_model,
         | 
| 1041 | 
            +
                            self.prompt_enhancer_llm_tokenizer,
         | 
| 1042 | 
            +
                            prompt,
         | 
| 1043 | 
            +
                            conditioning_items,
         | 
| 1044 | 
            +
                            max_new_tokens=text_encoder_max_tokens,
         | 
| 1045 | 
            +
                        )
         | 
| 1046 | 
            +
             | 
| 1047 | 
            +
                    # 3. Encode input prompt
         | 
| 1048 | 
            +
                    if self.text_encoder is not None:
         | 
| 1049 | 
            +
                        self.text_encoder = self.text_encoder.to(self._execution_device)
         | 
| 1050 | 
            +
             | 
| 1051 | 
            +
                    (
         | 
| 1052 | 
            +
                        prompt_embeds,
         | 
| 1053 | 
            +
                        prompt_attention_mask,
         | 
| 1054 | 
            +
                        negative_prompt_embeds,
         | 
| 1055 | 
            +
                        negative_prompt_attention_mask,
         | 
| 1056 | 
            +
                    ) = self.encode_prompt(
         | 
| 1057 | 
            +
                        prompt,
         | 
| 1058 | 
            +
                        do_classifier_free_guidance,
         | 
| 1059 | 
            +
                        negative_prompt=negative_prompt,
         | 
| 1060 | 
            +
                        num_images_per_prompt=num_images_per_prompt,
         | 
| 1061 | 
            +
                        device=device,
         | 
| 1062 | 
            +
                        prompt_embeds=prompt_embeds,
         | 
| 1063 | 
            +
                        negative_prompt_embeds=negative_prompt_embeds,
         | 
| 1064 | 
            +
                        prompt_attention_mask=prompt_attention_mask,
         | 
| 1065 | 
            +
                        negative_prompt_attention_mask=negative_prompt_attention_mask,
         | 
| 1066 | 
            +
                        text_encoder_max_tokens=text_encoder_max_tokens,
         | 
| 1067 | 
            +
                    )
         | 
| 1068 | 
            +
             | 
| 1069 | 
            +
                    if offload_to_cpu and self.text_encoder is not None:
         | 
| 1070 | 
            +
                        self.text_encoder = self.text_encoder.cpu()
         | 
| 1071 | 
            +
             | 
| 1072 | 
            +
                    self.transformer = self.transformer.to(self._execution_device)
         | 
| 1073 | 
            +
             | 
| 1074 | 
            +
                    prompt_embeds_batch = prompt_embeds
         | 
| 1075 | 
            +
                    prompt_attention_mask_batch = prompt_attention_mask
         | 
| 1076 | 
            +
                    if do_classifier_free_guidance:
         | 
| 1077 | 
            +
                        prompt_embeds_batch = torch.cat(
         | 
| 1078 | 
            +
                            [negative_prompt_embeds, prompt_embeds], dim=0
         | 
| 1079 | 
            +
                        )
         | 
| 1080 | 
            +
                        prompt_attention_mask_batch = torch.cat(
         | 
| 1081 | 
            +
                            [negative_prompt_attention_mask, prompt_attention_mask], dim=0
         | 
| 1082 | 
            +
                        )
         | 
| 1083 | 
            +
                    if do_spatio_temporal_guidance:
         | 
| 1084 | 
            +
                        prompt_embeds_batch = torch.cat([prompt_embeds_batch, prompt_embeds], dim=0)
         | 
| 1085 | 
            +
                        prompt_attention_mask_batch = torch.cat(
         | 
| 1086 | 
            +
                            [
         | 
| 1087 | 
            +
                                prompt_attention_mask_batch,
         | 
| 1088 | 
            +
                                prompt_attention_mask,
         | 
| 1089 | 
            +
                            ],
         | 
| 1090 | 
            +
                            dim=0,
         | 
| 1091 | 
            +
                        )
         | 
| 1092 | 
            +
             | 
| 1093 | 
            +
                    # 4. Prepare the initial latents using the provided media and conditioning items
         | 
| 1094 | 
            +
             | 
| 1095 | 
            +
                    # Prepare the initial latents tensor, shape = (b, c, f, h, w)
         | 
| 1096 | 
            +
                    latents = self.prepare_latents(
         | 
| 1097 | 
            +
                        latents=latents,
         | 
| 1098 | 
            +
                        media_items=media_items,
         | 
| 1099 | 
            +
                        timestep=timesteps[0],
         | 
| 1100 | 
            +
                        latent_shape=latent_shape,
         | 
| 1101 | 
            +
                        dtype=prompt_embeds_batch.dtype,
         | 
| 1102 | 
            +
                        device=device,
         | 
| 1103 | 
            +
                        generator=generator,
         | 
| 1104 | 
            +
                        vae_per_channel_normalize=vae_per_channel_normalize,
         | 
| 1105 | 
            +
                    )
         | 
| 1106 | 
            +
             | 
| 1107 | 
            +
                    # Update the latents with the conditioning items and patchify them into (b, n, c)
         | 
| 1108 | 
            +
                    latents, pixel_coords, conditioning_mask, num_cond_latents = (
         | 
| 1109 | 
            +
                        self.prepare_conditioning(
         | 
| 1110 | 
            +
                            conditioning_items=conditioning_items,
         | 
| 1111 | 
            +
                            init_latents=latents,
         | 
| 1112 | 
            +
                            num_frames=num_frames,
         | 
| 1113 | 
            +
                            height=height,
         | 
| 1114 | 
            +
                            width=width,
         | 
| 1115 | 
            +
                            vae_per_channel_normalize=vae_per_channel_normalize,
         | 
| 1116 | 
            +
                            generator=generator,
         | 
| 1117 | 
            +
                        )
         | 
| 1118 | 
            +
                    )
         | 
| 1119 | 
            +
                    init_latents = latents.clone()  # Used for image_cond_noise_update
         | 
| 1120 | 
            +
             | 
| 1121 | 
            +
                    pixel_coords = torch.cat([pixel_coords] * num_conds)
         | 
| 1122 | 
            +
                    orig_conditioning_mask = conditioning_mask
         | 
| 1123 | 
            +
                    if conditioning_mask is not None and is_video:
         | 
| 1124 | 
            +
                        assert num_images_per_prompt == 1
         | 
| 1125 | 
            +
                        conditioning_mask = torch.cat([conditioning_mask] * num_conds)
         | 
| 1126 | 
            +
                    fractional_coords = pixel_coords.to(torch.float32)
         | 
| 1127 | 
            +
                    fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
         | 
| 1128 | 
            +
             | 
| 1129 | 
            +
                    # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
         | 
| 1130 | 
            +
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 1131 | 
            +
             | 
| 1132 | 
            +
                    # 7. Denoising loop
         | 
| 1133 | 
            +
                    num_warmup_steps = max(
         | 
| 1134 | 
            +
                        len(timesteps) - num_inference_steps * self.scheduler.order, 0
         | 
| 1135 | 
            +
                    )
         | 
| 1136 | 
            +
             | 
| 1137 | 
            +
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 1138 | 
            +
                        for i, t in enumerate(timesteps):
         | 
| 1139 | 
            +
                            if conditioning_mask is not None and image_cond_noise_scale > 0.0:
         | 
| 1140 | 
            +
                                latents = self.add_noise_to_image_conditioning_latents(
         | 
| 1141 | 
            +
                                    t,
         | 
| 1142 | 
            +
                                    init_latents,
         | 
| 1143 | 
            +
                                    latents,
         | 
| 1144 | 
            +
                                    image_cond_noise_scale,
         | 
| 1145 | 
            +
                                    orig_conditioning_mask,
         | 
| 1146 | 
            +
                                    generator,
         | 
| 1147 | 
            +
                                )
         | 
| 1148 | 
            +
             | 
| 1149 | 
            +
                            latent_model_input = (
         | 
| 1150 | 
            +
                                torch.cat([latents] * num_conds) if num_conds > 1 else latents
         | 
| 1151 | 
            +
                            )
         | 
| 1152 | 
            +
                            latent_model_input = self.scheduler.scale_model_input(
         | 
| 1153 | 
            +
                                latent_model_input, t
         | 
| 1154 | 
            +
                            )
         | 
| 1155 | 
            +
             | 
| 1156 | 
            +
                            current_timestep = t
         | 
| 1157 | 
            +
                            if not torch.is_tensor(current_timestep):
         | 
| 1158 | 
            +
                                # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
         | 
| 1159 | 
            +
                                # This would be a good case for the `match` statement (Python 3.10+)
         | 
| 1160 | 
            +
                                is_mps = latent_model_input.device.type == "mps"
         | 
| 1161 | 
            +
                                if isinstance(current_timestep, float):
         | 
| 1162 | 
            +
                                    dtype = torch.float32 if is_mps else torch.float64
         | 
| 1163 | 
            +
                                else:
         | 
| 1164 | 
            +
                                    dtype = torch.int32 if is_mps else torch.int64
         | 
| 1165 | 
            +
                                current_timestep = torch.tensor(
         | 
| 1166 | 
            +
                                    [current_timestep],
         | 
| 1167 | 
            +
                                    dtype=dtype,
         | 
| 1168 | 
            +
                                    device=latent_model_input.device,
         | 
| 1169 | 
            +
                                )
         | 
| 1170 | 
            +
                            elif len(current_timestep.shape) == 0:
         | 
| 1171 | 
            +
                                current_timestep = current_timestep[None].to(
         | 
| 1172 | 
            +
                                    latent_model_input.device
         | 
| 1173 | 
            +
                                )
         | 
| 1174 | 
            +
                            # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 1175 | 
            +
                            current_timestep = current_timestep.expand(
         | 
| 1176 | 
            +
                                latent_model_input.shape[0]
         | 
| 1177 | 
            +
                            ).unsqueeze(-1)
         | 
| 1178 | 
            +
             | 
| 1179 | 
            +
                            if conditioning_mask is not None:
         | 
| 1180 | 
            +
                                # Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask)
         | 
| 1181 | 
            +
                                # and will start to be denoised when the current timestep is lower than their conditioning timestep.
         | 
| 1182 | 
            +
                                current_timestep = torch.min(
         | 
| 1183 | 
            +
                                    current_timestep, 1.0 - conditioning_mask
         | 
| 1184 | 
            +
                                )
         | 
| 1185 | 
            +
             | 
| 1186 | 
            +
                            # Choose the appropriate context manager based on `mixed_precision`
         | 
| 1187 | 
            +
                            if mixed_precision:
         | 
| 1188 | 
            +
                                context_manager = torch.autocast(device.type, dtype=torch.bfloat16)
         | 
| 1189 | 
            +
                            else:
         | 
| 1190 | 
            +
                                context_manager = nullcontext()  # Dummy context manager
         | 
| 1191 | 
            +
             | 
| 1192 | 
            +
                            # predict noise model_output
         | 
| 1193 | 
            +
                            with context_manager:
         | 
| 1194 | 
            +
                                noise_pred = self.transformer(
         | 
| 1195 | 
            +
                                    latent_model_input.to(self.transformer.dtype),
         | 
| 1196 | 
            +
                                    indices_grid=fractional_coords,
         | 
| 1197 | 
            +
                                    encoder_hidden_states=prompt_embeds_batch.to(
         | 
| 1198 | 
            +
                                        self.transformer.dtype
         | 
| 1199 | 
            +
                                    ),
         | 
| 1200 | 
            +
                                    encoder_attention_mask=prompt_attention_mask_batch,
         | 
| 1201 | 
            +
                                    timestep=current_timestep,
         | 
| 1202 | 
            +
                                    skip_layer_mask=(
         | 
| 1203 | 
            +
                                        skip_layer_masks[i]
         | 
| 1204 | 
            +
                                        if skip_layer_masks is not None
         | 
| 1205 | 
            +
                                        else None
         | 
| 1206 | 
            +
                                    ),
         | 
| 1207 | 
            +
                                    skip_layer_strategy=skip_layer_strategy,
         | 
| 1208 | 
            +
                                    return_dict=False,
         | 
| 1209 | 
            +
                                )[0]
         | 
| 1210 | 
            +
             | 
| 1211 | 
            +
                            # perform guidance
         | 
| 1212 | 
            +
                            if do_spatio_temporal_guidance:
         | 
| 1213 | 
            +
                                noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(
         | 
| 1214 | 
            +
                                    num_conds
         | 
| 1215 | 
            +
                                )[-2:]
         | 
| 1216 | 
            +
                            if do_classifier_free_guidance:
         | 
| 1217 | 
            +
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_conds)[:2]
         | 
| 1218 | 
            +
             | 
| 1219 | 
            +
                                if cfg_star_rescale:
         | 
| 1220 | 
            +
                                    # Rescales the unconditional noise prediction using the projection of the conditional prediction onto it:
         | 
| 1221 | 
            +
                                    # α = (⟨ε_text, ε_uncond⟩ / ||ε_uncond||²), then ε_uncond ← α * ε_uncond
         | 
| 1222 | 
            +
                                    # where ε_text is the conditional noise prediction and ε_uncond is the unconditional one.
         | 
| 1223 | 
            +
                                    positive_flat = noise_pred_text.view(batch_size, -1)
         | 
| 1224 | 
            +
                                    negative_flat = noise_pred_uncond.view(batch_size, -1)
         | 
| 1225 | 
            +
                                    dot_product = torch.sum(
         | 
| 1226 | 
            +
                                        positive_flat * negative_flat, dim=1, keepdim=True
         | 
| 1227 | 
            +
                                    )
         | 
| 1228 | 
            +
                                    squared_norm = (
         | 
| 1229 | 
            +
                                        torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
         | 
| 1230 | 
            +
                                    )
         | 
| 1231 | 
            +
                                    alpha = dot_product / squared_norm
         | 
| 1232 | 
            +
                                    noise_pred_uncond = alpha * noise_pred_uncond
         | 
| 1233 | 
            +
             | 
| 1234 | 
            +
                                noise_pred = noise_pred_uncond + guidance_scale[i] * (
         | 
| 1235 | 
            +
                                    noise_pred_text - noise_pred_uncond
         | 
| 1236 | 
            +
                                )
         | 
| 1237 | 
            +
                            elif do_spatio_temporal_guidance:
         | 
| 1238 | 
            +
                                noise_pred = noise_pred_text
         | 
| 1239 | 
            +
                            if do_spatio_temporal_guidance:
         | 
| 1240 | 
            +
                                noise_pred = noise_pred + stg_scale[i] * (
         | 
| 1241 | 
            +
                                    noise_pred_text - noise_pred_text_perturb
         | 
| 1242 | 
            +
                                )
         | 
| 1243 | 
            +
                                if do_rescaling and stg_scale[i] > 0.0:
         | 
| 1244 | 
            +
                                    noise_pred_text_std = noise_pred_text.view(batch_size, -1).std(
         | 
| 1245 | 
            +
                                        dim=1, keepdim=True
         | 
| 1246 | 
            +
                                    )
         | 
| 1247 | 
            +
                                    noise_pred_std = noise_pred.view(batch_size, -1).std(
         | 
| 1248 | 
            +
                                        dim=1, keepdim=True
         | 
| 1249 | 
            +
                                    )
         | 
| 1250 | 
            +
             | 
| 1251 | 
            +
                                    factor = noise_pred_text_std / noise_pred_std
         | 
| 1252 | 
            +
                                    factor = rescaling_scale[i] * factor + (1 - rescaling_scale[i])
         | 
| 1253 | 
            +
             | 
| 1254 | 
            +
                                    noise_pred = noise_pred * factor.view(batch_size, 1, 1)
         | 
| 1255 | 
            +
             | 
| 1256 | 
            +
                            current_timestep = current_timestep[:1]
         | 
| 1257 | 
            +
                            # learned sigma
         | 
| 1258 | 
            +
                            if (
         | 
| 1259 | 
            +
                                self.transformer.config.out_channels // 2
         | 
| 1260 | 
            +
                                == self.transformer.config.in_channels
         | 
| 1261 | 
            +
                            ):
         | 
| 1262 | 
            +
                                noise_pred = noise_pred.chunk(2, dim=1)[0]
         | 
| 1263 | 
            +
             | 
| 1264 | 
            +
                            # compute previous image: x_t -> x_t-1
         | 
| 1265 | 
            +
                            latents = self.denoising_step(
         | 
| 1266 | 
            +
                                latents,
         | 
| 1267 | 
            +
                                noise_pred,
         | 
| 1268 | 
            +
                                current_timestep,
         | 
| 1269 | 
            +
                                orig_conditioning_mask,
         | 
| 1270 | 
            +
                                t,
         | 
| 1271 | 
            +
                                extra_step_kwargs,
         | 
| 1272 | 
            +
                                stochastic_sampling=stochastic_sampling,
         | 
| 1273 | 
            +
                            )
         | 
| 1274 | 
            +
             | 
| 1275 | 
            +
                            # call the callback, if provided
         | 
| 1276 | 
            +
                            if i == len(timesteps) - 1 or (
         | 
| 1277 | 
            +
                                (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
         | 
| 1278 | 
            +
                            ):
         | 
| 1279 | 
            +
                                progress_bar.update()
         | 
| 1280 | 
            +
             | 
| 1281 | 
            +
                            if callback_on_step_end is not None:
         | 
| 1282 | 
            +
                                callback_on_step_end(self, i, t, {})
         | 
| 1283 | 
            +
             | 
| 1284 | 
            +
                    if offload_to_cpu:
         | 
| 1285 | 
            +
                        self.transformer = self.transformer.cpu()
         | 
| 1286 | 
            +
                        if self._execution_device == "cuda":
         | 
| 1287 | 
            +
                            torch.cuda.empty_cache()
         | 
| 1288 | 
            +
             | 
| 1289 | 
            +
                    # Remove the added conditioning latents
         | 
| 1290 | 
            +
                    latents = latents[:, num_cond_latents:]
         | 
| 1291 | 
            +
             | 
| 1292 | 
            +
                    latents = self.patchifier.unpatchify(
         | 
| 1293 | 
            +
                        latents=latents,
         | 
| 1294 | 
            +
                        output_height=latent_height,
         | 
| 1295 | 
            +
                        output_width=latent_width,
         | 
| 1296 | 
            +
                        out_channels=self.transformer.in_channels
         | 
| 1297 | 
            +
                        // math.prod(self.patchifier.patch_size),
         | 
| 1298 | 
            +
                    )
         | 
| 1299 | 
            +
                    if output_type != "latent":
         | 
| 1300 | 
            +
                        if self.vae.decoder.timestep_conditioning:
         | 
| 1301 | 
            +
                            noise = torch.randn_like(latents)
         | 
| 1302 | 
            +
                            if not isinstance(decode_timestep, list):
         | 
| 1303 | 
            +
                                decode_timestep = [decode_timestep] * latents.shape[0]
         | 
| 1304 | 
            +
                            if decode_noise_scale is None:
         | 
| 1305 | 
            +
                                decode_noise_scale = decode_timestep
         | 
| 1306 | 
            +
                            elif not isinstance(decode_noise_scale, list):
         | 
| 1307 | 
            +
                                decode_noise_scale = [decode_noise_scale] * latents.shape[0]
         | 
| 1308 | 
            +
             | 
| 1309 | 
            +
                            decode_timestep = torch.tensor(decode_timestep).to(latents.device)
         | 
| 1310 | 
            +
                            decode_noise_scale = torch.tensor(decode_noise_scale).to(
         | 
| 1311 | 
            +
                                latents.device
         | 
| 1312 | 
            +
                            )[:, None, None, None, None]
         | 
| 1313 | 
            +
                            latents = (
         | 
| 1314 | 
            +
                                latents * (1 - decode_noise_scale) + noise * decode_noise_scale
         | 
| 1315 | 
            +
                            )
         | 
| 1316 | 
            +
                        else:
         | 
| 1317 | 
            +
                            decode_timestep = None
         | 
| 1318 | 
            +
                        image = vae_decode(
         | 
| 1319 | 
            +
                            latents,
         | 
| 1320 | 
            +
                            self.vae,
         | 
| 1321 | 
            +
                            is_video,
         | 
| 1322 | 
            +
                            vae_per_channel_normalize=kwargs["vae_per_channel_normalize"],
         | 
| 1323 | 
            +
                            timestep=decode_timestep,
         | 
| 1324 | 
            +
                        )
         | 
| 1325 | 
            +
             | 
| 1326 | 
            +
                        image = self.image_processor.postprocess(image, output_type=output_type)
         | 
| 1327 | 
            +
             | 
| 1328 | 
            +
                    else:
         | 
| 1329 | 
            +
                        image = latents
         | 
| 1330 | 
            +
             | 
| 1331 | 
            +
                    # Offload all models
         | 
| 1332 | 
            +
                    self.maybe_free_model_hooks()
         | 
| 1333 | 
            +
             | 
| 1334 | 
            +
                    if not return_dict:
         | 
| 1335 | 
            +
                        return (image,)
         | 
| 1336 | 
            +
             | 
| 1337 | 
            +
                    return ImagePipelineOutput(images=image)
         | 
| 1338 | 
            +
             | 
| 1339 | 
            +
                def denoising_step(
         | 
| 1340 | 
            +
                    self,
         | 
| 1341 | 
            +
                    latents: torch.Tensor,
         | 
| 1342 | 
            +
                    noise_pred: torch.Tensor,
         | 
| 1343 | 
            +
                    current_timestep: torch.Tensor,
         | 
| 1344 | 
            +
                    conditioning_mask: torch.Tensor,
         | 
| 1345 | 
            +
                    t: float,
         | 
| 1346 | 
            +
                    extra_step_kwargs,
         | 
| 1347 | 
            +
                    t_eps=1e-6,
         | 
| 1348 | 
            +
                    stochastic_sampling=False,
         | 
| 1349 | 
            +
                ):
         | 
| 1350 | 
            +
                    """
         | 
| 1351 | 
            +
                    Perform the denoising step for the required tokens, based on the current timestep and
         | 
| 1352 | 
            +
                    conditioning mask:
         | 
| 1353 | 
            +
                    Conditioning latents have an initial timestep and noising level of (1.0 - conditioning_mask)
         | 
| 1354 | 
            +
                    and will start to be denoised when the current timestep is equal or lower than their
         | 
| 1355 | 
            +
                    conditioning timestep.
         | 
| 1356 | 
            +
                    (hard-conditioning latents with conditioning_mask = 1.0 are never denoised)
         | 
| 1357 | 
            +
                    """
         | 
| 1358 | 
            +
                    # Denoise the latents using the scheduler
         | 
| 1359 | 
            +
                    denoised_latents = self.scheduler.step(
         | 
| 1360 | 
            +
                        noise_pred,
         | 
| 1361 | 
            +
                        t if current_timestep is None else current_timestep,
         | 
| 1362 | 
            +
                        latents,
         | 
| 1363 | 
            +
                        **extra_step_kwargs,
         | 
| 1364 | 
            +
                        return_dict=False,
         | 
| 1365 | 
            +
                        stochastic_sampling=stochastic_sampling,
         | 
| 1366 | 
            +
                    )[0]
         | 
| 1367 | 
            +
             | 
| 1368 | 
            +
                    if conditioning_mask is None:
         | 
| 1369 | 
            +
                        return denoised_latents
         | 
| 1370 | 
            +
             | 
| 1371 | 
            +
                    tokens_to_denoise_mask = (t - t_eps < (1.0 - conditioning_mask)).unsqueeze(-1)
         | 
| 1372 | 
            +
                    return torch.where(tokens_to_denoise_mask, denoised_latents, latents)
         | 
| 1373 | 
            +
             | 
| 1374 | 
            +
                def prepare_conditioning(
         | 
| 1375 | 
            +
                    self,
         | 
| 1376 | 
            +
                    conditioning_items: Optional[List[ConditioningItem]],
         | 
| 1377 | 
            +
                    init_latents: torch.Tensor,
         | 
| 1378 | 
            +
                    num_frames: int,
         | 
| 1379 | 
            +
                    height: int,
         | 
| 1380 | 
            +
                    width: int,
         | 
| 1381 | 
            +
                    vae_per_channel_normalize: bool = False,
         | 
| 1382 | 
            +
                    generator=None,
         | 
| 1383 | 
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
         | 
| 1384 | 
            +
                    """
         | 
| 1385 | 
            +
                    Prepare conditioning tokens based on the provided conditioning items.
         | 
| 1386 | 
            +
             | 
| 1387 | 
            +
                    This method encodes provided conditioning items (video frames or single frames) into latents
         | 
| 1388 | 
            +
                    and integrates them with the initial latent tensor. It also calculates corresponding pixel
         | 
| 1389 | 
            +
                    coordinates, a mask indicating the influence of conditioning latents, and the total number of
         | 
| 1390 | 
            +
                    conditioning latents.
         | 
| 1391 | 
            +
             | 
| 1392 | 
            +
                    Args:
         | 
| 1393 | 
            +
                        conditioning_items (Optional[List[ConditioningItem]]): A list of ConditioningItem objects.
         | 
| 1394 | 
            +
                        init_latents (torch.Tensor): The initial latent tensor of shape (b, c, f_l, h_l, w_l), where
         | 
| 1395 | 
            +
                            `f_l` is the number of latent frames, and `h_l` and `w_l` are latent spatial dimensions.
         | 
| 1396 | 
            +
                        num_frames, height, width: The dimensions of the generated video.
         | 
| 1397 | 
            +
                        vae_per_channel_normalize (bool, optional): Whether to normalize channels during VAE encoding.
         | 
| 1398 | 
            +
                            Defaults to `False`.
         | 
| 1399 | 
            +
                        generator: The random generator
         | 
| 1400 | 
            +
             | 
| 1401 | 
            +
                    Returns:
         | 
| 1402 | 
            +
                        Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
         | 
| 1403 | 
            +
                            - `init_latents` (torch.Tensor): The updated latent tensor including conditioning latents,
         | 
| 1404 | 
            +
                              patchified into (b, n, c) shape.
         | 
| 1405 | 
            +
                            - `init_pixel_coords` (torch.Tensor): The pixel coordinates corresponding to the updated
         | 
| 1406 | 
            +
                              latent tensor.
         | 
| 1407 | 
            +
                            - `conditioning_mask` (torch.Tensor): A mask indicating the conditioning-strength of each
         | 
| 1408 | 
            +
                              latent token.
         | 
| 1409 | 
            +
                            - `num_cond_latents` (int): The total number of latent tokens added from conditioning items.
         | 
| 1410 | 
            +
             | 
| 1411 | 
            +
                    Raises:
         | 
| 1412 | 
            +
                        AssertionError: If input shapes, dimensions, or conditions for applying conditioning are invalid.
         | 
| 1413 | 
            +
                    """
         | 
| 1414 | 
            +
                    assert isinstance(self.vae, CausalVideoAutoencoder)
         | 
| 1415 | 
            +
             | 
| 1416 | 
            +
                    if conditioning_items:
         | 
| 1417 | 
            +
                        batch_size, _, num_latent_frames = init_latents.shape[:3]
         | 
| 1418 | 
            +
             | 
| 1419 | 
            +
                        init_conditioning_mask = torch.zeros(
         | 
| 1420 | 
            +
                            init_latents[:, 0, :, :, :].shape,
         | 
| 1421 | 
            +
                            dtype=torch.float32,
         | 
| 1422 | 
            +
                            device=init_latents.device,
         | 
| 1423 | 
            +
                        )
         | 
| 1424 | 
            +
             | 
| 1425 | 
            +
                        extra_conditioning_latents = []
         | 
| 1426 | 
            +
                        extra_conditioning_pixel_coords = []
         | 
| 1427 | 
            +
                        extra_conditioning_mask = []
         | 
| 1428 | 
            +
                        extra_conditioning_num_latents = 0  # Number of extra conditioning latents added (should be removed before decoding)
         | 
| 1429 | 
            +
             | 
| 1430 | 
            +
                        # Process each conditioning item
         | 
| 1431 | 
            +
                        for conditioning_item in conditioning_items:
         | 
| 1432 | 
            +
                            conditioning_item = self._resize_conditioning_item(
         | 
| 1433 | 
            +
                                conditioning_item, height, width
         | 
| 1434 | 
            +
                            )
         | 
| 1435 | 
            +
                            media_item = conditioning_item.media_item
         | 
| 1436 | 
            +
                            media_frame_number = conditioning_item.media_frame_number
         | 
| 1437 | 
            +
                            strength = conditioning_item.conditioning_strength
         | 
| 1438 | 
            +
                            assert media_item.ndim == 5  # (b, c, f, h, w)
         | 
| 1439 | 
            +
                            b, c, n_frames, h, w = media_item.shape
         | 
| 1440 | 
            +
                            assert (
         | 
| 1441 | 
            +
                                height == h and width == w
         | 
| 1442 | 
            +
                            ) or media_frame_number == 0, f"Dimensions do not match: {height}x{width} != {h}x{w} - allowed only when media_frame_number == 0"
         | 
| 1443 | 
            +
                            assert n_frames % 8 == 1
         | 
| 1444 | 
            +
                            assert (
         | 
| 1445 | 
            +
                                media_frame_number >= 0
         | 
| 1446 | 
            +
                                and media_frame_number + n_frames <= num_frames
         | 
| 1447 | 
            +
                            )
         | 
| 1448 | 
            +
             | 
| 1449 | 
            +
                            # Encode the provided conditioning media item
         | 
| 1450 | 
            +
                            media_item_latents = vae_encode(
         | 
| 1451 | 
            +
                                media_item.to(dtype=self.vae.dtype, device=self.vae.device),
         | 
| 1452 | 
            +
                                self.vae,
         | 
| 1453 | 
            +
                                vae_per_channel_normalize=vae_per_channel_normalize,
         | 
| 1454 | 
            +
                            ).to(dtype=init_latents.dtype)
         | 
| 1455 | 
            +
             | 
| 1456 | 
            +
                            # Handle the different conditioning cases
         | 
| 1457 | 
            +
                            if media_frame_number == 0:
         | 
| 1458 | 
            +
                                # Get the target spatial position of the latent conditioning item
         | 
| 1459 | 
            +
                                media_item_latents, l_x, l_y = self._get_latent_spatial_position(
         | 
| 1460 | 
            +
                                    media_item_latents,
         | 
| 1461 | 
            +
                                    conditioning_item,
         | 
| 1462 | 
            +
                                    height,
         | 
| 1463 | 
            +
                                    width,
         | 
| 1464 | 
            +
                                    strip_latent_border=True,
         | 
| 1465 | 
            +
                                )
         | 
| 1466 | 
            +
                                b, c_l, f_l, h_l, w_l = media_item_latents.shape
         | 
| 1467 | 
            +
             | 
| 1468 | 
            +
                                # First frame or sequence - just update the initial noise latents and the mask
         | 
| 1469 | 
            +
                                init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l] = (
         | 
| 1470 | 
            +
                                    torch.lerp(
         | 
| 1471 | 
            +
                                        init_latents[:, :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l],
         | 
| 1472 | 
            +
                                        media_item_latents,
         | 
| 1473 | 
            +
                                        strength,
         | 
| 1474 | 
            +
                                    )
         | 
| 1475 | 
            +
                                )
         | 
| 1476 | 
            +
                                init_conditioning_mask[
         | 
| 1477 | 
            +
                                    :, :f_l, l_y : l_y + h_l, l_x : l_x + w_l
         | 
| 1478 | 
            +
                                ] = strength
         | 
| 1479 | 
            +
                            else:
         | 
| 1480 | 
            +
                                # Non-first frame or sequence
         | 
| 1481 | 
            +
                                if n_frames > 1:
         | 
| 1482 | 
            +
                                    # Handle non-first sequence.
         | 
| 1483 | 
            +
                                    # Encoded latents are either fully consumed, or the prefix is handled separately below.
         | 
| 1484 | 
            +
                                    (
         | 
| 1485 | 
            +
                                        init_latents,
         | 
| 1486 | 
            +
                                        init_conditioning_mask,
         | 
| 1487 | 
            +
                                        media_item_latents,
         | 
| 1488 | 
            +
                                    ) = self._handle_non_first_conditioning_sequence(
         | 
| 1489 | 
            +
                                        init_latents,
         | 
| 1490 | 
            +
                                        init_conditioning_mask,
         | 
| 1491 | 
            +
                                        media_item_latents,
         | 
| 1492 | 
            +
                                        media_frame_number,
         | 
| 1493 | 
            +
                                        strength,
         | 
| 1494 | 
            +
                                    )
         | 
| 1495 | 
            +
             | 
| 1496 | 
            +
                                # Single frame or sequence-prefix latents
         | 
| 1497 | 
            +
                                if media_item_latents is not None:
         | 
| 1498 | 
            +
                                    noise = randn_tensor(
         | 
| 1499 | 
            +
                                        media_item_latents.shape,
         | 
| 1500 | 
            +
                                        generator=generator,
         | 
| 1501 | 
            +
                                        device=media_item_latents.device,
         | 
| 1502 | 
            +
                                        dtype=media_item_latents.dtype,
         | 
| 1503 | 
            +
                                    )
         | 
| 1504 | 
            +
             | 
| 1505 | 
            +
                                    media_item_latents = torch.lerp(
         | 
| 1506 | 
            +
                                        noise, media_item_latents, strength
         | 
| 1507 | 
            +
                                    )
         | 
| 1508 | 
            +
             | 
| 1509 | 
            +
                                    # Patchify the extra conditioning latents and calculate their pixel coordinates
         | 
| 1510 | 
            +
                                    media_item_latents, latent_coords = self.patchifier.patchify(
         | 
| 1511 | 
            +
                                        latents=media_item_latents
         | 
| 1512 | 
            +
                                    )
         | 
| 1513 | 
            +
                                    pixel_coords = latent_to_pixel_coords(
         | 
| 1514 | 
            +
                                        latent_coords,
         | 
| 1515 | 
            +
                                        self.vae,
         | 
| 1516 | 
            +
                                        causal_fix=self.transformer.config.causal_temporal_positioning,
         | 
| 1517 | 
            +
                                    )
         | 
| 1518 | 
            +
             | 
| 1519 | 
            +
                                    # Update the frame numbers to match the target frame number
         | 
| 1520 | 
            +
                                    pixel_coords[:, 0] += media_frame_number
         | 
| 1521 | 
            +
                                    extra_conditioning_num_latents += media_item_latents.shape[1]
         | 
| 1522 | 
            +
             | 
| 1523 | 
            +
                                    conditioning_mask = torch.full(
         | 
| 1524 | 
            +
                                        media_item_latents.shape[:2],
         | 
| 1525 | 
            +
                                        strength,
         | 
| 1526 | 
            +
                                        dtype=torch.float32,
         | 
| 1527 | 
            +
                                        device=init_latents.device,
         | 
| 1528 | 
            +
                                    )
         | 
| 1529 | 
            +
             | 
| 1530 | 
            +
                                    extra_conditioning_latents.append(media_item_latents)
         | 
| 1531 | 
            +
                                    extra_conditioning_pixel_coords.append(pixel_coords)
         | 
| 1532 | 
            +
                                    extra_conditioning_mask.append(conditioning_mask)
         | 
| 1533 | 
            +
             | 
| 1534 | 
            +
                    # Patchify the updated latents and calculate their pixel coordinates
         | 
| 1535 | 
            +
                    init_latents, init_latent_coords = self.patchifier.patchify(
         | 
| 1536 | 
            +
                        latents=init_latents
         | 
| 1537 | 
            +
                    )
         | 
| 1538 | 
            +
                    init_pixel_coords = latent_to_pixel_coords(
         | 
| 1539 | 
            +
                        init_latent_coords,
         | 
| 1540 | 
            +
                        self.vae,
         | 
| 1541 | 
            +
                        causal_fix=self.transformer.config.causal_temporal_positioning,
         | 
| 1542 | 
            +
                    )
         | 
| 1543 | 
            +
             | 
| 1544 | 
            +
                    if not conditioning_items:
         | 
| 1545 | 
            +
                        return init_latents, init_pixel_coords, None, 0
         | 
| 1546 | 
            +
             | 
| 1547 | 
            +
                    init_conditioning_mask, _ = self.patchifier.patchify(
         | 
| 1548 | 
            +
                        latents=init_conditioning_mask.unsqueeze(1)
         | 
| 1549 | 
            +
                    )
         | 
| 1550 | 
            +
                    init_conditioning_mask = init_conditioning_mask.squeeze(-1)
         | 
| 1551 | 
            +
             | 
| 1552 | 
            +
                    if extra_conditioning_latents:
         | 
| 1553 | 
            +
                        # Stack the extra conditioning latents, pixel coordinates and mask
         | 
| 1554 | 
            +
                        init_latents = torch.cat([*extra_conditioning_latents, init_latents], dim=1)
         | 
| 1555 | 
            +
                        init_pixel_coords = torch.cat(
         | 
| 1556 | 
            +
                            [*extra_conditioning_pixel_coords, init_pixel_coords], dim=2
         | 
| 1557 | 
            +
                        )
         | 
| 1558 | 
            +
                        init_conditioning_mask = torch.cat(
         | 
| 1559 | 
            +
                            [*extra_conditioning_mask, init_conditioning_mask], dim=1
         | 
| 1560 | 
            +
                        )
         | 
| 1561 | 
            +
             | 
| 1562 | 
            +
                        if self.transformer.use_tpu_flash_attention:
         | 
| 1563 | 
            +
                            # When flash attention is used, keep the original number of tokens by removing
         | 
| 1564 | 
            +
                            #   tokens from the end.
         | 
| 1565 | 
            +
                            init_latents = init_latents[:, :-extra_conditioning_num_latents]
         | 
| 1566 | 
            +
                            init_pixel_coords = init_pixel_coords[
         | 
| 1567 | 
            +
                                :, :, :-extra_conditioning_num_latents
         | 
| 1568 | 
            +
                            ]
         | 
| 1569 | 
            +
                            init_conditioning_mask = init_conditioning_mask[
         | 
| 1570 | 
            +
                                :, :-extra_conditioning_num_latents
         | 
| 1571 | 
            +
                            ]
         | 
| 1572 | 
            +
             | 
| 1573 | 
            +
                    return (
         | 
| 1574 | 
            +
                        init_latents,
         | 
| 1575 | 
            +
                        init_pixel_coords,
         | 
| 1576 | 
            +
                        init_conditioning_mask,
         | 
| 1577 | 
            +
                        extra_conditioning_num_latents,
         | 
| 1578 | 
            +
                    )
         | 
| 1579 | 
            +
             | 
| 1580 | 
            +
                @staticmethod
         | 
| 1581 | 
            +
                def _resize_conditioning_item(
         | 
| 1582 | 
            +
                    conditioning_item: ConditioningItem,
         | 
| 1583 | 
            +
                    height: int,
         | 
| 1584 | 
            +
                    width: int,
         | 
| 1585 | 
            +
                ):
         | 
| 1586 | 
            +
                    if conditioning_item.media_x or conditioning_item.media_y:
         | 
| 1587 | 
            +
                        raise ValueError(
         | 
| 1588 | 
            +
                            "Provide media_item in the target size for spatial conditioning."
         | 
| 1589 | 
            +
                        )
         | 
| 1590 | 
            +
                    new_conditioning_item = copy.copy(conditioning_item)
         | 
| 1591 | 
            +
                    new_conditioning_item.media_item = LTXVideoPipeline.resize_tensor(
         | 
| 1592 | 
            +
                        conditioning_item.media_item, height, width
         | 
| 1593 | 
            +
                    )
         | 
| 1594 | 
            +
                    return new_conditioning_item
         | 
| 1595 | 
            +
             | 
| 1596 | 
            +
                def _get_latent_spatial_position(
         | 
| 1597 | 
            +
                    self,
         | 
| 1598 | 
            +
                    latents: torch.Tensor,
         | 
| 1599 | 
            +
                    conditioning_item: ConditioningItem,
         | 
| 1600 | 
            +
                    height: int,
         | 
| 1601 | 
            +
                    width: int,
         | 
| 1602 | 
            +
                    strip_latent_border,
         | 
| 1603 | 
            +
                ):
         | 
| 1604 | 
            +
                    """
         | 
| 1605 | 
            +
                    Get the spatial position of the conditioning item in the latent space.
         | 
| 1606 | 
            +
                    If requested, strip the conditioning latent borders that do not align with target borders.
         | 
| 1607 | 
            +
                    (border latents look different then other latents and might confuse the model)
         | 
| 1608 | 
            +
                    """
         | 
| 1609 | 
            +
                    scale = self.vae_scale_factor
         | 
| 1610 | 
            +
                    h, w = conditioning_item.media_item.shape[-2:]
         | 
| 1611 | 
            +
                    assert (
         | 
| 1612 | 
            +
                        h <= height and w <= width
         | 
| 1613 | 
            +
                    ), f"Conditioning item size {h}x{w} is larger than target size {height}x{width}"
         | 
| 1614 | 
            +
                    assert h % scale == 0 and w % scale == 0
         | 
| 1615 | 
            +
             | 
| 1616 | 
            +
                    # Compute the start and end spatial positions of the media item
         | 
| 1617 | 
            +
                    x_start, y_start = conditioning_item.media_x, conditioning_item.media_y
         | 
| 1618 | 
            +
                    x_start = (width - w) // 2 if x_start is None else x_start
         | 
| 1619 | 
            +
                    y_start = (height - h) // 2 if y_start is None else y_start
         | 
| 1620 | 
            +
                    x_end, y_end = x_start + w, y_start + h
         | 
| 1621 | 
            +
                    assert (
         | 
| 1622 | 
            +
                        x_end <= width and y_end <= height
         | 
| 1623 | 
            +
                    ), f"Conditioning item {x_start}:{x_end}x{y_start}:{y_end} is out of bounds for target size {width}x{height}"
         | 
| 1624 | 
            +
             | 
| 1625 | 
            +
                    if strip_latent_border:
         | 
| 1626 | 
            +
                        # Strip one latent from left/right and/or top/bottom, update x, y accordingly
         | 
| 1627 | 
            +
                        if x_start > 0:
         | 
| 1628 | 
            +
                            x_start += scale
         | 
| 1629 | 
            +
                            latents = latents[:, :, :, :, 1:]
         | 
| 1630 | 
            +
             | 
| 1631 | 
            +
                        if y_start > 0:
         | 
| 1632 | 
            +
                            y_start += scale
         | 
| 1633 | 
            +
                            latents = latents[:, :, :, 1:, :]
         | 
| 1634 | 
            +
             | 
| 1635 | 
            +
                        if x_end < width:
         | 
| 1636 | 
            +
                            latents = latents[:, :, :, :, :-1]
         | 
| 1637 | 
            +
             | 
| 1638 | 
            +
                        if y_end < height:
         | 
| 1639 | 
            +
                            latents = latents[:, :, :, :-1, :]
         | 
| 1640 | 
            +
             | 
| 1641 | 
            +
                    return latents, x_start // scale, y_start // scale
         | 
| 1642 | 
            +
             | 
| 1643 | 
            +
                @staticmethod
         | 
| 1644 | 
            +
                def _handle_non_first_conditioning_sequence(
         | 
| 1645 | 
            +
                    init_latents: torch.Tensor,
         | 
| 1646 | 
            +
                    init_conditioning_mask: torch.Tensor,
         | 
| 1647 | 
            +
                    latents: torch.Tensor,
         | 
| 1648 | 
            +
                    media_frame_number: int,
         | 
| 1649 | 
            +
                    strength: float,
         | 
| 1650 | 
            +
                    num_prefix_latent_frames: int = 2,
         | 
| 1651 | 
            +
                    prefix_latents_mode: str = "concat",
         | 
| 1652 | 
            +
                    prefix_soft_conditioning_strength: float = 0.15,
         | 
| 1653 | 
            +
                ):
         | 
| 1654 | 
            +
                    """
         | 
| 1655 | 
            +
                    Special handling for a conditioning sequence that does not start on the first frame.
         | 
| 1656 | 
            +
                    The special handling is required to allow a short encoded video to be used as middle
         | 
| 1657 | 
            +
                    (or last) sequence in a longer video.
         | 
| 1658 | 
            +
                    Args:
         | 
| 1659 | 
            +
                        init_latents (torch.Tensor): The initial noise latents to be updated.
         | 
| 1660 | 
            +
                        init_conditioning_mask (torch.Tensor): The initial conditioning mask to be updated.
         | 
| 1661 | 
            +
                        latents (torch.Tensor): The encoded conditioning item.
         | 
| 1662 | 
            +
                        media_frame_number (int): The target frame number of the first frame in the conditioning sequence.
         | 
| 1663 | 
            +
                        strength (float): The conditioning strength for the conditioning latents.
         | 
| 1664 | 
            +
                        num_prefix_latent_frames (int, optional): The length of the sequence prefix, to be handled
         | 
| 1665 | 
            +
                            separately. Defaults to 2.
         | 
| 1666 | 
            +
                        prefix_latents_mode (str, optional): Special treatment for prefix (boundary) latents.
         | 
| 1667 | 
            +
                            - "drop": Drop the prefix latents.
         | 
| 1668 | 
            +
                            - "soft": Use the prefix latents, but with soft-conditioning
         | 
| 1669 | 
            +
                            - "concat": Add the prefix latents as extra tokens (like single frames)
         | 
| 1670 | 
            +
                        prefix_soft_conditioning_strength (float, optional): The strength of the soft-conditioning for
         | 
| 1671 | 
            +
                            the prefix latents, relevant if `prefix_latents_mode` is "soft". Defaults to 0.1.
         | 
| 1672 | 
            +
             | 
| 1673 | 
            +
                    """
         | 
| 1674 | 
            +
                    f_l = latents.shape[2]
         | 
| 1675 | 
            +
                    f_l_p = num_prefix_latent_frames
         | 
| 1676 | 
            +
                    assert f_l >= f_l_p
         | 
| 1677 | 
            +
                    assert media_frame_number % 8 == 0
         | 
| 1678 | 
            +
                    if f_l > f_l_p:
         | 
| 1679 | 
            +
                        # Insert the conditioning latents **excluding the prefix** into the sequence
         | 
| 1680 | 
            +
                        f_l_start = media_frame_number // 8 + f_l_p
         | 
| 1681 | 
            +
                        f_l_end = f_l_start + f_l - f_l_p
         | 
| 1682 | 
            +
                        init_latents[:, :, f_l_start:f_l_end] = torch.lerp(
         | 
| 1683 | 
            +
                            init_latents[:, :, f_l_start:f_l_end],
         | 
| 1684 | 
            +
                            latents[:, :, f_l_p:],
         | 
| 1685 | 
            +
                            strength,
         | 
| 1686 | 
            +
                        )
         | 
| 1687 | 
            +
                        # Mark these latent frames as conditioning latents
         | 
| 1688 | 
            +
                        init_conditioning_mask[:, f_l_start:f_l_end] = strength
         | 
| 1689 | 
            +
             | 
| 1690 | 
            +
                    # Handle the prefix-latents
         | 
| 1691 | 
            +
                    if prefix_latents_mode == "soft":
         | 
| 1692 | 
            +
                        if f_l_p > 1:
         | 
| 1693 | 
            +
                            # Drop the first (single-frame) latent and soft-condition the remaining prefix
         | 
| 1694 | 
            +
                            f_l_start = media_frame_number // 8 + 1
         | 
| 1695 | 
            +
                            f_l_end = f_l_start + f_l_p - 1
         | 
| 1696 | 
            +
                            strength = min(prefix_soft_conditioning_strength, strength)
         | 
| 1697 | 
            +
                            init_latents[:, :, f_l_start:f_l_end] = torch.lerp(
         | 
| 1698 | 
            +
                                init_latents[:, :, f_l_start:f_l_end],
         | 
| 1699 | 
            +
                                latents[:, :, 1:f_l_p],
         | 
| 1700 | 
            +
                                strength,
         | 
| 1701 | 
            +
                            )
         | 
| 1702 | 
            +
                            # Mark these latent frames as conditioning latents
         | 
| 1703 | 
            +
                            init_conditioning_mask[:, f_l_start:f_l_end] = strength
         | 
| 1704 | 
            +
                        latents = None  # No more latents to handle
         | 
| 1705 | 
            +
                    elif prefix_latents_mode == "drop":
         | 
| 1706 | 
            +
                        # Drop the prefix latents
         | 
| 1707 | 
            +
                        latents = None
         | 
| 1708 | 
            +
                    elif prefix_latents_mode == "concat":
         | 
| 1709 | 
            +
                        # Pass-on the prefix latents to be handled as extra conditioning frames
         | 
| 1710 | 
            +
                        latents = latents[:, :, :f_l_p]
         | 
| 1711 | 
            +
                    else:
         | 
| 1712 | 
            +
                        raise ValueError(f"Invalid prefix_latents_mode: {prefix_latents_mode}")
         | 
| 1713 | 
            +
                    return (
         | 
| 1714 | 
            +
                        init_latents,
         | 
| 1715 | 
            +
                        init_conditioning_mask,
         | 
| 1716 | 
            +
                        latents,
         | 
| 1717 | 
            +
                    )
         | 
| 1718 | 
            +
             | 
| 1719 | 
            +
                def trim_conditioning_sequence(
         | 
| 1720 | 
            +
                    self, start_frame: int, sequence_num_frames: int, target_num_frames: int
         | 
| 1721 | 
            +
                ):
         | 
| 1722 | 
            +
                    """
         | 
| 1723 | 
            +
                    Trim a conditioning sequence to the allowed number of frames.
         | 
| 1724 | 
            +
             | 
| 1725 | 
            +
                    Args:
         | 
| 1726 | 
            +
                        start_frame (int): The target frame number of the first frame in the sequence.
         | 
| 1727 | 
            +
                        sequence_num_frames (int): The number of frames in the sequence.
         | 
| 1728 | 
            +
                        target_num_frames (int): The target number of frames in the generated video.
         | 
| 1729 | 
            +
             | 
| 1730 | 
            +
                    Returns:
         | 
| 1731 | 
            +
                        int: updated sequence length
         | 
| 1732 | 
            +
                    """
         | 
| 1733 | 
            +
                    scale_factor = self.video_scale_factor
         | 
| 1734 | 
            +
                    num_frames = min(sequence_num_frames, target_num_frames - start_frame)
         | 
| 1735 | 
            +
                    # Trim down to a multiple of temporal_scale_factor frames plus 1
         | 
| 1736 | 
            +
                    num_frames = (num_frames - 1) // scale_factor * scale_factor + 1
         | 
| 1737 | 
            +
                    return num_frames
         | 
| 1738 | 
            +
             | 
| 1739 | 
            +
             | 
| 1740 | 
            +
            def adain_filter_latent(
         | 
| 1741 | 
            +
                latents: torch.Tensor, reference_latents: torch.Tensor, factor=1.0
         | 
| 1742 | 
            +
            ):
         | 
| 1743 | 
            +
                """
         | 
| 1744 | 
            +
                Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on
         | 
| 1745 | 
            +
                statistics from a reference latent tensor.
         | 
| 1746 | 
            +
             | 
| 1747 | 
            +
                Args:
         | 
| 1748 | 
            +
                    latent (torch.Tensor): Input latents to normalize
         | 
| 1749 | 
            +
                    reference_latent (torch.Tensor): The reference latents providing style statistics.
         | 
| 1750 | 
            +
                    factor (float): Blending factor between original and transformed latent.
         | 
| 1751 | 
            +
                                   Range: -10.0 to 10.0, Default: 1.0
         | 
| 1752 | 
            +
             | 
| 1753 | 
            +
                Returns:
         | 
| 1754 | 
            +
                    torch.Tensor: The transformed latent tensor
         | 
| 1755 | 
            +
                """
         | 
| 1756 | 
            +
                result = latents.clone()
         | 
| 1757 | 
            +
             | 
| 1758 | 
            +
                for i in range(latents.size(0)):
         | 
| 1759 | 
            +
                    for c in range(latents.size(1)):
         | 
| 1760 | 
            +
                        r_sd, r_mean = torch.std_mean(
         | 
| 1761 | 
            +
                            reference_latents[i, c], dim=None
         | 
| 1762 | 
            +
                        )  # index by original dim order
         | 
| 1763 | 
            +
                        i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
         | 
| 1764 | 
            +
             | 
| 1765 | 
            +
                        result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
         | 
| 1766 | 
            +
             | 
| 1767 | 
            +
                result = torch.lerp(latents, result, factor)
         | 
| 1768 | 
            +
                return result
         | 
| 1769 | 
            +
             | 
| 1770 | 
            +
             | 
| 1771 | 
            +
            class LTXMultiScalePipeline:
         | 
| 1772 | 
            +
                def _upsample_latents(
         | 
| 1773 | 
            +
                    self, latest_upsampler: LatentUpsampler, latents: torch.Tensor
         | 
| 1774 | 
            +
                ):
         | 
| 1775 | 
            +
                    assert latents.device == latest_upsampler.device
         | 
| 1776 | 
            +
             | 
| 1777 | 
            +
                    latents = un_normalize_latents(
         | 
| 1778 | 
            +
                        latents, self.vae, vae_per_channel_normalize=True
         | 
| 1779 | 
            +
                    )
         | 
| 1780 | 
            +
                    upsampled_latents = latest_upsampler(latents)
         | 
| 1781 | 
            +
                    upsampled_latents = normalize_latents(
         | 
| 1782 | 
            +
                        upsampled_latents, self.vae, vae_per_channel_normalize=True
         | 
| 1783 | 
            +
                    )
         | 
| 1784 | 
            +
                    return upsampled_latents
         | 
| 1785 | 
            +
             | 
| 1786 | 
            +
                def __init__(
         | 
| 1787 | 
            +
                    self, video_pipeline: LTXVideoPipeline, latent_upsampler: LatentUpsampler
         | 
| 1788 | 
            +
                ):
         | 
| 1789 | 
            +
                    self.video_pipeline = video_pipeline
         | 
| 1790 | 
            +
                    self.vae = video_pipeline.vae
         | 
| 1791 | 
            +
                    self.latent_upsampler = latent_upsampler
         | 
| 1792 | 
            +
             | 
| 1793 | 
            +
                def __call__(
         | 
| 1794 | 
            +
                    self,
         | 
| 1795 | 
            +
                    downscale_factor: float,
         | 
| 1796 | 
            +
                    first_pass: dict,
         | 
| 1797 | 
            +
                    second_pass: dict,
         | 
| 1798 | 
            +
                    *args: Any,
         | 
| 1799 | 
            +
                    **kwargs: Any,
         | 
| 1800 | 
            +
                ) -> Any:
         | 
| 1801 | 
            +
                    original_kwargs = kwargs.copy()
         | 
| 1802 | 
            +
                    original_output_type = kwargs["output_type"]
         | 
| 1803 | 
            +
                    original_width = kwargs["width"]
         | 
| 1804 | 
            +
                    original_height = kwargs["height"]
         | 
| 1805 | 
            +
             | 
| 1806 | 
            +
                    x_width = int(kwargs["width"] * downscale_factor)
         | 
| 1807 | 
            +
                    downscaled_width = x_width - (x_width % self.video_pipeline.vae_scale_factor)
         | 
| 1808 | 
            +
                    x_height = int(kwargs["height"] * downscale_factor)
         | 
| 1809 | 
            +
                    downscaled_height = x_height - (x_height % self.video_pipeline.vae_scale_factor)
         | 
| 1810 | 
            +
             | 
| 1811 | 
            +
                    kwargs["output_type"] = "latent"
         | 
| 1812 | 
            +
                    kwargs["width"] = downscaled_width
         | 
| 1813 | 
            +
                    kwargs["height"] = downscaled_height
         | 
| 1814 | 
            +
                    kwargs.update(**first_pass)
         | 
| 1815 | 
            +
                    result = self.video_pipeline(*args, **kwargs)
         | 
| 1816 | 
            +
                    latents = result.images
         | 
| 1817 | 
            +
             | 
| 1818 | 
            +
                    upsampled_latents = self._upsample_latents(self.latent_upsampler, latents)
         | 
| 1819 | 
            +
                    upsampled_latents = adain_filter_latent(
         | 
| 1820 | 
            +
                        latents=upsampled_latents, reference_latents=latents
         | 
| 1821 | 
            +
                    )
         | 
| 1822 | 
            +
             | 
| 1823 | 
            +
                    kwargs = original_kwargs
         | 
| 1824 | 
            +
             | 
| 1825 | 
            +
                    kwargs["latents"] = upsampled_latents
         | 
| 1826 | 
            +
                    kwargs["output_type"] = original_output_type
         | 
| 1827 | 
            +
                    kwargs["width"] = downscaled_width * 2
         | 
| 1828 | 
            +
                    kwargs["height"] = downscaled_height * 2
         | 
| 1829 | 
            +
                    kwargs.update(**second_pass)
         | 
| 1830 | 
            +
             | 
| 1831 | 
            +
                    result = self.video_pipeline(*args, **kwargs)
         | 
| 1832 | 
            +
                    if original_output_type != "latent":
         | 
| 1833 | 
            +
                        num_frames = result.images.shape[2]
         | 
| 1834 | 
            +
                        videos = rearrange(result.images, "b c f h w -> (b f) c h w")
         | 
| 1835 | 
            +
             | 
| 1836 | 
            +
                        videos = F.interpolate(
         | 
| 1837 | 
            +
                            videos,
         | 
| 1838 | 
            +
                            size=(original_height, original_width),
         | 
| 1839 | 
            +
                            mode="bilinear",
         | 
| 1840 | 
            +
                            align_corners=False,
         | 
| 1841 | 
            +
                        )
         | 
| 1842 | 
            +
                        videos = rearrange(videos, "(b f) c h w -> b c f h w", f=num_frames)
         | 
| 1843 | 
            +
                        result.images = videos
         | 
| 1844 | 
            +
             | 
| 1845 | 
            +
                    return result
         | 
    	
        ltx_video/schedulers/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        ltx_video/schedulers/rf.py
    ADDED
    
    | @@ -0,0 +1,386 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import math
         | 
| 2 | 
            +
            from abc import ABC, abstractmethod
         | 
| 3 | 
            +
            from dataclasses import dataclass
         | 
| 4 | 
            +
            from typing import Callable, Optional, Tuple, Union
         | 
| 5 | 
            +
            import json
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            from pathlib import Path
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 11 | 
            +
            from diffusers.schedulers.scheduling_utils import SchedulerMixin
         | 
| 12 | 
            +
            from diffusers.utils import BaseOutput
         | 
| 13 | 
            +
            from torch import Tensor
         | 
| 14 | 
            +
            from safetensors import safe_open
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            from ltx_video.utils.torch_utils import append_dims
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from ltx_video.utils.diffusers_config_mapping import (
         | 
| 20 | 
            +
                diffusers_and_ours_config_mapping,
         | 
| 21 | 
            +
                make_hashable_key,
         | 
| 22 | 
            +
            )
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=None):
         | 
| 26 | 
            +
                if num_steps == 1:
         | 
| 27 | 
            +
                    return torch.tensor([1.0])
         | 
| 28 | 
            +
                if linear_steps is None:
         | 
| 29 | 
            +
                    linear_steps = num_steps // 2
         | 
| 30 | 
            +
                linear_sigma_schedule = [
         | 
| 31 | 
            +
                    i * threshold_noise / linear_steps for i in range(linear_steps)
         | 
| 32 | 
            +
                ]
         | 
| 33 | 
            +
                threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
         | 
| 34 | 
            +
                quadratic_steps = num_steps - linear_steps
         | 
| 35 | 
            +
                quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
         | 
| 36 | 
            +
                linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (
         | 
| 37 | 
            +
                    quadratic_steps**2
         | 
| 38 | 
            +
                )
         | 
| 39 | 
            +
                const = quadratic_coef * (linear_steps**2)
         | 
| 40 | 
            +
                quadratic_sigma_schedule = [
         | 
| 41 | 
            +
                    quadratic_coef * (i**2) + linear_coef * i + const
         | 
| 42 | 
            +
                    for i in range(linear_steps, num_steps)
         | 
| 43 | 
            +
                ]
         | 
| 44 | 
            +
                sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
         | 
| 45 | 
            +
                sigma_schedule = [1.0 - x for x in sigma_schedule]
         | 
| 46 | 
            +
                return torch.tensor(sigma_schedule[:-1])
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def simple_diffusion_resolution_dependent_timestep_shift(
         | 
| 50 | 
            +
                samples_shape: torch.Size,
         | 
| 51 | 
            +
                timesteps: Tensor,
         | 
| 52 | 
            +
                n: int = 32 * 32,
         | 
| 53 | 
            +
            ) -> Tensor:
         | 
| 54 | 
            +
                if len(samples_shape) == 3:
         | 
| 55 | 
            +
                    _, m, _ = samples_shape
         | 
| 56 | 
            +
                elif len(samples_shape) in [4, 5]:
         | 
| 57 | 
            +
                    m = math.prod(samples_shape[2:])
         | 
| 58 | 
            +
                else:
         | 
| 59 | 
            +
                    raise ValueError(
         | 
| 60 | 
            +
                        "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
         | 
| 61 | 
            +
                    )
         | 
| 62 | 
            +
                snr = (timesteps / (1 - timesteps)) ** 2
         | 
| 63 | 
            +
                shift_snr = torch.log(snr) + 2 * math.log(m / n)
         | 
| 64 | 
            +
                shifted_timesteps = torch.sigmoid(0.5 * shift_snr)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                return shifted_timesteps
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def time_shift(mu: float, sigma: float, t: Tensor):
         | 
| 70 | 
            +
                return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            def get_normal_shift(
         | 
| 74 | 
            +
                n_tokens: int,
         | 
| 75 | 
            +
                min_tokens: int = 1024,
         | 
| 76 | 
            +
                max_tokens: int = 4096,
         | 
| 77 | 
            +
                min_shift: float = 0.95,
         | 
| 78 | 
            +
                max_shift: float = 2.05,
         | 
| 79 | 
            +
            ) -> Callable[[float], float]:
         | 
| 80 | 
            +
                m = (max_shift - min_shift) / (max_tokens - min_tokens)
         | 
| 81 | 
            +
                b = min_shift - m * min_tokens
         | 
| 82 | 
            +
                return m * n_tokens + b
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def strech_shifts_to_terminal(shifts: Tensor, terminal=0.1):
         | 
| 86 | 
            +
                """
         | 
| 87 | 
            +
                Stretch a function (given as sampled shifts) so that its final value matches the given terminal value
         | 
| 88 | 
            +
                using the provided formula.
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                Parameters:
         | 
| 91 | 
            +
                - shifts (Tensor): The samples of the function to be stretched (PyTorch Tensor).
         | 
| 92 | 
            +
                - terminal (float): The desired terminal value (value at the last sample).
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                Returns:
         | 
| 95 | 
            +
                - Tensor: The stretched shifts such that the final value equals `terminal`.
         | 
| 96 | 
            +
                """
         | 
| 97 | 
            +
                if shifts.numel() == 0:
         | 
| 98 | 
            +
                    raise ValueError("The 'shifts' tensor must not be empty.")
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                # Ensure terminal value is valid
         | 
| 101 | 
            +
                if terminal <= 0 or terminal >= 1:
         | 
| 102 | 
            +
                    raise ValueError("The terminal value must be between 0 and 1 (exclusive).")
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                # Transform the shifts using the given formula
         | 
| 105 | 
            +
                one_minus_z = 1 - shifts
         | 
| 106 | 
            +
                scale_factor = one_minus_z[-1] / (1 - terminal)
         | 
| 107 | 
            +
                stretched_shifts = 1 - (one_minus_z / scale_factor)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                return stretched_shifts
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            def sd3_resolution_dependent_timestep_shift(
         | 
| 113 | 
            +
                samples_shape: torch.Size,
         | 
| 114 | 
            +
                timesteps: Tensor,
         | 
| 115 | 
            +
                target_shift_terminal: Optional[float] = None,
         | 
| 116 | 
            +
            ) -> Tensor:
         | 
| 117 | 
            +
                """
         | 
| 118 | 
            +
                Shifts the timestep schedule as a function of the generated resolution.
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                In the SD3 paper, the authors empirically how to shift the timesteps based on the resolution of the target images.
         | 
| 121 | 
            +
                For more details: https://arxiv.org/pdf/2403.03206
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                In Flux they later propose a more dynamic resolution dependent timestep shift, see:
         | 
| 124 | 
            +
                https://github.com/black-forest-labs/flux/blob/87f6fff727a377ea1c378af692afb41ae84cbe04/src/flux/sampling.py#L66
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
                Args:
         | 
| 128 | 
            +
                    samples_shape (torch.Size): The samples batch shape (batch_size, channels, height, width) or
         | 
| 129 | 
            +
                        (batch_size, channels, frame, height, width).
         | 
| 130 | 
            +
                    timesteps (Tensor): A batch of timesteps with shape (batch_size,).
         | 
| 131 | 
            +
                    target_shift_terminal (float): The target terminal value for the shifted timesteps.
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                Returns:
         | 
| 134 | 
            +
                    Tensor: The shifted timesteps.
         | 
| 135 | 
            +
                """
         | 
| 136 | 
            +
                if len(samples_shape) == 3:
         | 
| 137 | 
            +
                    _, m, _ = samples_shape
         | 
| 138 | 
            +
                elif len(samples_shape) in [4, 5]:
         | 
| 139 | 
            +
                    m = math.prod(samples_shape[2:])
         | 
| 140 | 
            +
                else:
         | 
| 141 | 
            +
                    raise ValueError(
         | 
| 142 | 
            +
                        "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)"
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                shift = get_normal_shift(m)
         | 
| 146 | 
            +
                time_shifts = time_shift(shift, 1, timesteps)
         | 
| 147 | 
            +
                if target_shift_terminal is not None:  # Stretch the shifts to the target terminal
         | 
| 148 | 
            +
                    time_shifts = strech_shifts_to_terminal(time_shifts, target_shift_terminal)
         | 
| 149 | 
            +
                return time_shifts
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
            class TimestepShifter(ABC):
         | 
| 153 | 
            +
                @abstractmethod
         | 
| 154 | 
            +
                def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor) -> Tensor:
         | 
| 155 | 
            +
                    pass
         | 
| 156 | 
            +
             | 
| 157 | 
            +
             | 
| 158 | 
            +
            @dataclass
         | 
| 159 | 
            +
            class RectifiedFlowSchedulerOutput(BaseOutput):
         | 
| 160 | 
            +
                """
         | 
| 161 | 
            +
                Output class for the scheduler's step function output.
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                Args:
         | 
| 164 | 
            +
                    prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 165 | 
            +
                        Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
         | 
| 166 | 
            +
                        denoising loop.
         | 
| 167 | 
            +
                    pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 168 | 
            +
                        The predicted denoised sample (x_{0}) based on the model output from the current timestep.
         | 
| 169 | 
            +
                        `pred_original_sample` can be used to preview progress or for guidance.
         | 
| 170 | 
            +
                """
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                prev_sample: torch.FloatTensor
         | 
| 173 | 
            +
                pred_original_sample: Optional[torch.FloatTensor] = None
         | 
| 174 | 
            +
             | 
| 175 | 
            +
             | 
| 176 | 
            +
            class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter):
         | 
| 177 | 
            +
                order = 1
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                @register_to_config
         | 
| 180 | 
            +
                def __init__(
         | 
| 181 | 
            +
                    self,
         | 
| 182 | 
            +
                    num_train_timesteps=1000,
         | 
| 183 | 
            +
                    shifting: Optional[str] = None,
         | 
| 184 | 
            +
                    base_resolution: int = 32**2,
         | 
| 185 | 
            +
                    target_shift_terminal: Optional[float] = None,
         | 
| 186 | 
            +
                    sampler: Optional[str] = "Uniform",
         | 
| 187 | 
            +
                    shift: Optional[float] = None,
         | 
| 188 | 
            +
                ):
         | 
| 189 | 
            +
                    super().__init__()
         | 
| 190 | 
            +
                    self.init_noise_sigma = 1.0
         | 
| 191 | 
            +
                    self.num_inference_steps = None
         | 
| 192 | 
            +
                    self.sampler = sampler
         | 
| 193 | 
            +
                    self.shifting = shifting
         | 
| 194 | 
            +
                    self.base_resolution = base_resolution
         | 
| 195 | 
            +
                    self.target_shift_terminal = target_shift_terminal
         | 
| 196 | 
            +
                    self.timesteps = self.sigmas = self.get_initial_timesteps(
         | 
| 197 | 
            +
                        num_train_timesteps, shift=shift
         | 
| 198 | 
            +
                    )
         | 
| 199 | 
            +
                    self.shift = shift
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def get_initial_timesteps(
         | 
| 202 | 
            +
                    self, num_timesteps: int, shift: Optional[float] = None
         | 
| 203 | 
            +
                ) -> Tensor:
         | 
| 204 | 
            +
                    if self.sampler == "Uniform":
         | 
| 205 | 
            +
                        return torch.linspace(1, 1 / num_timesteps, num_timesteps)
         | 
| 206 | 
            +
                    elif self.sampler == "LinearQuadratic":
         | 
| 207 | 
            +
                        return linear_quadratic_schedule(num_timesteps)
         | 
| 208 | 
            +
                    elif self.sampler == "Constant":
         | 
| 209 | 
            +
                        assert (
         | 
| 210 | 
            +
                            shift is not None
         | 
| 211 | 
            +
                        ), "Shift must be provided for constant time shift sampler."
         | 
| 212 | 
            +
                        return time_shift(
         | 
| 213 | 
            +
                            shift, 1, torch.linspace(1, 1 / num_timesteps, num_timesteps)
         | 
| 214 | 
            +
                        )
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                def shift_timesteps(self, samples_shape: torch.Size, timesteps: Tensor) -> Tensor:
         | 
| 217 | 
            +
                    if self.shifting == "SD3":
         | 
| 218 | 
            +
                        return sd3_resolution_dependent_timestep_shift(
         | 
| 219 | 
            +
                            samples_shape, timesteps, self.target_shift_terminal
         | 
| 220 | 
            +
                        )
         | 
| 221 | 
            +
                    elif self.shifting == "SimpleDiffusion":
         | 
| 222 | 
            +
                        return simple_diffusion_resolution_dependent_timestep_shift(
         | 
| 223 | 
            +
                            samples_shape, timesteps, self.base_resolution
         | 
| 224 | 
            +
                        )
         | 
| 225 | 
            +
                    return timesteps
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                def set_timesteps(
         | 
| 228 | 
            +
                    self,
         | 
| 229 | 
            +
                    num_inference_steps: Optional[int] = None,
         | 
| 230 | 
            +
                    samples_shape: Optional[torch.Size] = None,
         | 
| 231 | 
            +
                    timesteps: Optional[Tensor] = None,
         | 
| 232 | 
            +
                    device: Union[str, torch.device] = None,
         | 
| 233 | 
            +
                ):
         | 
| 234 | 
            +
                    """
         | 
| 235 | 
            +
                    Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
         | 
| 236 | 
            +
                    If `timesteps` are provided, they will be used instead of the scheduled timesteps.
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    Args:
         | 
| 239 | 
            +
                        num_inference_steps (`int` *optional*): The number of diffusion steps used when generating samples.
         | 
| 240 | 
            +
                        samples_shape (`torch.Size` *optional*): The samples batch shape, used for shifting.
         | 
| 241 | 
            +
                        timesteps ('torch.Tensor' *optional*): Specific timesteps to use instead of scheduled timesteps.
         | 
| 242 | 
            +
                        device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved.
         | 
| 243 | 
            +
                    """
         | 
| 244 | 
            +
                    if timesteps is not None and num_inference_steps is not None:
         | 
| 245 | 
            +
                        raise ValueError(
         | 
| 246 | 
            +
                            "You cannot provide both `timesteps` and `num_inference_steps`."
         | 
| 247 | 
            +
                        )
         | 
| 248 | 
            +
                    if timesteps is None:
         | 
| 249 | 
            +
                        num_inference_steps = min(
         | 
| 250 | 
            +
                            self.config.num_train_timesteps, num_inference_steps
         | 
| 251 | 
            +
                        )
         | 
| 252 | 
            +
                        timesteps = self.get_initial_timesteps(
         | 
| 253 | 
            +
                            num_inference_steps, shift=self.shift
         | 
| 254 | 
            +
                        ).to(device)
         | 
| 255 | 
            +
                        timesteps = self.shift_timesteps(samples_shape, timesteps)
         | 
| 256 | 
            +
                    else:
         | 
| 257 | 
            +
                        timesteps = torch.Tensor(timesteps).to(device)
         | 
| 258 | 
            +
                        num_inference_steps = len(timesteps)
         | 
| 259 | 
            +
                    self.timesteps = timesteps
         | 
| 260 | 
            +
                    self.num_inference_steps = num_inference_steps
         | 
| 261 | 
            +
                    self.sigmas = self.timesteps
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                @staticmethod
         | 
| 264 | 
            +
                def from_pretrained(pretrained_model_path: Union[str, os.PathLike]):
         | 
| 265 | 
            +
                    pretrained_model_path = Path(pretrained_model_path)
         | 
| 266 | 
            +
                    if pretrained_model_path.is_file():
         | 
| 267 | 
            +
                        comfy_single_file_state_dict = {}
         | 
| 268 | 
            +
                        with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
         | 
| 269 | 
            +
                            metadata = f.metadata()
         | 
| 270 | 
            +
                            for k in f.keys():
         | 
| 271 | 
            +
                                comfy_single_file_state_dict[k] = f.get_tensor(k)
         | 
| 272 | 
            +
                        configs = json.loads(metadata["config"])
         | 
| 273 | 
            +
                        config = configs["scheduler"]
         | 
| 274 | 
            +
                        del comfy_single_file_state_dict
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    elif pretrained_model_path.is_dir():
         | 
| 277 | 
            +
                        diffusers_noise_scheduler_config_path = (
         | 
| 278 | 
            +
                            pretrained_model_path / "scheduler" / "scheduler_config.json"
         | 
| 279 | 
            +
                        )
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                        with open(diffusers_noise_scheduler_config_path, "r") as f:
         | 
| 282 | 
            +
                            scheduler_config = json.load(f)
         | 
| 283 | 
            +
                        hashable_config = make_hashable_key(scheduler_config)
         | 
| 284 | 
            +
                        if hashable_config in diffusers_and_ours_config_mapping:
         | 
| 285 | 
            +
                            config = diffusers_and_ours_config_mapping[hashable_config]
         | 
| 286 | 
            +
                    return RectifiedFlowScheduler.from_config(config)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                def scale_model_input(
         | 
| 289 | 
            +
                    self, sample: torch.FloatTensor, timestep: Optional[int] = None
         | 
| 290 | 
            +
                ) -> torch.FloatTensor:
         | 
| 291 | 
            +
                    # pylint: disable=unused-argument
         | 
| 292 | 
            +
                    """
         | 
| 293 | 
            +
                    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
         | 
| 294 | 
            +
                    current timestep.
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    Args:
         | 
| 297 | 
            +
                        sample (`torch.FloatTensor`): input sample
         | 
| 298 | 
            +
                        timestep (`int`, optional): current timestep
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    Returns:
         | 
| 301 | 
            +
                        `torch.FloatTensor`: scaled input sample
         | 
| 302 | 
            +
                    """
         | 
| 303 | 
            +
                    return sample
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                def step(
         | 
| 306 | 
            +
                    self,
         | 
| 307 | 
            +
                    model_output: torch.FloatTensor,
         | 
| 308 | 
            +
                    timestep: torch.FloatTensor,
         | 
| 309 | 
            +
                    sample: torch.FloatTensor,
         | 
| 310 | 
            +
                    return_dict: bool = True,
         | 
| 311 | 
            +
                    stochastic_sampling: Optional[bool] = False,
         | 
| 312 | 
            +
                    **kwargs,
         | 
| 313 | 
            +
                ) -> Union[RectifiedFlowSchedulerOutput, Tuple]:
         | 
| 314 | 
            +
                    """
         | 
| 315 | 
            +
                    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
         | 
| 316 | 
            +
                    process from the learned model outputs (most often the predicted noise).
         | 
| 317 | 
            +
                    z_{t_1} = z_t - \Delta_t * v
         | 
| 318 | 
            +
                    The method finds the next timestep that is lower than the input timestep(s) and denoises the latents
         | 
| 319 | 
            +
                    to that level. The input timestep(s) are not required to be one of the predefined timesteps.
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    Args:
         | 
| 322 | 
            +
                        model_output (`torch.FloatTensor`):
         | 
| 323 | 
            +
                            The direct output from learned diffusion model - the velocity,
         | 
| 324 | 
            +
                        timestep (`float`):
         | 
| 325 | 
            +
                            The current discrete timestep in the diffusion chain (global or per-token).
         | 
| 326 | 
            +
                        sample (`torch.FloatTensor`):
         | 
| 327 | 
            +
                            A current latent tokens to be de-noised.
         | 
| 328 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 329 | 
            +
                            Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
         | 
| 330 | 
            +
                        stochastic_sampling (`bool`, *optional*, defaults to `False`):
         | 
| 331 | 
            +
                            Whether to use stochastic sampling for the sampling process.
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    Returns:
         | 
| 334 | 
            +
                        [`~schedulers.scheduling_utils.RectifiedFlowSchedulerOutput`] or `tuple`:
         | 
| 335 | 
            +
                            If return_dict is `True`, [`~schedulers.rf_scheduler.RectifiedFlowSchedulerOutput`] is returned,
         | 
| 336 | 
            +
                            otherwise a tuple is returned where the first element is the sample tensor.
         | 
| 337 | 
            +
                    """
         | 
| 338 | 
            +
                    if self.num_inference_steps is None:
         | 
| 339 | 
            +
                        raise ValueError(
         | 
| 340 | 
            +
                            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
         | 
| 341 | 
            +
                        )
         | 
| 342 | 
            +
                    t_eps = 1e-6  # Small epsilon to avoid numerical issues in timestep values
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    timesteps_padded = torch.cat(
         | 
| 345 | 
            +
                        [self.timesteps, torch.zeros(1, device=self.timesteps.device)]
         | 
| 346 | 
            +
                    )
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    # Find the next lower timestep(s) and compute the dt from the current timestep(s)
         | 
| 349 | 
            +
                    if timestep.ndim == 0:
         | 
| 350 | 
            +
                        # Global timestep case
         | 
| 351 | 
            +
                        lower_mask = timesteps_padded < timestep - t_eps
         | 
| 352 | 
            +
                        lower_timestep = timesteps_padded[lower_mask][0]  # Closest lower timestep
         | 
| 353 | 
            +
                        dt = timestep - lower_timestep
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    else:
         | 
| 356 | 
            +
                        # Per-token case
         | 
| 357 | 
            +
                        assert timestep.ndim == 2
         | 
| 358 | 
            +
                        lower_mask = timesteps_padded[:, None, None] < timestep[None] - t_eps
         | 
| 359 | 
            +
                        lower_timestep = lower_mask * timesteps_padded[:, None, None]
         | 
| 360 | 
            +
                        lower_timestep, _ = lower_timestep.max(dim=0)
         | 
| 361 | 
            +
                        dt = (timestep - lower_timestep)[..., None]
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                    # Compute previous sample
         | 
| 364 | 
            +
                    if stochastic_sampling:
         | 
| 365 | 
            +
                        x0 = sample - timestep[..., None] * model_output
         | 
| 366 | 
            +
                        next_timestep = timestep[..., None] - dt
         | 
| 367 | 
            +
                        prev_sample = self.add_noise(x0, torch.randn_like(sample), next_timestep)
         | 
| 368 | 
            +
                    else:
         | 
| 369 | 
            +
                        prev_sample = sample - dt * model_output
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    if not return_dict:
         | 
| 372 | 
            +
                        return (prev_sample,)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    return RectifiedFlowSchedulerOutput(prev_sample=prev_sample)
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                def add_noise(
         | 
| 377 | 
            +
                    self,
         | 
| 378 | 
            +
                    original_samples: torch.FloatTensor,
         | 
| 379 | 
            +
                    noise: torch.FloatTensor,
         | 
| 380 | 
            +
                    timesteps: torch.FloatTensor,
         | 
| 381 | 
            +
                ) -> torch.FloatTensor:
         | 
| 382 | 
            +
                    sigmas = timesteps
         | 
| 383 | 
            +
                    sigmas = append_dims(sigmas, original_samples.ndim)
         | 
| 384 | 
            +
                    alphas = 1 - sigmas
         | 
| 385 | 
            +
                    noisy_samples = alphas * original_samples + sigmas * noise
         | 
| 386 | 
            +
                    return noisy_samples
         | 
    	
        ltx_video/utils/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        ltx_video/utils/diffusers_config_mapping.py
    ADDED
    
    | @@ -0,0 +1,174 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            def make_hashable_key(dict_key):
         | 
| 2 | 
            +
                def convert_value(value):
         | 
| 3 | 
            +
                    if isinstance(value, list):
         | 
| 4 | 
            +
                        return tuple(value)
         | 
| 5 | 
            +
                    elif isinstance(value, dict):
         | 
| 6 | 
            +
                        return tuple(sorted((k, convert_value(v)) for k, v in value.items()))
         | 
| 7 | 
            +
                    else:
         | 
| 8 | 
            +
                        return value
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items()))
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            DIFFUSERS_SCHEDULER_CONFIG = {
         | 
| 14 | 
            +
                "_class_name": "FlowMatchEulerDiscreteScheduler",
         | 
| 15 | 
            +
                "_diffusers_version": "0.32.0.dev0",
         | 
| 16 | 
            +
                "base_image_seq_len": 1024,
         | 
| 17 | 
            +
                "base_shift": 0.95,
         | 
| 18 | 
            +
                "invert_sigmas": False,
         | 
| 19 | 
            +
                "max_image_seq_len": 4096,
         | 
| 20 | 
            +
                "max_shift": 2.05,
         | 
| 21 | 
            +
                "num_train_timesteps": 1000,
         | 
| 22 | 
            +
                "shift": 1.0,
         | 
| 23 | 
            +
                "shift_terminal": 0.1,
         | 
| 24 | 
            +
                "use_beta_sigmas": False,
         | 
| 25 | 
            +
                "use_dynamic_shifting": True,
         | 
| 26 | 
            +
                "use_exponential_sigmas": False,
         | 
| 27 | 
            +
                "use_karras_sigmas": False,
         | 
| 28 | 
            +
            }
         | 
| 29 | 
            +
            DIFFUSERS_TRANSFORMER_CONFIG = {
         | 
| 30 | 
            +
                "_class_name": "LTXVideoTransformer3DModel",
         | 
| 31 | 
            +
                "_diffusers_version": "0.32.0.dev0",
         | 
| 32 | 
            +
                "activation_fn": "gelu-approximate",
         | 
| 33 | 
            +
                "attention_bias": True,
         | 
| 34 | 
            +
                "attention_head_dim": 64,
         | 
| 35 | 
            +
                "attention_out_bias": True,
         | 
| 36 | 
            +
                "caption_channels": 4096,
         | 
| 37 | 
            +
                "cross_attention_dim": 2048,
         | 
| 38 | 
            +
                "in_channels": 128,
         | 
| 39 | 
            +
                "norm_elementwise_affine": False,
         | 
| 40 | 
            +
                "norm_eps": 1e-06,
         | 
| 41 | 
            +
                "num_attention_heads": 32,
         | 
| 42 | 
            +
                "num_layers": 28,
         | 
| 43 | 
            +
                "out_channels": 128,
         | 
| 44 | 
            +
                "patch_size": 1,
         | 
| 45 | 
            +
                "patch_size_t": 1,
         | 
| 46 | 
            +
                "qk_norm": "rms_norm_across_heads",
         | 
| 47 | 
            +
            }
         | 
| 48 | 
            +
            DIFFUSERS_VAE_CONFIG = {
         | 
| 49 | 
            +
                "_class_name": "AutoencoderKLLTXVideo",
         | 
| 50 | 
            +
                "_diffusers_version": "0.32.0.dev0",
         | 
| 51 | 
            +
                "block_out_channels": [128, 256, 512, 512],
         | 
| 52 | 
            +
                "decoder_causal": False,
         | 
| 53 | 
            +
                "encoder_causal": True,
         | 
| 54 | 
            +
                "in_channels": 3,
         | 
| 55 | 
            +
                "latent_channels": 128,
         | 
| 56 | 
            +
                "layers_per_block": [4, 3, 3, 3, 4],
         | 
| 57 | 
            +
                "out_channels": 3,
         | 
| 58 | 
            +
                "patch_size": 4,
         | 
| 59 | 
            +
                "patch_size_t": 1,
         | 
| 60 | 
            +
                "resnet_norm_eps": 1e-06,
         | 
| 61 | 
            +
                "scaling_factor": 1.0,
         | 
| 62 | 
            +
                "spatio_temporal_scaling": [True, True, True, False],
         | 
| 63 | 
            +
            }
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            OURS_SCHEDULER_CONFIG = {
         | 
| 66 | 
            +
                "_class_name": "RectifiedFlowScheduler",
         | 
| 67 | 
            +
                "_diffusers_version": "0.25.1",
         | 
| 68 | 
            +
                "num_train_timesteps": 1000,
         | 
| 69 | 
            +
                "shifting": "SD3",
         | 
| 70 | 
            +
                "base_resolution": None,
         | 
| 71 | 
            +
                "target_shift_terminal": 0.1,
         | 
| 72 | 
            +
            }
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            OURS_TRANSFORMER_CONFIG = {
         | 
| 75 | 
            +
                "_class_name": "Transformer3DModel",
         | 
| 76 | 
            +
                "_diffusers_version": "0.25.1",
         | 
| 77 | 
            +
                "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256",
         | 
| 78 | 
            +
                "activation_fn": "gelu-approximate",
         | 
| 79 | 
            +
                "attention_bias": True,
         | 
| 80 | 
            +
                "attention_head_dim": 64,
         | 
| 81 | 
            +
                "attention_type": "default",
         | 
| 82 | 
            +
                "caption_channels": 4096,
         | 
| 83 | 
            +
                "cross_attention_dim": 2048,
         | 
| 84 | 
            +
                "double_self_attention": False,
         | 
| 85 | 
            +
                "dropout": 0.0,
         | 
| 86 | 
            +
                "in_channels": 128,
         | 
| 87 | 
            +
                "norm_elementwise_affine": False,
         | 
| 88 | 
            +
                "norm_eps": 1e-06,
         | 
| 89 | 
            +
                "norm_num_groups": 32,
         | 
| 90 | 
            +
                "num_attention_heads": 32,
         | 
| 91 | 
            +
                "num_embeds_ada_norm": 1000,
         | 
| 92 | 
            +
                "num_layers": 28,
         | 
| 93 | 
            +
                "num_vector_embeds": None,
         | 
| 94 | 
            +
                "only_cross_attention": False,
         | 
| 95 | 
            +
                "out_channels": 128,
         | 
| 96 | 
            +
                "project_to_2d_pos": True,
         | 
| 97 | 
            +
                "upcast_attention": False,
         | 
| 98 | 
            +
                "use_linear_projection": False,
         | 
| 99 | 
            +
                "qk_norm": "rms_norm",
         | 
| 100 | 
            +
                "standardization_norm": "rms_norm",
         | 
| 101 | 
            +
                "positional_embedding_type": "rope",
         | 
| 102 | 
            +
                "positional_embedding_theta": 10000.0,
         | 
| 103 | 
            +
                "positional_embedding_max_pos": [20, 2048, 2048],
         | 
| 104 | 
            +
                "timestep_scale_multiplier": 1000,
         | 
| 105 | 
            +
            }
         | 
| 106 | 
            +
            OURS_VAE_CONFIG = {
         | 
| 107 | 
            +
                "_class_name": "CausalVideoAutoencoder",
         | 
| 108 | 
            +
                "dims": 3,
         | 
| 109 | 
            +
                "in_channels": 3,
         | 
| 110 | 
            +
                "out_channels": 3,
         | 
| 111 | 
            +
                "latent_channels": 128,
         | 
| 112 | 
            +
                "blocks": [
         | 
| 113 | 
            +
                    ["res_x", 4],
         | 
| 114 | 
            +
                    ["compress_all", 1],
         | 
| 115 | 
            +
                    ["res_x_y", 1],
         | 
| 116 | 
            +
                    ["res_x", 3],
         | 
| 117 | 
            +
                    ["compress_all", 1],
         | 
| 118 | 
            +
                    ["res_x_y", 1],
         | 
| 119 | 
            +
                    ["res_x", 3],
         | 
| 120 | 
            +
                    ["compress_all", 1],
         | 
| 121 | 
            +
                    ["res_x", 3],
         | 
| 122 | 
            +
                    ["res_x", 4],
         | 
| 123 | 
            +
                ],
         | 
| 124 | 
            +
                "scaling_factor": 1.0,
         | 
| 125 | 
            +
                "norm_layer": "pixel_norm",
         | 
| 126 | 
            +
                "patch_size": 4,
         | 
| 127 | 
            +
                "latent_log_var": "uniform",
         | 
| 128 | 
            +
                "use_quant_conv": False,
         | 
| 129 | 
            +
                "causal_decoder": False,
         | 
| 130 | 
            +
            }
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            diffusers_and_ours_config_mapping = {
         | 
| 134 | 
            +
                make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG,
         | 
| 135 | 
            +
                make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG,
         | 
| 136 | 
            +
                make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG,
         | 
| 137 | 
            +
            }
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            TRANSFORMER_KEYS_RENAME_DICT = {
         | 
| 141 | 
            +
                "proj_in": "patchify_proj",
         | 
| 142 | 
            +
                "time_embed": "adaln_single",
         | 
| 143 | 
            +
                "norm_q": "q_norm",
         | 
| 144 | 
            +
                "norm_k": "k_norm",
         | 
| 145 | 
            +
            }
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            VAE_KEYS_RENAME_DICT = {
         | 
| 149 | 
            +
                "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7",
         | 
| 150 | 
            +
                "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8",
         | 
| 151 | 
            +
                "decoder.up_blocks.3": "decoder.up_blocks.9",
         | 
| 152 | 
            +
                "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5",
         | 
| 153 | 
            +
                "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4",
         | 
| 154 | 
            +
                "decoder.up_blocks.2": "decoder.up_blocks.6",
         | 
| 155 | 
            +
                "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2",
         | 
| 156 | 
            +
                "decoder.up_blocks.1": "decoder.up_blocks.3",
         | 
| 157 | 
            +
                "decoder.up_blocks.0": "decoder.up_blocks.1",
         | 
| 158 | 
            +
                "decoder.mid_block": "decoder.up_blocks.0",
         | 
| 159 | 
            +
                "encoder.down_blocks.3": "encoder.down_blocks.8",
         | 
| 160 | 
            +
                "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7",
         | 
| 161 | 
            +
                "encoder.down_blocks.2": "encoder.down_blocks.6",
         | 
| 162 | 
            +
                "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4",
         | 
| 163 | 
            +
                "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5",
         | 
| 164 | 
            +
                "encoder.down_blocks.1": "encoder.down_blocks.3",
         | 
| 165 | 
            +
                "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2",
         | 
| 166 | 
            +
                "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1",
         | 
| 167 | 
            +
                "encoder.down_blocks.0": "encoder.down_blocks.0",
         | 
| 168 | 
            +
                "encoder.mid_block": "encoder.down_blocks.9",
         | 
| 169 | 
            +
                "conv_shortcut.conv": "conv_shortcut",
         | 
| 170 | 
            +
                "resnets": "res_blocks",
         | 
| 171 | 
            +
                "norm3": "norm3.norm",
         | 
| 172 | 
            +
                "latents_mean": "per_channel_statistics.mean-of-means",
         | 
| 173 | 
            +
                "latents_std": "per_channel_statistics.std-of-means",
         | 
| 174 | 
            +
            }
         | 
    	
        ltx_video/utils/prompt_enhance_utils.py
    ADDED
    
    | @@ -0,0 +1,226 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            from typing import Union, List, Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
         | 
| 10 | 
            +
            Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
         | 
| 11 | 
            +
            Start directly with the action, and keep descriptions literal and precise.
         | 
| 12 | 
            +
            Think like a cinematographer describing a shot list.
         | 
| 13 | 
            +
            Do not change the user input intent, just enhance it.
         | 
| 14 | 
            +
            Keep within 150 words.
         | 
| 15 | 
            +
            For best results, build your prompts using this structure:
         | 
| 16 | 
            +
            Start with main action in a single sentence
         | 
| 17 | 
            +
            Add specific details about movements and gestures
         | 
| 18 | 
            +
            Describe character/object appearances precisely
         | 
| 19 | 
            +
            Include background and environment details
         | 
| 20 | 
            +
            Specify camera angles and movements
         | 
| 21 | 
            +
            Describe lighting and colors
         | 
| 22 | 
            +
            Note any changes or sudden events
         | 
| 23 | 
            +
            Do not exceed the 150 word limit!
         | 
| 24 | 
            +
            Output the enhanced prompt only.
         | 
| 25 | 
            +
            """
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
         | 
| 28 | 
            +
            Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
         | 
| 29 | 
            +
            Start directly with the action, and keep descriptions literal and precise.
         | 
| 30 | 
            +
            Think like a cinematographer describing a shot list.
         | 
| 31 | 
            +
            Keep within 150 words.
         | 
| 32 | 
            +
            For best results, build your prompts using this structure:
         | 
| 33 | 
            +
            Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input.
         | 
| 34 | 
            +
            Start with main action in a single sentence
         | 
| 35 | 
            +
            Add specific details about movements and gestures
         | 
| 36 | 
            +
            Describe character/object appearances precisely
         | 
| 37 | 
            +
            Include background and environment details
         | 
| 38 | 
            +
            Specify camera angles and movements
         | 
| 39 | 
            +
            Describe lighting and colors
         | 
| 40 | 
            +
            Note any changes or sudden events
         | 
| 41 | 
            +
            Align to the image caption if it contradicts the user text input.
         | 
| 42 | 
            +
            Do not exceed the 150 word limit!
         | 
| 43 | 
            +
            Output the enhanced prompt only.
         | 
| 44 | 
            +
            """
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            def tensor_to_pil(tensor):
         | 
| 48 | 
            +
                # Ensure tensor is in range [-1, 1]
         | 
| 49 | 
            +
                assert tensor.min() >= -1 and tensor.max() <= 1
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                # Convert from [-1, 1] to [0, 1]
         | 
| 52 | 
            +
                tensor = (tensor + 1) / 2
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # Rearrange from [C, H, W] to [H, W, C]
         | 
| 55 | 
            +
                tensor = tensor.permute(1, 2, 0)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                # Convert to numpy array and then to uint8 range [0, 255]
         | 
| 58 | 
            +
                numpy_image = (tensor.cpu().numpy() * 255).astype("uint8")
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                # Convert to PIL Image
         | 
| 61 | 
            +
                return Image.fromarray(numpy_image)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            def generate_cinematic_prompt(
         | 
| 65 | 
            +
                image_caption_model,
         | 
| 66 | 
            +
                image_caption_processor,
         | 
| 67 | 
            +
                prompt_enhancer_model,
         | 
| 68 | 
            +
                prompt_enhancer_tokenizer,
         | 
| 69 | 
            +
                prompt: Union[str, List[str]],
         | 
| 70 | 
            +
                conditioning_items: Optional[List] = None,
         | 
| 71 | 
            +
                max_new_tokens: int = 256,
         | 
| 72 | 
            +
            ) -> List[str]:
         | 
| 73 | 
            +
                prompts = [prompt] if isinstance(prompt, str) else prompt
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                if conditioning_items is None:
         | 
| 76 | 
            +
                    prompts = _generate_t2v_prompt(
         | 
| 77 | 
            +
                        prompt_enhancer_model,
         | 
| 78 | 
            +
                        prompt_enhancer_tokenizer,
         | 
| 79 | 
            +
                        prompts,
         | 
| 80 | 
            +
                        max_new_tokens,
         | 
| 81 | 
            +
                        T2V_CINEMATIC_PROMPT,
         | 
| 82 | 
            +
                    )
         | 
| 83 | 
            +
                else:
         | 
| 84 | 
            +
                    if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0:
         | 
| 85 | 
            +
                        logger.warning(
         | 
| 86 | 
            +
                            "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts"
         | 
| 87 | 
            +
                        )
         | 
| 88 | 
            +
                        return prompts
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    first_frame_conditioning_item = conditioning_items[0]
         | 
| 91 | 
            +
                    first_frames = _get_first_frames_from_conditioning_item(
         | 
| 92 | 
            +
                        first_frame_conditioning_item
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    assert len(first_frames) == len(
         | 
| 96 | 
            +
                        prompts
         | 
| 97 | 
            +
                    ), "Number of conditioning frames must match number of prompts"
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    prompts = _generate_i2v_prompt(
         | 
| 100 | 
            +
                        image_caption_model,
         | 
| 101 | 
            +
                        image_caption_processor,
         | 
| 102 | 
            +
                        prompt_enhancer_model,
         | 
| 103 | 
            +
                        prompt_enhancer_tokenizer,
         | 
| 104 | 
            +
                        prompts,
         | 
| 105 | 
            +
                        first_frames,
         | 
| 106 | 
            +
                        max_new_tokens,
         | 
| 107 | 
            +
                        I2V_CINEMATIC_PROMPT,
         | 
| 108 | 
            +
                    )
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                return prompts
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]:
         | 
| 114 | 
            +
                frames_tensor = conditioning_item.media_item
         | 
| 115 | 
            +
                return [
         | 
| 116 | 
            +
                    tensor_to_pil(frames_tensor[i, :, 0, :, :])
         | 
| 117 | 
            +
                    for i in range(frames_tensor.shape[0])
         | 
| 118 | 
            +
                ]
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            def _generate_t2v_prompt(
         | 
| 122 | 
            +
                prompt_enhancer_model,
         | 
| 123 | 
            +
                prompt_enhancer_tokenizer,
         | 
| 124 | 
            +
                prompts: List[str],
         | 
| 125 | 
            +
                max_new_tokens: int,
         | 
| 126 | 
            +
                system_prompt: str,
         | 
| 127 | 
            +
            ) -> List[str]:
         | 
| 128 | 
            +
                messages = [
         | 
| 129 | 
            +
                    [
         | 
| 130 | 
            +
                        {"role": "system", "content": system_prompt},
         | 
| 131 | 
            +
                        {"role": "user", "content": f"user_prompt: {p}"},
         | 
| 132 | 
            +
                    ]
         | 
| 133 | 
            +
                    for p in prompts
         | 
| 134 | 
            +
                ]
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                texts = [
         | 
| 137 | 
            +
                    prompt_enhancer_tokenizer.apply_chat_template(
         | 
| 138 | 
            +
                        m, tokenize=False, add_generation_prompt=True
         | 
| 139 | 
            +
                    )
         | 
| 140 | 
            +
                    for m in messages
         | 
| 141 | 
            +
                ]
         | 
| 142 | 
            +
                model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(
         | 
| 143 | 
            +
                    prompt_enhancer_model.device
         | 
| 144 | 
            +
                )
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                return _generate_and_decode_prompts(
         | 
| 147 | 
            +
                    prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens
         | 
| 148 | 
            +
                )
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            def _generate_i2v_prompt(
         | 
| 152 | 
            +
                image_caption_model,
         | 
| 153 | 
            +
                image_caption_processor,
         | 
| 154 | 
            +
                prompt_enhancer_model,
         | 
| 155 | 
            +
                prompt_enhancer_tokenizer,
         | 
| 156 | 
            +
                prompts: List[str],
         | 
| 157 | 
            +
                first_frames: List[Image.Image],
         | 
| 158 | 
            +
                max_new_tokens: int,
         | 
| 159 | 
            +
                system_prompt: str,
         | 
| 160 | 
            +
            ) -> List[str]:
         | 
| 161 | 
            +
                image_captions = _generate_image_captions(
         | 
| 162 | 
            +
                    image_caption_model, image_caption_processor, first_frames
         | 
| 163 | 
            +
                )
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                messages = [
         | 
| 166 | 
            +
                    [
         | 
| 167 | 
            +
                        {"role": "system", "content": system_prompt},
         | 
| 168 | 
            +
                        {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"},
         | 
| 169 | 
            +
                    ]
         | 
| 170 | 
            +
                    for p, c in zip(prompts, image_captions)
         | 
| 171 | 
            +
                ]
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                texts = [
         | 
| 174 | 
            +
                    prompt_enhancer_tokenizer.apply_chat_template(
         | 
| 175 | 
            +
                        m, tokenize=False, add_generation_prompt=True
         | 
| 176 | 
            +
                    )
         | 
| 177 | 
            +
                    for m in messages
         | 
| 178 | 
            +
                ]
         | 
| 179 | 
            +
                model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(
         | 
| 180 | 
            +
                    prompt_enhancer_model.device
         | 
| 181 | 
            +
                )
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                return _generate_and_decode_prompts(
         | 
| 184 | 
            +
                    prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens
         | 
| 185 | 
            +
                )
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            def _generate_image_captions(
         | 
| 189 | 
            +
                image_caption_model,
         | 
| 190 | 
            +
                image_caption_processor,
         | 
| 191 | 
            +
                images: List[Image.Image],
         | 
| 192 | 
            +
                system_prompt: str = "<DETAILED_CAPTION>",
         | 
| 193 | 
            +
            ) -> List[str]:
         | 
| 194 | 
            +
                image_caption_prompts = [system_prompt] * len(images)
         | 
| 195 | 
            +
                inputs = image_caption_processor(
         | 
| 196 | 
            +
                    image_caption_prompts, images, return_tensors="pt"
         | 
| 197 | 
            +
                ).to(image_caption_model.device)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                with torch.inference_mode():
         | 
| 200 | 
            +
                    generated_ids = image_caption_model.generate(
         | 
| 201 | 
            +
                        input_ids=inputs["input_ids"],
         | 
| 202 | 
            +
                        pixel_values=inputs["pixel_values"],
         | 
| 203 | 
            +
                        max_new_tokens=1024,
         | 
| 204 | 
            +
                        do_sample=False,
         | 
| 205 | 
            +
                        num_beams=3,
         | 
| 206 | 
            +
                    )
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
             | 
| 211 | 
            +
            def _generate_and_decode_prompts(
         | 
| 212 | 
            +
                prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int
         | 
| 213 | 
            +
            ) -> List[str]:
         | 
| 214 | 
            +
                with torch.inference_mode():
         | 
| 215 | 
            +
                    outputs = prompt_enhancer_model.generate(
         | 
| 216 | 
            +
                        **model_inputs, max_new_tokens=max_new_tokens
         | 
| 217 | 
            +
                    )
         | 
| 218 | 
            +
                    generated_ids = [
         | 
| 219 | 
            +
                        output_ids[len(input_ids) :]
         | 
| 220 | 
            +
                        for input_ids, output_ids in zip(model_inputs.input_ids, outputs)
         | 
| 221 | 
            +
                    ]
         | 
| 222 | 
            +
                    decoded_prompts = prompt_enhancer_tokenizer.batch_decode(
         | 
| 223 | 
            +
                        generated_ids, skip_special_tokens=True
         | 
| 224 | 
            +
                    )
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                return decoded_prompts
         | 
    	
        ltx_video/utils/skip_layer_strategy.py
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from enum import Enum, auto
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class SkipLayerStrategy(Enum):
         | 
| 5 | 
            +
                AttentionSkip = auto()
         | 
| 6 | 
            +
                AttentionValues = auto()
         | 
| 7 | 
            +
                Residual = auto()
         | 
| 8 | 
            +
                TransformerBlock = auto()
         | 
    	
        ltx_video/utils/torch_utils.py
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from torch import nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
         | 
| 6 | 
            +
                """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
         | 
| 7 | 
            +
                dims_to_append = target_dims - x.ndim
         | 
| 8 | 
            +
                if dims_to_append < 0:
         | 
| 9 | 
            +
                    raise ValueError(
         | 
| 10 | 
            +
                        f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
         | 
| 11 | 
            +
                    )
         | 
| 12 | 
            +
                elif dims_to_append == 0:
         | 
| 13 | 
            +
                    return x
         | 
| 14 | 
            +
                return x[(...,) + (None,) * dims_to_append]
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class Identity(nn.Module):
         | 
| 18 | 
            +
                """A placeholder identity operator that is argument-insensitive."""
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                def __init__(self, *args, **kwargs) -> None:  # pylint: disable=unused-argument
         | 
| 21 | 
            +
                    super().__init__()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                # pylint: disable=unused-argument
         | 
| 24 | 
            +
                def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
         | 
| 25 | 
            +
                    return x
         | 
 
			
