diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..c91de452ed26ee69fcc20bac4ba01b7a4e464856 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,32 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/0_edit.png filter=lfs diff=lfs merge=lfs -text
+assets/0.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/1_edit.png filter=lfs diff=lfs merge=lfs -text
+assets/1.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/2_edit.png filter=lfs diff=lfs merge=lfs -text
+assets/2.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/3_edit.png filter=lfs diff=lfs merge=lfs -text
+assets/3.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/4_edit.png filter=lfs diff=lfs merge=lfs -text
+assets/4.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/5_edit.png filter=lfs diff=lfs merge=lfs -text
+assets/5.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/6_edit.png filter=lfs diff=lfs merge=lfs -text
+assets/6.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/7_edit.png filter=lfs diff=lfs merge=lfs -text
+assets/7.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/8_edit.png filter=lfs diff=lfs merge=lfs -text
+assets/8.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/9_edit.png filter=lfs diff=lfs merge=lfs -text
+assets/9.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/outputvideo/output_0.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/outputvideo/output_1.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/outputvideo/output_2.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/outputvideo/output_3.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/outputvideo/output_5.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/outputvideo/output_6.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/outputvideo/output_7.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/outputvideo/output_8.mp4 filter=lfs diff=lfs merge=lfs -text
+control_cogvideox/__pycache__/embeddings.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
diff --git a/app_video_image_guidence.py b/app_video_image_guidence.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef04a1e54a5fdc61a0f7f226e2d48e6f10014a6f
--- /dev/null
+++ b/app_video_image_guidence.py
@@ -0,0 +1,258 @@
+import gradio as gr
+
+import cv2
+import torch
+import numpy as np
+import os
+from control_cogvideox.cogvideox_transformer_3d import CogVideoXTransformer3DModel
+from control_cogvideox.controlnet_cogvideox_transformer_3d import ControlCogVideoXTransformer3DModel
+from pipeline_cogvideox_controlnet_5b_i2v_instruction2 import ControlCogVideoXPipeline
+from diffusers.utils import export_to_video
+from diffusers import AutoencoderKLCogVideoX
+from transformers import T5EncoderModel, T5Tokenizer
+from diffusers.schedulers import CogVideoXDDIMScheduler
+
+from omegaconf import OmegaConf
+from transformers import T5EncoderModel
+from einops import rearrange
+import decord
+from typing import List
+from tqdm import tqdm
+
+import PIL
+import torch.nn.functional as F
+from torchvision import transforms
+
+def get_prompt(file:str):
+ with open(file,'r') as f:
+ a=f.readlines()
+ return a #a[0]:positive prompt, a[1] negative prompt
+
+def init_pipe():
+ def unwarp_model(state_dict):
+ new_state_dict = {}
+ for key in state_dict:
+ new_state_dict[key.split('module.')[1]] = state_dict[key]
+ return new_state_dict
+
+ i2v=True
+ root_path="./cogvideox_instructions_lr_1e_5_bs_48_2_epoch_params_controlnet_first_frame_5b_mixed_precision_480_896/cogvideox-2025-01-15T12-42-19/"
+ training_steps=30001
+
+ if i2v:
+ key = "i2v"
+ else:
+ key = "t2v"
+ noise_scheduler = CogVideoXDDIMScheduler(
+ **OmegaConf.to_container(
+ OmegaConf.load(f"./cogvideox-5b-{key}/scheduler/scheduler_config.json")
+ )
+ )
+
+ text_encoder = T5EncoderModel.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="text_encoder", torch_dtype=torch.float16)#.to("cuda:0")
+ vae = AutoencoderKLCogVideoX.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="vae", torch_dtype=torch.float16).to("cuda:0")
+ tokenizer = T5Tokenizer.from_pretrained(f"./cogvideox-5b-{key}/tokenizer", torch_dtype=torch.float16)
+
+
+ config = OmegaConf.to_container(
+ OmegaConf.load(f"./cogvideox-5b-{key}/transformer/config.json")
+ )
+ if i2v:
+ config["in_channels"] = 32
+ else:
+ config["in_channels"] = 16
+ transformer = CogVideoXTransformer3DModel(**config)
+
+ control_config = OmegaConf.to_container(
+ OmegaConf.load(f"./cogvideox-5b-{key}/transformer/config.json")
+ )
+ if i2v:
+ control_config["in_channels"] = 32
+ else:
+ control_config["in_channels"] = 16
+ control_config['num_layers'] = 6
+ control_config['control_in_channels'] = 16
+ controlnet_transformer = ControlCogVideoXTransformer3DModel(**control_config)
+
+ all_state_dicts = torch.load(f"{root_path}/checkpoints/checkpoint{training_steps}.ckpt", map_location="cpu",weights_only=True)
+ transformer_state_dict = unwarp_model(all_state_dicts["transformer_state_dict"])
+ controlnet_transformer_state_dict = unwarp_model(all_state_dicts["controlnet_transformer_state_dict"])
+
+ transformer.load_state_dict(transformer_state_dict, strict=True)
+ controlnet_transformer.load_state_dict(controlnet_transformer_state_dict, strict=True)
+
+ transformer = transformer.half().to("cuda:0")
+ controlnet_transformer = controlnet_transformer.half().to("cuda:0")
+
+ vae = vae.eval()
+ text_encoder = text_encoder.eval()
+ transformer = transformer.eval()
+ controlnet_transformer = controlnet_transformer.eval()
+
+ pipe = ControlCogVideoXPipeline(tokenizer,
+ text_encoder,
+ vae,
+ transformer,
+ noise_scheduler,
+ controlnet_transformer,
+ )
+
+ pipe.vae.enable_slicing()
+ pipe.vae.enable_tiling()
+ pipe.enable_model_cpu_offload()
+
+ return pipe
+
+def inference(source_images,
+ target_images,
+ text_prompt, negative_prompt,
+ pipe, vae,
+ step, guidance_scale,
+ h, w, random_seed)->List[PIL.Image.Image]:
+ torch.manual_seed(random_seed)
+
+ source_pixel_values = source_images/127.5 - 1.0
+ source_pixel_values = source_pixel_values.to(torch.float16).to("cuda:0")
+ if target_images is not None:
+ target_pixel_values = target_images/127.5 - 1.0
+ target_pixel_values = target_pixel_values.to(torch.float16).to("cuda:0")
+ bsz,f,h,w,c = source_pixel_values.shape
+
+ with torch.no_grad():
+ source_pixel_values = rearrange(source_pixel_values, "b f w h c -> b c f w h")
+ source_latents = vae.encode(source_pixel_values).latent_dist.sample()
+ source_latents = source_latents.to(torch.float16)
+ source_latents = source_latents * vae.config.scaling_factor
+ source_latents = rearrange(source_latents, "b c f h w -> b f c h w")
+
+ if target_images is not None:
+ target_pixel_values = rearrange(target_pixel_values, "b f w h c -> b c f w h")
+ images = target_pixel_values[:,:,:1,...]
+ image_latents = vae.encode(images).latent_dist.sample()
+ image_latents = image_latents.to(torch.float16)
+ image_latents = image_latents * vae.config.scaling_factor
+ image_latents = rearrange(image_latents, "b c f h w -> b f c h w")
+ image_latents = torch.cat([image_latents, torch.zeros_like(source_latents)[:,1:]],dim=1)
+ latents = torch.cat([image_latents, source_latents], dim=2)
+ else:
+ image_latents = None
+ latents = source_latents
+
+ video = pipe(
+ prompt = text_prompt,
+ negative_prompt = negative_prompt,
+ video_condition = source_latents, # input to controlnet
+ video_condition2 = image_latents, # concat with latents
+ height = h,
+ width = w,
+ num_frames = f,
+ num_inference_steps = 50,
+ interval = 6,
+ guidance_scale = guidance_scale,
+ generator = torch.Generator(device=f"cuda:0").manual_seed(random_seed)
+ ).frames[0]
+
+ return video
+
+def process_video(video_file, image_file, positive_prompt, negative_prompt, guidance, random_seed, choice, progress=gr.Progress(track_tqdm=True))->str:
+ if choice==33:
+ video_shard=1
+ elif choice==65:
+ video_shard=2
+
+ pipe=PIPE
+
+ h = 448
+ w = 768
+ step=30001 #checkpoint
+ frames_per_shard=33
+
+ #get image
+ image = cv2.imread(image_file)
+ resized_image = cv2.resize(image, (768, 448))
+ resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)
+ image=torch.from_numpy(resized_image)
+ #get mp4
+ vr = decord.VideoReader(video_file)
+ frames = vr.get_batch(list(range(33))).asnumpy()
+ _,src_h,src_w,_=frames.shape
+ resized_frames = [cv2.resize(frame, (768, 448)) for frame in frames]
+ images=torch.from_numpy(np.array(resized_frames))
+
+ target_path="outputvideo"
+ source_images = images[None,...]
+ target_images = image[None,None,...]
+
+ video:List[PIL.Image.Image]=[]
+
+
+ for i in progress.tqdm(range(video_shard)):
+ if i>0: #first frame guidence
+ first_frame=transforms.ToTensor()(video[-1])
+ first_frame = first_frame*255.0
+ first_frame = rearrange(first_frame,"c w h -> w h c")
+ source_images=source_images#仍用原视频引导
+ target_images=first_frame[None,None,...]
+
+ video+=inference(source_images, \
+ target_images, positive_prompt, \
+ negative_prompt, pipe, pipe.vae, \
+ step, guidance, \
+ h, w, random_seed)
+ i+=1
+
+ video=[image.resize((int(src_w/src_h*448),448))for image in video]
+
+ os.makedirs(f"./{target_path}", exist_ok=True)
+ output_path:str=f"./{target_path}/output_{video_file[-5]}.mp4"
+ export_to_video(video, output_path, fps=8)
+ return output_path
+
+
+PIPE=init_pipe()
+
+with gr.Blocks() as demo:
+ gr.Markdown("""
+ # Señorita-2M: A High-Quality Instruction-based Dataset for General Video Editing by Video Specialists
+
+ [Paper](https://arxiv.org/bas/2502.06734) | [Code](https://127.0.0.1:7860) | [Huggingface](https://127.0.0.1:7860)
+ """)
+ #gr.HTML(open("gradio_title.md",'r').read())
+
+ with gr.Row():
+ video_input = gr.Video(label="Video input")
+ image_input = gr.Image(type="filepath", label="First frame guidence")
+ with gr.Row():
+ with gr.Column():
+ positive_prompt = gr.Textbox(label="Positive prompt",value="")
+ negative_prompt = gr.Textbox(label="Negative prompt",value="")
+ seed = gr.Slider(minimum=0, maximum=2147483647, step=1, value=0, label="Seed")
+ guidance_slider = gr.Slider(minimum=1, maximum=10, value=4, label="Guidance")
+ choice=gr.Radio(choices=[33,65],label="Frame number",value=33)
+ with gr.Column():
+ video_output = gr.Video(label="Video output")
+
+ with gr.Row():
+ submit_button = gr.Button("Generate")
+ submit_button.click(fn=process_video, inputs=[video_input, image_input, positive_prompt, negative_prompt, guidance_slider, seed, choice], outputs=video_output)
+ with gr.Row():
+ gr.Examples(
+ [
+ ["assets/0.mp4","assets/0_edit.png",get_prompt("assets/0.txt")[0],get_prompt("assets/0.txt")[1],4,0,33],
+ ["assets/1.mp4","assets/1_edit.png",get_prompt("assets/1.txt")[0],get_prompt("assets/1.txt")[1],4,0,33],
+ ["assets/2.mp4","assets/2_edit.png",get_prompt("assets/2.txt")[0],get_prompt("assets/2.txt")[1],4,0,33],
+ ["assets/3.mp4","assets/3_edit.png",get_prompt("assets/3.txt")[0],get_prompt("assets/3.txt")[1],4,0,33],
+ ["assets/4.mp4","assets/4_edit.png",get_prompt("assets/4.txt")[0],get_prompt("assets/4.txt")[1],4,0,33],
+ ["assets/5.mp4","assets/5_edit.png",get_prompt("assets/5.txt")[0],get_prompt("assets/5.txt")[1],4,0,33],
+ ["assets/6.mp4","assets/6_edit.png",get_prompt("assets/6.txt")[0],get_prompt("assets/6.txt")[1],4,0,33],
+ ["assets/7.mp4","assets/7_edit.png",get_prompt("assets/7.txt")[0],get_prompt("assets/7.txt")[1],4,0,33],
+ ["assets/8.mp4","assets/8_edit.png",get_prompt("assets/8.txt")[0],get_prompt("assets/8.txt")[1],4,0,33]
+ ],
+ inputs=[video_input, image_input, positive_prompt, negative_prompt, guidance_slider, seed, choice],
+ outputs=video_output,
+ fn=process_video,
+ cache_examples=False
+ )
+
+if __name__ == "__main__":
+ demo.queue().launch()
diff --git a/assets/0.mp4 b/assets/0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..7c29f675813094fbe637976131f5f7e544e97433
--- /dev/null
+++ b/assets/0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6b9b4cbbc26fd2c76e5339fc9868b97cb3a9e6dfd394a99a6217f8f2070ad4af
+size 1234363
diff --git a/assets/0.txt b/assets/0.txt
new file mode 100644
index 0000000000000000000000000000000000000000..39ca707c6acfee822f309ad07b23f099148d7cf5
--- /dev/null
+++ b/assets/0.txt
@@ -0,0 +1,4 @@
+Swap dog for fox. prompt: the fox with two ears. the motion is clear. The background is strictly aligned.
+The fox with three ears.
+
+
diff --git a/assets/0_edit.png b/assets/0_edit.png
new file mode 100644
index 0000000000000000000000000000000000000000..937662131dfe6523ccfecff46bf148f88352066c
--- /dev/null
+++ b/assets/0_edit.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:64be8927b0ce4b390c9c762fbd2358e39bc9e79ff5ef706f53511c91aadbf89e
+size 1755272
diff --git a/assets/1.mp4 b/assets/1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..78c7dfd277efe558aa3c3dc22d0cef74742c15e4
--- /dev/null
+++ b/assets/1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b5bc1baee8501d33e58ea3a914b90f8737b3dc692cd21c313a790b7c90d2ed1
+size 583649
diff --git a/assets/1.txt b/assets/1.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8ad7a3f0d0d1bd43bb5da8f3042f6d3be5978ec6
--- /dev/null
+++ b/assets/1.txt
@@ -0,0 +1,2 @@
+Make it anime style. prompt: the flower is swaying in the wind. the video is captured by professional camera. The motion is stable. Best quality.
+bad quality.
\ No newline at end of file
diff --git a/assets/1_edit.png b/assets/1_edit.png
new file mode 100644
index 0000000000000000000000000000000000000000..ce585f9c8405b2fc16c53014b3baa467ba0c237a
--- /dev/null
+++ b/assets/1_edit.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:974df3936f42e277c407a29e7dcb3b7e87b32399857e58e09fec706ffd8f76da
+size 1375616
diff --git a/assets/2.mp4 b/assets/2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0e6b7020d3cb1dfd963472f35f97e495309ee8c6
--- /dev/null
+++ b/assets/2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b0bff917b61848877637c5e9e9c8f8a7ff246ecf02b91b6d5b496737f7ea9ddc
+size 3359278
diff --git a/assets/2.txt b/assets/2.txt
new file mode 100644
index 0000000000000000000000000000000000000000..403f697d7aab6821f2dfcf303d21c4e8d88bb332
--- /dev/null
+++ b/assets/2.txt
@@ -0,0 +1,2 @@
+Add a hat on girl's head.
+Bad quality.
\ No newline at end of file
diff --git a/assets/2_edit.png b/assets/2_edit.png
new file mode 100644
index 0000000000000000000000000000000000000000..d3ad431892182e080740e4566303d6634f8c6838
--- /dev/null
+++ b/assets/2_edit.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bfe17831a79a5f2311a010f635cf783241b4b8387877b15ef06363de69ae68f5
+size 3181851
diff --git a/assets/3.mp4 b/assets/3.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..583e13602b08599a8b35ce9b5637ff7ad8c70be1
--- /dev/null
+++ b/assets/3.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ea5c36673ade253be717ee8a7b06599f7108a03cd75ef6a954c32f1a05fde812
+size 1294568
diff --git a/assets/3.txt b/assets/3.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6f96152e9c315a4a59b5efca0db0e60c00d96140
--- /dev/null
+++ b/assets/3.txt
@@ -0,0 +1,2 @@
+Make it oil painting style. The color is bright and beautiful. the video is captured by professional camera. The motion is stable. Best quality.
+bad quality.
\ No newline at end of file
diff --git a/assets/3_edit.png b/assets/3_edit.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d8711f5331721a5c50c8b58de63be5a86287f76
--- /dev/null
+++ b/assets/3_edit.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19945819dec239a4fba99f84d4e167569b4a051697975d383a07d5ccdb603cc8
+size 2583076
diff --git a/assets/4.mp4 b/assets/4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0a886176ff7b249aaa205b61316da0877be6ad5d
--- /dev/null
+++ b/assets/4.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a57e8216cdb4b2bc990a380d72844e2fe6594820ff31993330718c9315c0167a
+size 460228
diff --git a/assets/4.txt b/assets/4.txt
new file mode 100644
index 0000000000000000000000000000000000000000..aa895850fe8a49efa4785d21f83a3c1bf147a204
--- /dev/null
+++ b/assets/4.txt
@@ -0,0 +1,2 @@
+Remove the girl. the video is captured by professional camera. The motion is stable. Best quality.
+bad quality.
\ No newline at end of file
diff --git a/assets/4_edit.png b/assets/4_edit.png
new file mode 100644
index 0000000000000000000000000000000000000000..91f31b26b18d8c36b7ef7acdc5e3dbed1e28ddd4
--- /dev/null
+++ b/assets/4_edit.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:75a189056f9add32ccf0191a649c87ef5dbd84eeb4c7f3afc633fd07fcbe07d1
+size 5432019
diff --git a/assets/5.mp4 b/assets/5.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..56554f0686866c3328d5e81df76810b1ffae6c63
--- /dev/null
+++ b/assets/5.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92dbba09c23b2dc594a84505dc7575bf2111b5af3a2636723533282d349e6e86
+size 513130
diff --git a/assets/5.txt b/assets/5.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8079f1e9c61c4fc97132e0696e103631e4ea6dfa
--- /dev/null
+++ b/assets/5.txt
@@ -0,0 +1,2 @@
+Make it water color style. prompt: the flowers with green leaves. The color is bright and beautiful. the video is captured by professional camera. The motion is stable. Best quality.
+bad quality.
\ No newline at end of file
diff --git a/assets/5_edit.png b/assets/5_edit.png
new file mode 100644
index 0000000000000000000000000000000000000000..a2d36312ad1697e9ce9563f6b950f6470d359bf1
--- /dev/null
+++ b/assets/5_edit.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0b238cea32445fa45b9d8c7c255cf83b574bf82d93510060c987ca11523c2c95
+size 1443562
diff --git a/assets/6.mp4 b/assets/6.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a5dc31cbd3765dac330b4bde90b3097e8ed28405
--- /dev/null
+++ b/assets/6.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:afb4aadab4339e508e66e1630225191f859dd34902c7f165284070241b853eff
+size 119857
diff --git a/assets/6.txt b/assets/6.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d9275d725cde722aca24d2c535e8597cb0f5e351
--- /dev/null
+++ b/assets/6.txt
@@ -0,0 +1,2 @@
+Make it anime style. prompt: the butterfly in on the flower. The color is bright and beautiful. the video is captured by professional camera. The motion is stable. Best quality.
+bad quality.
\ No newline at end of file
diff --git a/assets/6_edit.png b/assets/6_edit.png
new file mode 100644
index 0000000000000000000000000000000000000000..2c0692dcfc7557cc4b5a79e685e1dfd5abd9677f
--- /dev/null
+++ b/assets/6_edit.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:be97371310c83bf4ae64728aa73999d52bc0fe3ce779496fa5938431c86064b8
+size 1445363
diff --git a/assets/7.mp4 b/assets/7.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..af71e91bde97cd632c5ea2ba305beff21ddff9b6
--- /dev/null
+++ b/assets/7.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3ade27a4aed77a41c809a82768c2006179718d2779b8391e5b74d8bf7a44aecf
+size 403370
diff --git a/assets/7.txt b/assets/7.txt
new file mode 100644
index 0000000000000000000000000000000000000000..19274608a1dc37ba716350fa8c3c10aef74b499a
--- /dev/null
+++ b/assets/7.txt
@@ -0,0 +1,2 @@
+Make it anime style. prompt: white swan, autumn. The color is bright and beautiful. the video is captured by professional camera. The motion is stable. Best quality.
+bad quality.
\ No newline at end of file
diff --git a/assets/7_edit.png b/assets/7_edit.png
new file mode 100644
index 0000000000000000000000000000000000000000..ea006b7a74764c9b0c6b48d92fcf6c7966fd253a
--- /dev/null
+++ b/assets/7_edit.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d9ec616c189f6484bf1812886565cbcff4d93719339d077211e2633f45d8cf12
+size 1685678
diff --git a/assets/8.mp4 b/assets/8.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..d32e0c31ea5a31548bcefaa6aa3d0399567c1ed5
--- /dev/null
+++ b/assets/8.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dfe7c713e94c8f85bb174326db147942d906973d49c4dd1a0412a81a7a7a1d93
+size 690881
diff --git a/assets/8.txt b/assets/8.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d9d0fbd24bbae52e99ccb4179b10dbebf7e4f678
--- /dev/null
+++ b/assets/8.txt
@@ -0,0 +1,2 @@
+Swap bird for squirrel. prompt: the squirrel is standing on the column. Squirrel are looking around. two ears.
+Static Squirrel.
diff --git a/assets/8_edit.png b/assets/8_edit.png
new file mode 100644
index 0000000000000000000000000000000000000000..8923b217cd6ec9def0099facccf0e83b7d3a1c2f
--- /dev/null
+++ b/assets/8_edit.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:108da8c8da6764304ad70266eef4d4faa27288286e9f370c118c1da178cfe60c
+size 649406
diff --git a/assets/9.mp4 b/assets/9.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0875c340284712b5aacdd9a9702a2623464acee3
--- /dev/null
+++ b/assets/9.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0073470e8507542945226a48e3af1e72656cdb5df331449374955be07fef4871
+size 539224
diff --git a/assets/9.txt b/assets/9.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d090d4430bf44806a8698ec056cd9272f65aa648
--- /dev/null
+++ b/assets/9.txt
@@ -0,0 +1,3 @@
+Swap black dog for white pig. prompt: the pig is standing between two trees.
+black dog.
+
diff --git a/assets/9_edit.png b/assets/9_edit.png
new file mode 100644
index 0000000000000000000000000000000000000000..b1a52cde930da4f83f8345280504a56f7c8d6316
--- /dev/null
+++ b/assets/9_edit.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a98bb414828486c103a8f845c420031323fa842633eb37ecd98c5c58ce35d83e
+size 703927
diff --git a/assets/outputvideo/output_0.mp4 b/assets/outputvideo/output_0.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..45813c89d03c95b44afe703555cd9cb8d1d11cca
--- /dev/null
+++ b/assets/outputvideo/output_0.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4ba1cf2d89579e0d9e0d12cf4ebc06c0756725f048c6c2b269d52a99f6185477
+size 1282150
diff --git a/assets/outputvideo/output_1.mp4 b/assets/outputvideo/output_1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..565befc010dd701bc2f3550944a8de7c6238bfed
--- /dev/null
+++ b/assets/outputvideo/output_1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cc843ae155c3c3198e147b882fe2460c18d05e0b387136925df828edd83587fc
+size 266780
diff --git a/assets/outputvideo/output_2.mp4 b/assets/outputvideo/output_2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..98bf1a27395a6c3bfded9d082245dcb597b17bb2
--- /dev/null
+++ b/assets/outputvideo/output_2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ee9743f7d86ff59dc3b19cbfaf2677aad38340cc89c6856f353da1aa69a226fe
+size 912515
diff --git a/assets/outputvideo/output_3.mp4 b/assets/outputvideo/output_3.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8875b013665f7ffb3fcc14168c4ab2161d116a9c
--- /dev/null
+++ b/assets/outputvideo/output_3.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:07928cc806c93453a99e6de1ece37d949d71fbf33bbd20ea2c7116ec988f1d17
+size 1058274
diff --git a/assets/outputvideo/output_4.mp4 b/assets/outputvideo/output_4.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a5fafff583cd88e61d64c0b1ce62bafadb18c795
Binary files /dev/null and b/assets/outputvideo/output_4.mp4 differ
diff --git a/assets/outputvideo/output_5.mp4 b/assets/outputvideo/output_5.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..00f7c580c44ee05ca7060c5ca698a0a3fcaddef6
--- /dev/null
+++ b/assets/outputvideo/output_5.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a3e32b887b2084b9ce04bc77c94da9c068018432d1b15de7f93435dff9e65ee3
+size 237088
diff --git a/assets/outputvideo/output_6.mp4 b/assets/outputvideo/output_6.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..5996832ebf25ae7f784f7bcb95463aac6650809d
--- /dev/null
+++ b/assets/outputvideo/output_6.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2fff03bc3efbb20752cae96fb8da35b5fee292596ad00d588a09fabb20c9aeef
+size 141423
diff --git a/assets/outputvideo/output_7.mp4 b/assets/outputvideo/output_7.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..496ef199cd3b18adcab23121f720861b69fe181c
--- /dev/null
+++ b/assets/outputvideo/output_7.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0910239dd1f4b40a49602bdc27af42bd394c3d55724a3665d731471fe20989db
+size 294492
diff --git a/assets/outputvideo/output_8.mp4 b/assets/outputvideo/output_8.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0e691698625df3b21dec2defce3e227198fc0669
--- /dev/null
+++ b/assets/outputvideo/output_8.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8abf76c7fffe1f58445a24bfe203073c094dbc44ec9934ccaff71c31e645a7f9
+size 358596
diff --git a/control_cogvideox/__pycache__/attention_processor.cpython-310.pyc b/control_cogvideox/__pycache__/attention_processor.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bddfe66bda35da58dce509b76a0e0c99c927994c
Binary files /dev/null and b/control_cogvideox/__pycache__/attention_processor.cpython-310.pyc differ
diff --git a/control_cogvideox/__pycache__/cogvideox_transformer_3d.cpython-310.pyc b/control_cogvideox/__pycache__/cogvideox_transformer_3d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4265d891293d2c30a8ae959dcccc073ae4ad09a8
Binary files /dev/null and b/control_cogvideox/__pycache__/cogvideox_transformer_3d.cpython-310.pyc differ
diff --git a/control_cogvideox/__pycache__/cogvideox_transformer_3d.cpython-311.pyc b/control_cogvideox/__pycache__/cogvideox_transformer_3d.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3fc9de56b1352cff08913005d6739ec1c6924454
Binary files /dev/null and b/control_cogvideox/__pycache__/cogvideox_transformer_3d.cpython-311.pyc differ
diff --git a/control_cogvideox/__pycache__/cogvideox_transformer_3d_ipadapter.cpython-310.pyc b/control_cogvideox/__pycache__/cogvideox_transformer_3d_ipadapter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..037849fdc730840e1ab3f22e785ab5cc5a76d9a2
Binary files /dev/null and b/control_cogvideox/__pycache__/cogvideox_transformer_3d_ipadapter.cpython-310.pyc differ
diff --git a/control_cogvideox/__pycache__/cogvideox_transformer_3d_new_version.cpython-310.pyc b/control_cogvideox/__pycache__/cogvideox_transformer_3d_new_version.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa6ca53b152b3c07be7fbf490d439bf339d8225a
Binary files /dev/null and b/control_cogvideox/__pycache__/cogvideox_transformer_3d_new_version.cpython-310.pyc differ
diff --git a/control_cogvideox/__pycache__/controlnet_cogvideox_transformer_3d.cpython-310.pyc b/control_cogvideox/__pycache__/controlnet_cogvideox_transformer_3d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f04cc7ba4eba3f8533215a38d2e9fbba441de32
Binary files /dev/null and b/control_cogvideox/__pycache__/controlnet_cogvideox_transformer_3d.cpython-310.pyc differ
diff --git a/control_cogvideox/__pycache__/controlnet_cogvideox_transformer_3d.cpython-311.pyc b/control_cogvideox/__pycache__/controlnet_cogvideox_transformer_3d.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2cb6395c02e6ca66074a61706747236e37cdd20
Binary files /dev/null and b/control_cogvideox/__pycache__/controlnet_cogvideox_transformer_3d.cpython-311.pyc differ
diff --git a/control_cogvideox/__pycache__/controlnet_cogvideox_transformer_3d_condition.cpython-310.pyc b/control_cogvideox/__pycache__/controlnet_cogvideox_transformer_3d_condition.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6599a77bae60e1ec1ad41f19c66ab2cf5bb30d04
Binary files /dev/null and b/control_cogvideox/__pycache__/controlnet_cogvideox_transformer_3d_condition.cpython-310.pyc differ
diff --git a/control_cogvideox/__pycache__/embeddings.cpython-310.pyc b/control_cogvideox/__pycache__/embeddings.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c59a54655c1b893924bd2b76ec0158741f6c7c8
Binary files /dev/null and b/control_cogvideox/__pycache__/embeddings.cpython-310.pyc differ
diff --git a/control_cogvideox/__pycache__/embeddings.cpython-310.pyc.139873079775504 b/control_cogvideox/__pycache__/embeddings.cpython-310.pyc.139873079775504
new file mode 100644
index 0000000000000000000000000000000000000000..94f8fbf7ac88c48721b338513b45a0aa3fca268c
Binary files /dev/null and b/control_cogvideox/__pycache__/embeddings.cpython-310.pyc.139873079775504 differ
diff --git a/control_cogvideox/__pycache__/embeddings.cpython-310.pyc.140219680461072 b/control_cogvideox/__pycache__/embeddings.cpython-310.pyc.140219680461072
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/control_cogvideox/__pycache__/embeddings.cpython-311.pyc b/control_cogvideox/__pycache__/embeddings.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45e0e8abe8b459d481c9b2725a49319ef047cda6
--- /dev/null
+++ b/control_cogvideox/__pycache__/embeddings.cpython-311.pyc
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:90808e21fcd7971587847f6516bcef59f7bc41fe31d356bb4da7ebcf6a4af392
+size 100711
diff --git a/control_cogvideox/attention_processor.py b/control_cogvideox/attention_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc16a260c12f6f5ee16f17802ce60e33d174c493
--- /dev/null
+++ b/control_cogvideox/attention_processor.py
@@ -0,0 +1,4303 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+import math
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.image_processor import IPAdapterMaskProcessor
+from diffusers.utils import deprecate, logging
+from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
+from diffusers.utils.torch_utils import is_torch_version, maybe_allow_in_graph
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_torch_npu_available():
+ import torch_npu
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+@maybe_allow_in_graph
+class Attention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`):
+ The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8):
+ The number of heads to use for multi-head attention.
+ kv_heads (`int`, *optional*, defaults to `None`):
+ The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
+ `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
+ Query Attention (MQA) otherwise GQA is used.
+ dim_head (`int`, *optional*, defaults to 64):
+ The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the attention computation to `float32`.
+ upcast_softmax (`bool`, *optional*, defaults to False):
+ Set to `True` to upcast the softmax computation to `float32`.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the group norm in the cross attention.
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
+ norm_num_groups (`int`, *optional*, defaults to `None`):
+ The number of groups to use for the group norm in the attention.
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
+ The number of channels to use for the spatial normalization.
+ out_bias (`bool`, *optional*, defaults to `True`):
+ Set to `True` to use a bias in the output linear layer.
+ scale_qk (`bool`, *optional*, defaults to `True`):
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
+ `added_kv_proj_dim` is not `None`.
+ eps (`float`, *optional*, defaults to 1e-5):
+ An additional value added to the denominator in group normalization that is used for numerical stability.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
+ A factor to rescale the output by dividing it with this value.
+ residual_connection (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add the residual connection to the output.
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
+ Set to `True` if the attention block is loaded from a deprecated state dict.
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
+ `AttnProcessor` otherwise.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ kv_heads: Optional[int] = None,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ cross_attention_norm_num_groups: int = 32,
+ qk_norm: Optional[str] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ norm_num_groups: Optional[int] = None,
+ spatial_norm_dim: Optional[int] = None,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ _from_deprecated_attn_block: bool = False,
+ processor: Optional["AttnProcessor"] = None,
+ out_dim: int = None,
+ context_pre_only=None,
+ pre_only=False,
+ ):
+ super().__init__()
+
+ # To prevent circular import.
+ from .normalization import FP32LayerNorm, RMSNorm
+
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.is_cross_attention = cross_attention_dim is not None
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+ self.fused_projections = False
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+
+ # we make use of this private variable to know whether this class is loaded
+ # with an deprecated state dict so that we can convert it on the fly
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
+ else:
+ self.group_norm = None
+
+ if spatial_norm_dim is not None:
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
+ else:
+ self.spatial_norm = None
+
+ if qk_norm is None:
+ self.norm_q = None
+ self.norm_k = None
+ elif qk_norm == "layer_norm":
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps)
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps)
+ elif qk_norm == "fp32_layer_norm":
+ self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "layer_norm_across_heads":
+ # Lumina applys qk norm across all heads
+ self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
+ self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_q = RMSNorm(dim_head, eps=eps)
+ self.norm_k = RMSNorm(dim_head, eps=eps)
+ else:
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
+
+ if cross_attention_norm is None:
+ self.norm_cross = None
+ elif cross_attention_norm == "layer_norm":
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
+ elif cross_attention_norm == "group_norm":
+ if self.added_kv_proj_dim is not None:
+ # The given `encoder_hidden_states` are initially of shape
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
+ # before the projection, so we need to use `added_kv_proj_dim` as
+ # the number of channels for the group norm.
+ norm_cross_num_channels = added_kv_proj_dim
+ else:
+ norm_cross_num_channels = self.cross_attention_dim
+
+ self.norm_cross = nn.GroupNorm(
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
+ )
+ else:
+ raise ValueError(
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
+ )
+
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.only_cross_attention:
+ # only relevant for the `AddedKVProcessor` classes
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ else:
+ self.to_k = None
+ self.to_v = None
+
+ self.added_proj_bias = added_proj_bias
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
+ if self.context_pre_only is not None:
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+
+ if not self.pre_only:
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(dropout))
+
+ if self.context_pre_only is not None and not self.context_pre_only:
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
+
+ if qk_norm is not None and added_kv_proj_dim is not None:
+ if qk_norm == "fp32_layer_norm":
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
+ elif qk_norm == "rms_norm":
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
+ else:
+ self.norm_added_q = None
+ self.norm_added_k = None
+
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ if processor is None:
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
+ r"""
+ Set whether to use npu flash attention from `torch_npu` or not.
+
+ """
+ if use_npu_flash_attention:
+ processor = AttnProcessorNPU()
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ r"""
+ Set whether to use memory efficient attention from `xformers` or not.
+
+ Args:
+ use_memory_efficient_attention_xformers (`bool`):
+ Whether to use memory efficient attention from `xformers` or not.
+ attention_op (`Callable`, *optional*):
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
+ `xformers`.
+ """
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
+ )
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
+ self.processor,
+ (
+ AttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ SlicedAttnAddedKVProcessor,
+ XFormersAttnAddedKVProcessor,
+ ),
+ )
+
+ if use_memory_efficient_attention_xformers:
+ if is_added_kv_processor and is_custom_diffusion:
+ raise NotImplementedError(
+ f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
+ )
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ (
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers"
+ ),
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+
+ if is_custom_diffusion:
+ processor = CustomDiffusionXFormersAttnProcessor(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ attention_op=attention_op,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ elif is_added_kv_processor:
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
+ # which uses this type of cross attention ONLY because the attention mask of format
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
+ # throw warning
+ logger.info(
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
+ )
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
+ else:
+ processor = XFormersAttnProcessor(attention_op=attention_op)
+ else:
+ if is_custom_diffusion:
+ attn_processor_class = (
+ CustomDiffusionAttnProcessor2_0
+ if hasattr(F, "scaled_dot_product_attention")
+ else CustomDiffusionAttnProcessor
+ )
+ processor = attn_processor_class(
+ train_kv=self.processor.train_kv,
+ train_q_out=self.processor.train_q_out,
+ hidden_size=self.processor.hidden_size,
+ cross_attention_dim=self.processor.cross_attention_dim,
+ )
+ processor.load_state_dict(self.processor.state_dict())
+ if hasattr(self.processor, "to_k_custom_diffusion"):
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0()
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
+ else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_attention_slice(self, slice_size: int) -> None:
+ r"""
+ Set the slice size for attention computation.
+
+ Args:
+ slice_size (`int`):
+ The slice size for attention computation.
+ """
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ if slice_size is not None and self.added_kv_proj_dim is not None:
+ processor = SlicedAttnAddedKVProcessor(slice_size)
+ elif slice_size is not None:
+ processor = SlicedAttnProcessor(slice_size)
+ elif self.added_kv_proj_dim is not None:
+ processor = AttnAddedKVProcessor()
+ else:
+ # set attention processor
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+
+ self.set_processor(processor)
+
+ def set_processor(self, processor: "AttnProcessor") -> None:
+ r"""
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
+ r"""
+ Get the attention processor in use.
+
+ Args:
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to return the deprecated LoRA attention processor.
+
+ Returns:
+ "AttentionProcessor": The attention processor in use.
+ """
+ if not return_deprecated_lora:
+ return self.processor
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks"}
+ unused_kwargs = [
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
+ ]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
+ is the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+ r"""
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
+ the number of heads initialized while constructing the `Attention` class.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ if tensor.ndim == 3:
+ batch_size, seq_len, dim = tensor.shape
+ extra_dim = 1
+ else:
+ batch_size, extra_dim, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
+
+ return tensor
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ r"""
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ r"""
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ The attention mask to prepare.
+ target_length (`int`):
+ The target length of the attention mask. This is the length of the attention mask after padding.
+ batch_size (`int`):
+ The batch size, which is used to repeat the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`):
+ The output dimension of the attention mask. Can be either `3` or `4`.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ r"""
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
+ `Attention` class.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+ Returns:
+ `torch.Tensor`: The normalized encoder hidden states.
+ """
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+ @torch.no_grad()
+ def fuse_projections(self, fuse=True):
+ device = self.to_q.weight.data.device
+ dtype = self.to_q.weight.data.dtype
+
+ if not self.is_cross_attention:
+ # fetch weight matrices.
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ # create a new single projection layer and copy over the weights.
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_qkv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ self.to_qkv.bias.copy_(concatenated_bias)
+
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_kv.weight.copy_(concatenated_weights)
+ if self.use_bias:
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ self.to_kv.bias.copy_(concatenated_bias)
+
+ # handle added projections for SD3 and others.
+ if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
+ concatenated_weights = torch.cat(
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
+ )
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_added_qkv = nn.Linear(
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
+ )
+ self.to_added_qkv.weight.copy_(concatenated_weights)
+ if self.added_proj_bias:
+ concatenated_bias = torch.cat(
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
+ )
+ self.to_added_qkv.bias.copy_(concatenated_bias)
+
+ self.fused_projections = fuse
+
+
+class AttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CustomDiffusionAttnProcessor(nn.Module):
+ r"""
+ Processor for implementing attention for the Custom Diffusion method.
+
+ Args:
+ train_kv (`bool`, defaults to `True`):
+ Whether to newly train the key and value matrices corresponding to the text features.
+ train_q_out (`bool`, defaults to `True`):
+ Whether to newly train query matrices corresponding to the latent image features.
+ hidden_size (`int`, *optional*, defaults to `None`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ out_bias (`bool`, defaults to `True`):
+ Whether to include the bias parameter in `train_q_out`.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ """
+
+ def __init__(
+ self,
+ train_kv: bool = True,
+ train_q_out: bool = True,
+ hidden_size: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ self.train_kv = train_kv
+ self.train_q_out = train_q_out
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ # `_custom_diffusion` id for easy serialization and loading.
+ if self.train_kv:
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ if self.train_q_out:
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.to_out_custom_diffusion = nn.ModuleList([])
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ if self.train_q_out:
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
+ else:
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
+
+ if encoder_hidden_states is None:
+ crossattn = False
+ encoder_hidden_states = hidden_states
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if self.train_kv:
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+ key = key.to(attn.to_q.weight.dtype)
+ value = value.to(attn.to_q.weight.dtype)
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if crossattn:
+ detach = torch.ones_like(key)
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
+ key = detach * key + (1 - detach) * key.detach()
+ value = detach * value + (1 - detach) * value.detach()
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if self.train_q_out:
+ # linear proj
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+ else:
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class AttnAddedKVProcessor:
+ r"""
+ Processor for performing attention-related computations with extra learnable key and value matrices for the text
+ encoder.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class AttnAddedKVProcessor2_0:
+ r"""
+ Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra
+ learnable key and value matrices for the text encoder.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query, out_dim=4)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key, out_dim=4)
+ value = attn.head_to_batch_dim(value, out_dim=4)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class JointAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ # attention
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states, encoder_hidden_states
+
+
+class PAGJointAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # store the length of image patch sequences to create a mask that prevents interaction between patches
+ # similar to making the self-attention map an identity matrix
+ identity_block_size = hidden_states.shape[1]
+
+ # chunk
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
+ encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
+
+ ################## original path ##################
+ batch_size = encoder_hidden_states_org.shape[0]
+
+ # `sample` projections.
+ query_org = attn.to_q(hidden_states_org)
+ key_org = attn.to_k(hidden_states_org)
+ value_org = attn.to_v(hidden_states_org)
+
+ # `context` projections.
+ encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
+ encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
+ encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
+
+ # attention
+ query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
+ key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
+ value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
+
+ inner_dim = key_org.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states_org = F.scaled_dot_product_attention(
+ query_org, key_org, value_org, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query_org.dtype)
+
+ # Split the attention outputs.
+ hidden_states_org, encoder_hidden_states_org = (
+ hidden_states_org[:, : residual.shape[1]],
+ hidden_states_org[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+ if not attn.context_pre_only:
+ encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ ################## perturbed path ##################
+
+ batch_size = encoder_hidden_states_ptb.shape[0]
+
+ # `sample` projections.
+ query_ptb = attn.to_q(hidden_states_ptb)
+ key_ptb = attn.to_k(hidden_states_ptb)
+ value_ptb = attn.to_v(hidden_states_ptb)
+
+ # `context` projections.
+ encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
+ encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
+ encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
+
+ # attention
+ query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
+ key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
+ value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
+
+ inner_dim = key_ptb.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # create a full mask with all entries set to 0
+ seq_len = query_ptb.size(2)
+ full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
+
+ # set the attention value between image patches to -inf
+ full_mask[:identity_block_size, :identity_block_size] = float("-inf")
+
+ # set the diagonal of the attention value between image patches to 0
+ full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
+
+ # expand the mask to match the attention weights shape
+ full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
+
+ hidden_states_ptb = F.scaled_dot_product_attention(
+ query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
+
+ # split the attention outputs.
+ hidden_states_ptb, encoder_hidden_states_ptb = (
+ hidden_states_ptb[:, : residual.shape[1]],
+ hidden_states_ptb[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+ if not attn.context_pre_only:
+ encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ ################ concat ###############
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+ encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
+
+ return hidden_states, encoder_hidden_states
+
+
+class PAGCFGJointAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ identity_block_size = hidden_states.shape[
+ 1
+ ] # patch embeddings width * height (correspond to self-attention map width or height)
+
+ # chunk
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
+
+ (
+ encoder_hidden_states_uncond,
+ encoder_hidden_states_org,
+ encoder_hidden_states_ptb,
+ ) = encoder_hidden_states.chunk(3)
+ encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
+
+ ################## original path ##################
+ batch_size = encoder_hidden_states_org.shape[0]
+
+ # `sample` projections.
+ query_org = attn.to_q(hidden_states_org)
+ key_org = attn.to_k(hidden_states_org)
+ value_org = attn.to_v(hidden_states_org)
+
+ # `context` projections.
+ encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
+ encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
+ encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
+
+ # attention
+ query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
+ key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
+ value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
+
+ inner_dim = key_org.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states_org = F.scaled_dot_product_attention(
+ query_org, key_org, value_org, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query_org.dtype)
+
+ # Split the attention outputs.
+ hidden_states_org, encoder_hidden_states_org = (
+ hidden_states_org[:, : residual.shape[1]],
+ hidden_states_org[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+ if not attn.context_pre_only:
+ encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ ################## perturbed path ##################
+
+ batch_size = encoder_hidden_states_ptb.shape[0]
+
+ # `sample` projections.
+ query_ptb = attn.to_q(hidden_states_ptb)
+ key_ptb = attn.to_k(hidden_states_ptb)
+ value_ptb = attn.to_v(hidden_states_ptb)
+
+ # `context` projections.
+ encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
+ encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
+ encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
+
+ # attention
+ query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
+ key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
+ value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
+
+ inner_dim = key_ptb.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # create a full mask with all entries set to 0
+ seq_len = query_ptb.size(2)
+ full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
+
+ # set the attention value between image patches to -inf
+ full_mask[:identity_block_size, :identity_block_size] = float("-inf")
+
+ # set the diagonal of the attention value between image patches to 0
+ full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
+
+ # expand the mask to match the attention weights shape
+ full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
+
+ hidden_states_ptb = F.scaled_dot_product_attention(
+ query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
+
+ # split the attention outputs.
+ hidden_states_ptb, encoder_hidden_states_ptb = (
+ hidden_states_ptb[:, : residual.shape[1]],
+ hidden_states_ptb[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+ if not attn.context_pre_only:
+ encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
+ batch_size, channel, height, width
+ )
+
+ ################ concat ###############
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+ encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
+
+ return hidden_states, encoder_hidden_states
+
+
+class FusedJointAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size = encoder_hidden_states.shape[0]
+
+ # `sample` projections.
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ # `context` projections.
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ (
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
+
+ # attention
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ return hidden_states, encoder_hidden_states
+
+
+class AuraFlowAttnProcessor2_0:
+ """Attention processor used typically in processing Aura Flow."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
+ raise ImportError(
+ "AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ # Reshape.
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ # Apply QK norm.
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Concatenate the projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
+
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # Attention.
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FusedAuraFlowAttnProcessor2_0:
+ """Attention processor used typically in processing Aura Flow with fused projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
+ raise ImportError(
+ "FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ (
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
+
+ # Reshape.
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ # Apply QK norm.
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Concatenate the projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ )
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
+
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # Attention.
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # Split the attention outputs.
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FluxAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ if encoder_hidden_states is not None:
+ # `context` projections.
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FusedFluxAttnProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+
+ # `sample` projections.
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
+ split_size = encoder_qkv.shape[-1] // 3
+ (
+ encoder_hidden_states_query_proj,
+ encoder_hidden_states_key_proj,
+ encoder_hidden_states_value_proj,
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ # attention
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
+
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query = apply_rotary_emb(query, image_rotary_emb)
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, : encoder_hidden_states.shape[1]],
+ hidden_states[:, encoder_hidden_states.shape[1] :],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class CogVideoXAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
+ if not attn.is_cross_attention:
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return query, hidden_states, encoder_hidden_states
+
+
+class FusedCogVideoXAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
+ if not attn.is_cross_attention:
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return query, hidden_states, encoder_hidden_states
+
+
+class XFormersAttnAddedKVProcessor:
+ r"""
+ Processor for implementing memory efficient attention using xFormers.
+
+ Args:
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
+ """
+
+ def __init__(self, attention_op: Optional[Callable] = None):
+ self.attention_op = attention_op
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class XFormersAttnProcessor:
+ r"""
+ Processor for implementing memory efficient attention using xFormers.
+
+ Args:
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
+ operator.
+ """
+
+ def __init__(self, attention_op: Optional[Callable] = None):
+ self.attention_op = attention_op
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, key_tokens, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
+ if attention_mask is not None:
+ # expand our mask's singleton query_tokens dimension:
+ # [batch*heads, 1, key_tokens] ->
+ # [batch*heads, query_tokens, key_tokens]
+ # so that it can be added as a bias onto the attention scores that xformers computes:
+ # [batch*heads, query_tokens, key_tokens]
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
+ _, query_tokens, _ = hidden_states.shape
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AttnProcessorNPU:
+ r"""
+ Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
+ fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
+ not significant.
+
+ """
+
+ def __init__(self):
+ if not is_torch_npu_available():
+ raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ if query.dtype in (torch.float16, torch.bfloat16):
+ hidden_states = torch_npu.npu_fusion_attention(
+ query,
+ key,
+ value,
+ attn.heads,
+ input_layout="BNSD",
+ pse=None,
+ atten_mask=attention_mask,
+ scale=1.0 / math.sqrt(query.shape[-1]),
+ pre_tockens=65536,
+ next_tockens=65536,
+ keep_prob=1.0,
+ sync=False,
+ inner_precise=0,
+ )[0]
+ else:
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class StableAudioAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def apply_partial_rotary_emb(
+ self,
+ x: torch.Tensor,
+ freqs_cis: Tuple[torch.Tensor],
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ rot_dim = freqs_cis[0].shape[-1]
+ x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:]
+
+ x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2)
+
+ out = torch.cat((x_rotated, x_unrotated), dim=-1)
+ return out
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ head_dim = query.shape[-1] // attn.heads
+ kv_heads = key.shape[-1] // head_dim
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2)
+
+ if kv_heads != attn.heads:
+ # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
+ heads_per_kv_head = attn.heads // kv_heads
+ key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
+ value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if rotary_emb is not None:
+ query_dtype = query.dtype
+ key_dtype = key.dtype
+ query = query.to(torch.float32)
+ key = key.to(torch.float32)
+
+ rot_dim = rotary_emb[0].shape[-1]
+ query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:]
+ query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
+
+ query = torch.cat((query_rotated, query_unrotated), dim=-1)
+
+ if not attn.is_cross_attention:
+ key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:]
+ key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2)
+
+ key = torch.cat((key_rotated, key_unrotated), dim=-1)
+
+ query = query.to(query_dtype)
+ key = key.to(key_dtype)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class HunyuanAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class FusedHunyuanAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0) with fused
+ projection layers. This is used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on
+ query and key vector.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "FusedHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ if encoder_hidden_states is None:
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+ else:
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+ query = attn.to_q(hidden_states)
+
+ kv = attn.to_kv(encoder_hidden_states)
+ split_size = kv.shape[-1] // 2
+ key, value = torch.split(kv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class PAGHunyuanAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
+ variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # chunk
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
+
+ # 1. Original Path
+ batch_size, sequence_length, _ = (
+ hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states_org)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states_org
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states_org = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query.dtype)
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # 2. Perturbed Path
+ if attn.group_norm is not None:
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
+
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # cat
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class PAGCFGHunyuanAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the HunyuanDiT model. It applies a normalization layer and rotary embedding on query and key vector. This
+ variant of the processor employs [Pertubed Attention Guidance](https://arxiv.org/abs/2403.17377).
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGCFGHunyuanAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # chunk
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
+
+ # 1. Original Path
+ batch_size, sequence_length, _ = (
+ hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states_org)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states_org
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb)
+ if not attn.is_cross_attention:
+ key = apply_rotary_emb(key, image_rotary_emb)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states_org = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query.dtype)
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # 2. Perturbed Path
+ if attn.group_norm is not None:
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
+
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # cat
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LuminaAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ query_rotary_emb: Optional[torch.Tensor] = None,
+ key_rotary_emb: Optional[torch.Tensor] = None,
+ base_sequence_length: Optional[int] = None,
+ ) -> torch.Tensor:
+ from .embeddings import apply_rotary_emb
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ # Get Query-Key-Value Pair
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query_dim = query.shape[-1]
+ inner_dim = key.shape[-1]
+ head_dim = query_dim // attn.heads
+ dtype = query.dtype
+
+ # Get key-value heads
+ kv_heads = inner_dim // head_dim
+
+ # Apply Query-Key Norm if needed
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+
+ key = key.view(batch_size, -1, kv_heads, head_dim)
+ value = value.view(batch_size, -1, kv_heads, head_dim)
+
+ # Apply RoPE if needed
+ if query_rotary_emb is not None:
+ query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
+ if key_rotary_emb is not None:
+ key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
+
+ query, key = query.to(dtype), key.to(dtype)
+
+ # Apply proportional attention if true
+ if key_rotary_emb is None:
+ softmax_scale = None
+ else:
+ if base_sequence_length is not None:
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
+ else:
+ softmax_scale = attn.scale
+
+ # perform Grouped-qurey Attention (GQA)
+ n_rep = attn.heads // kv_heads
+ if n_rep >= 1:
+ key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+ value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
+ attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
+ )
+ hidden_states = hidden_states.transpose(1, 2).to(dtype)
+
+ return hidden_states
+
+
+class FusedAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
+ fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
+ For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is currently 🧪 experimental in nature and can change in future.
+
+
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ if encoder_hidden_states is None:
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+ else:
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+ query = attn.to_q(hidden_states)
+
+ kv = attn.to_kv(encoder_hidden_states)
+ split_size = kv.shape[-1] // 2
+ key, value = torch.split(kv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CustomDiffusionXFormersAttnProcessor(nn.Module):
+ r"""
+ Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
+
+ Args:
+ train_kv (`bool`, defaults to `True`):
+ Whether to newly train the key and value matrices corresponding to the text features.
+ train_q_out (`bool`, defaults to `True`):
+ Whether to newly train query matrices corresponding to the latent image features.
+ hidden_size (`int`, *optional*, defaults to `None`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ out_bias (`bool`, defaults to `True`):
+ Whether to include the bias parameter in `train_q_out`.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ attention_op (`Callable`, *optional*, defaults to `None`):
+ The base
+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
+ as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
+ """
+
+ def __init__(
+ self,
+ train_kv: bool = True,
+ train_q_out: bool = False,
+ hidden_size: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ dropout: float = 0.0,
+ attention_op: Optional[Callable] = None,
+ ):
+ super().__init__()
+ self.train_kv = train_kv
+ self.train_q_out = train_q_out
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.attention_op = attention_op
+
+ # `_custom_diffusion` id for easy serialization and loading.
+ if self.train_kv:
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ if self.train_q_out:
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.to_out_custom_diffusion = nn.ModuleList([])
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if self.train_q_out:
+ query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
+ else:
+ query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
+
+ if encoder_hidden_states is None:
+ crossattn = False
+ encoder_hidden_states = hidden_states
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if self.train_kv:
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+ key = key.to(attn.to_q.weight.dtype)
+ value = value.to(attn.to_q.weight.dtype)
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if crossattn:
+ detach = torch.ones_like(key)
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
+ key = detach * key + (1 - detach) * key.detach()
+ value = detach * value + (1 - detach) * value.detach()
+
+ query = attn.head_to_batch_dim(query).contiguous()
+ key = attn.head_to_batch_dim(key).contiguous()
+ value = attn.head_to_batch_dim(value).contiguous()
+
+ hidden_states = xformers.ops.memory_efficient_attention(
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
+ )
+ hidden_states = hidden_states.to(query.dtype)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if self.train_q_out:
+ # linear proj
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+ else:
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class CustomDiffusionAttnProcessor2_0(nn.Module):
+ r"""
+ Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
+ dot-product attention.
+
+ Args:
+ train_kv (`bool`, defaults to `True`):
+ Whether to newly train the key and value matrices corresponding to the text features.
+ train_q_out (`bool`, defaults to `True`):
+ Whether to newly train query matrices corresponding to the latent image features.
+ hidden_size (`int`, *optional*, defaults to `None`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`, *optional*, defaults to `None`):
+ The number of channels in the `encoder_hidden_states`.
+ out_bias (`bool`, defaults to `True`):
+ Whether to include the bias parameter in `train_q_out`.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability to use.
+ """
+
+ def __init__(
+ self,
+ train_kv: bool = True,
+ train_q_out: bool = True,
+ hidden_size: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ dropout: float = 0.0,
+ ):
+ super().__init__()
+ self.train_kv = train_kv
+ self.train_q_out = train_q_out
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ # `_custom_diffusion` id for easy serialization and loading.
+ if self.train_kv:
+ self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ if self.train_q_out:
+ self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
+ self.to_out_custom_diffusion = nn.ModuleList([])
+ self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
+ self.to_out_custom_diffusion.append(nn.Dropout(dropout))
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = hidden_states.shape
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ if self.train_q_out:
+ query = self.to_q_custom_diffusion(hidden_states)
+ else:
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ crossattn = False
+ encoder_hidden_states = hidden_states
+ else:
+ crossattn = True
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if self.train_kv:
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+ key = key.to(attn.to_q.weight.dtype)
+ value = value.to(attn.to_q.weight.dtype)
+
+ else:
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if crossattn:
+ detach = torch.ones_like(key)
+ detach[:, :1, :] = detach[:, :1, :] * 0.0
+ key = detach * key + (1 - detach) * key.detach()
+ value = detach * value + (1 - detach) * value.detach()
+
+ inner_dim = hidden_states.shape[-1]
+
+ head_dim = inner_dim // attn.heads
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if self.train_q_out:
+ # linear proj
+ hidden_states = self.to_out_custom_diffusion[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out_custom_diffusion[1](hidden_states)
+ else:
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+class SlicedAttnProcessor:
+ r"""
+ Processor for implementing sliced attention.
+
+ Args:
+ slice_size (`int`, *optional*):
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+ `attention_head_dim` must be a multiple of the `slice_size`.
+ """
+
+ def __init__(self, slice_size: int):
+ self.slice_size = slice_size
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = attn.head_to_batch_dim(query)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ batch_size_attention, query_tokens, _ = query.shape
+ hidden_states = torch.zeros(
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+ )
+
+ for i in range((batch_size_attention - 1) // self.slice_size + 1):
+ start_idx = i * self.slice_size
+ end_idx = (i + 1) * self.slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class SlicedAttnAddedKVProcessor:
+ r"""
+ Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder.
+
+ Args:
+ slice_size (`int`, *optional*):
+ The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
+ `attention_head_dim` must be a multiple of the `slice_size`.
+ """
+
+ def __init__(self, slice_size):
+ self.slice_size = slice_size
+
+ def __call__(
+ self,
+ attn: "Attention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
+
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = attn.head_to_batch_dim(query)
+
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
+
+ if not attn.only_cross_attention:
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ key = encoder_hidden_states_key_proj
+ value = encoder_hidden_states_value_proj
+
+ batch_size_attention, query_tokens, _ = query.shape
+ hidden_states = torch.zeros(
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
+ )
+
+ for i in range((batch_size_attention - 1) // self.slice_size + 1):
+ start_idx = i * self.slice_size
+ end_idx = (i + 1) * self.slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
+
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
+
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class SpatialNorm(nn.Module):
+ """
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002.
+
+ Args:
+ f_channels (`int`):
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
+ zq_channels (`int`):
+ The number of channels for the quantized vector as described in the paper.
+ """
+
+ def __init__(
+ self,
+ f_channels: int,
+ zq_channels: int,
+ ):
+ super().__init__()
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
+ self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+ self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
+ f_size = f.shape[-2:]
+ zq = F.interpolate(zq, size=f_size, mode="nearest")
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+class IPAdapterAttnProcessor(nn.Module):
+ r"""
+ Attention processor for Multiple IP-Adapters.
+
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
+ The context length of the image features.
+ scale (`float` or List[`float`], defaults to 1.0):
+ the weight scale of image prompt.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
+ self.num_tokens = num_tokens
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+ self.scale = scale
+
+ self.to_k_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ scale: float = 1.0,
+ ip_adapter_masks: Optional[torch.Tensor] = None,
+ ):
+ residual = hidden_states
+
+ # separate ip_hidden_states from encoder_hidden_states
+ if encoder_hidden_states is not None:
+ if isinstance(encoder_hidden_states, tuple):
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
+ else:
+ deprecation_message = (
+ "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
+ )
+ deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ [encoder_hidden_states[:, end_pos:, :]],
+ )
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ if ip_adapter_masks is not None:
+ if not isinstance(ip_adapter_masks, List):
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
+ raise ValueError(
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
+ f"({len(ip_hidden_states)})"
+ )
+ else:
+ for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
+ raise ValueError(
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
+ "[1, num_images_for_ip_adapter, height, width]."
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+ )
+ if mask.shape[1] != ip_state.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
+ )
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of scales ({len(scale)}) at index {index}"
+ )
+ else:
+ ip_adapter_masks = [None] * len(self.scale)
+
+ # for ip-adapter
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
+ ):
+ skip = False
+ if isinstance(scale, list):
+ if all(s == 0 for s in scale):
+ skip = True
+ elif scale == 0:
+ skip = True
+ if not skip:
+ if mask is not None:
+ if not isinstance(scale, list):
+ scale = [scale] * mask.shape[1]
+
+ current_num_images = mask.shape[1]
+ for i in range(current_num_images):
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
+
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask[:, i, :, :],
+ batch_size,
+ _current_ip_hidden_states.shape[1],
+ _current_ip_hidden_states.shape[2],
+ )
+
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
+ else:
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
+
+ hidden_states = hidden_states + scale * current_ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class IPAdapterAttnProcessor2_0(torch.nn.Module):
+ r"""
+ Attention processor for IP-Adapter for PyTorch 2.0.
+
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
+ The context length of the image features.
+ scale (`float` or `List[float]`, defaults to 1.0):
+ the weight scale of image prompt.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
+ self.num_tokens = num_tokens
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+ self.scale = scale
+
+ self.to_k_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))]
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ temb: Optional[torch.Tensor] = None,
+ scale: float = 1.0,
+ ip_adapter_masks: Optional[torch.Tensor] = None,
+ ):
+ residual = hidden_states
+
+ # separate ip_hidden_states from encoder_hidden_states
+ if encoder_hidden_states is not None:
+ if isinstance(encoder_hidden_states, tuple):
+ encoder_hidden_states, ip_hidden_states = encoder_hidden_states
+ else:
+ deprecation_message = (
+ "You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
+ )
+ deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ [encoder_hidden_states[:, end_pos:, :]],
+ )
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if ip_adapter_masks is not None:
+ if not isinstance(ip_adapter_masks, List):
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
+ raise ValueError(
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
+ f"({len(ip_hidden_states)})"
+ )
+ else:
+ for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
+ raise ValueError(
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
+ "[1, num_images_for_ip_adapter, height, width]."
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+ )
+ if mask.shape[1] != ip_state.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
+ )
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of scales ({len(scale)}) at index {index}"
+ )
+ else:
+ ip_adapter_masks = [None] * len(self.scale)
+
+ # for ip-adapter
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
+ ):
+ skip = False
+ if isinstance(scale, list):
+ if all(s == 0 for s in scale):
+ skip = True
+ elif scale == 0:
+ skip = True
+ if not skip:
+ if mask is not None:
+ if not isinstance(scale, list):
+ scale = [scale] * mask.shape[1]
+
+ current_num_images = mask.shape[1]
+ for i in range(current_num_images):
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ _current_ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
+
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask[:, i, :, :],
+ batch_size,
+ _current_ip_hidden_states.shape[1],
+ _current_ip_hidden_states.shape[2],
+ )
+
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
+ else:
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ current_ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
+
+ hidden_states = hidden_states + scale * current_ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class PAGIdentitySelfAttnProcessor2_0:
+ r"""
+ Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ PAG reference: https://arxiv.org/abs/2403.17377
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # chunk
+ hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
+
+ # original path
+ batch_size, sequence_length, _ = hidden_states_org.shape
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states_org)
+ key = attn.to_k(hidden_states_org)
+ value = attn.to_v(hidden_states_org)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states_org = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query.dtype)
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # perturbed path (identity attention)
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
+
+ if attn.group_norm is not None:
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
+
+ hidden_states_ptb = attn.to_v(hidden_states_ptb)
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # cat
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class PAGCFGIdentitySelfAttnProcessor2_0:
+ r"""
+ Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ PAG reference: https://arxiv.org/abs/2403.17377
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ # chunk
+ hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
+ hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
+
+ # original path
+ batch_size, sequence_length, _ = hidden_states_org.shape
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states_org)
+ key = attn.to_k(hidden_states_org)
+ value = attn.to_v(hidden_states_org)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states_org = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_org = hidden_states_org.to(query.dtype)
+
+ # linear proj
+ hidden_states_org = attn.to_out[0](hidden_states_org)
+ # dropout
+ hidden_states_org = attn.to_out[1](hidden_states_org)
+
+ if input_ndim == 4:
+ hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # perturbed path (identity attention)
+ batch_size, sequence_length, _ = hidden_states_ptb.shape
+
+ if attn.group_norm is not None:
+ hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
+
+ value = attn.to_v(hidden_states_ptb)
+ hidden_states_ptb = value
+ hidden_states_ptb = hidden_states_ptb.to(query.dtype)
+
+ # linear proj
+ hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
+ # dropout
+ hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
+
+ if input_ndim == 4:
+ hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ # cat
+ hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class LoRAAttnProcessor:
+ def __init__(self):
+ pass
+
+
+class LoRAAttnProcessor2_0:
+ def __init__(self):
+ pass
+
+
+class LoRAXFormersAttnProcessor:
+ def __init__(self):
+ pass
+
+
+class LoRAAttnAddedKVProcessor:
+ def __init__(self):
+ pass
+
+
+class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self):
+ deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
+ deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
+ super().__init__()
+
+
+ADDED_KV_ATTENTION_PROCESSORS = (
+ AttnAddedKVProcessor,
+ SlicedAttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ XFormersAttnAddedKVProcessor,
+)
+
+CROSS_ATTENTION_PROCESSORS = (
+ AttnProcessor,
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ SlicedAttnProcessor,
+ IPAdapterAttnProcessor,
+ IPAdapterAttnProcessor2_0,
+)
+
+AttentionProcessor = Union[
+ AttnProcessor,
+ AttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ XFormersAttnProcessor,
+ SlicedAttnProcessor,
+ AttnAddedKVProcessor,
+ SlicedAttnAddedKVProcessor,
+ AttnAddedKVProcessor2_0,
+ XFormersAttnAddedKVProcessor,
+ CustomDiffusionAttnProcessor,
+ CustomDiffusionXFormersAttnProcessor,
+ CustomDiffusionAttnProcessor2_0,
+ PAGCFGIdentitySelfAttnProcessor2_0,
+ PAGIdentitySelfAttnProcessor2_0,
+ PAGCFGHunyuanAttnProcessor2_0,
+ PAGHunyuanAttnProcessor2_0,
+]
diff --git a/control_cogvideox/cogvideox_transformer_3d.py b/control_cogvideox/cogvideox_transformer_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0ee12d077f7880508dc01d48fe0bc7f85a12022
--- /dev/null
+++ b/control_cogvideox/cogvideox_transformer_3d.py
@@ -0,0 +1,515 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from .embeddings import CogVideoXPatchEmbed, CogVideoXPatchEmbed_No_Textual, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
+try:
+ from diffusers.models.controlnet import BaseOutput, zero_module
+except:
+ from diffusers.models.controlnet import zero_module
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 30,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ use_learned_positional_embeddings: bool = False,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
+ raise ValueError(
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+ "issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=not use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
+ )
+ print(use_rotary_positional_embeddings, use_learned_positional_embeddings)
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 3. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 4. Output blocks
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ interval: int=5,
+ residual_hidden_states: torch.Tensor = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ text_seq_length = encoder_hidden_states.shape[1]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ if residual_hidden_states is not None:
+ hidden_states += residual_hidden_states[i//interval]
+
+ if not self.config.use_rotary_positional_embeddings:
+ # CogVideoX-2B
+ hidden_states = self.norm_final(hidden_states)
+ else:
+ # CogVideoX-5B
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 4. Final block
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ # Note: we use `-1` instead of `channels`:
+ # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
+ # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
+ p = self.config.patch_size
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/control_cogvideox/cogvideox_transformer_3d_ipadapter.py b/control_cogvideox/cogvideox_transformer_3d_ipadapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..61f39e5a06895c00331e59f126d36ea809aad780
--- /dev/null
+++ b/control_cogvideox/cogvideox_transformer_3d_ipadapter.py
@@ -0,0 +1,566 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+import torch.nn.functional as F
+from einops import rearrange
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention import Attention, FeedForward
+from .attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from .embeddings import CogVideoXPatchEmbed, CogVideoXPatchEmbed_No_Textual, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
+from diffusers.models.controlnet import BaseOutput, zero_module
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+ self.num_attention_heads = num_attention_heads
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ self.ipadapter_k_proj = nn.Linear(dim, dim)
+ self.ipadapter_v_proj = nn.Linear(dim, dim)
+ nn.init.zeros_(self.ipadapter_k_proj.weight)
+ nn.init.zeros_(self.ipadapter_k_proj.bias)
+ nn.init.zeros_(self.ipadapter_v_proj.weight)
+ nn.init.zeros_(self.ipadapter_v_proj.bias)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ip_scale: float = 1.0,
+ image_embed: torch.Tensor = None
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ video_q, attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ if image_embed is not None:
+ #print("hidden_states", hidden_states.shape)
+ #print("video_q.shape", video_q.shape, "num_attention_heads", self.num_attention_heads)
+ ip_query = video_q[:, :, text_seq_length:] # b h l d
+ ip_key = self.ipadapter_k_proj(image_embed) # b l2 d
+ ip_value = self.ipadapter_v_proj(image_embed) # b l2 d
+ ip_key = rearrange(ip_key, " b l (h d) -> b h l d", h=self.num_attention_heads)
+ ip_value = rearrange(ip_value, " b l (h d) -> b h l d", h=self.num_attention_heads)
+
+ ip_attention = F.scaled_dot_product_attention(
+ ip_query, ip_key, ip_value, dropout_p=0, is_causal=False
+ )
+
+ ip_attention = rearrange(ip_attention, "b h l d -> b l (h d)")
+
+ hidden_states += ip_scale*ip_attention
+
+
+ return hidden_states, encoder_hidden_states
+
+
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 30,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ use_learned_positional_embeddings: bool = False,
+ num_extra_tokens: int = 4,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
+ raise ValueError(
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+ "issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=not use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
+ )
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 3. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 4. Output blocks
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ self.proj = nn.Linear(768, num_extra_tokens*inner_dim)
+ self.norm = nn.LayerNorm(inner_dim)
+
+ self.inner_dim = inner_dim
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ interval: int=5,
+ residual_hidden_states: torch.Tensor = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ image_conditions: torch.Tensor = None,
+ inference_with_multi_images: bool = False,
+ ip_scale: float = 1.0,
+ control_scale: float = 1.0
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ text_seq_length = encoder_hidden_states.shape[1]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ image_conditions = self.proj(image_conditions)
+ if inference_with_multi_images:
+ image_conditions = rearrange(image_conditions, "b m (l d) -> b (m l) d", d=self.inner_dim)
+ else:
+ image_conditions = rearrange(image_conditions, "b (l d) -> b l d", d=self.inner_dim)
+ image_conditions = self.norm(image_conditions)
+ print(image_conditions.shape)
+
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ #print("ip_scale", ip_scale)
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ ip_scale,
+ image_conditions,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ ip_scale=ip_scale,
+ image_embed=image_conditions
+ )
+ if residual_hidden_states is not None:
+ #print(control_scale)
+ hidden_states += control_scale*residual_hidden_states[i//interval]
+
+ if not self.config.use_rotary_positional_embeddings:
+ # CogVideoX-2B
+ hidden_states = self.norm_final(hidden_states)
+ else:
+ # CogVideoX-5B
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 4. Final block
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ # Note: we use `-1` instead of `channels`:
+ # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
+ # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
+ p = self.config.patch_size
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/control_cogvideox/cogvideox_transformer_3d_new_version.py b/control_cogvideox/cogvideox_transformer_3d_new_version.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dd759bcc1805d1aa7a824bc41b89770b39db8f0
--- /dev/null
+++ b/control_cogvideox/cogvideox_transformer_3d_new_version.py
@@ -0,0 +1,513 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from .embeddings import CogVideoXPatchEmbed, CogVideoXPatchEmbed_No_Textual, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
+from diffusers.models.controlnet import zero_module
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 30,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ use_learned_positional_embeddings: bool = False,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
+ raise ValueError(
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+ "issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=not use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
+ )
+ print(use_rotary_positional_embeddings, use_learned_positional_embeddings)
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 3. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
+
+ # 4. Output blocks
+ self.norm_out = AdaLayerNorm(
+ embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ chunk_dim=1,
+ )
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ interval: int=5,
+ residual_hidden_states: torch.Tensor = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ text_seq_length = encoder_hidden_states.shape[1]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ if residual_hidden_states is not None:
+ hidden_states += residual_hidden_states[i//interval]
+
+ if not self.config.use_rotary_positional_embeddings:
+ # CogVideoX-2B
+ hidden_states = self.norm_final(hidden_states)
+ else:
+ # CogVideoX-5B
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 4. Final block
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ # Note: we use `-1` instead of `channels`:
+ # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
+ # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
+ p = self.config.patch_size
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/control_cogvideox/controlnet_cogvideox_transformer_3d.py b/control_cogvideox/controlnet_cogvideox_transformer_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..77010cef9296910281851eade91027acd20e75f2
--- /dev/null
+++ b/control_cogvideox/controlnet_cogvideox_transformer_3d.py
@@ -0,0 +1,504 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from .embeddings import CogVideoXPatchEmbed, CogVideoXPatchEmbed_No_Textual, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
+try:
+ from diffusers.models.controlnet import BaseOutput, zero_module
+except:
+ from diffusers.models.controlnet import zero_module
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+class ControlCogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ control_in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 6,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ use_learned_positional_embeddings: bool = False,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
+ raise ValueError(
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+ "issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=not use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
+ )
+ self.control_patch_embed = CogVideoXPatchEmbed_No_Textual(
+ patch_size=patch_size,
+ in_channels=control_in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=False,
+ use_learned_positional_embeddings=False,
+ )
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 3. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.controlnet_blocks = torch.nn.ModuleList([])
+ for i in range(num_layers):
+ control_block = nn.Linear(inner_dim, inner_dim)
+ control_block = zero_module(control_block)
+ self.controlnet_blocks.append(control_block)
+ self.control_patch_embed.proj = zero_module(self.control_patch_embed.proj)
+
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ control_condition: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ dropout=False
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+ before_hidden_states = hidden_states.detach()
+ control_hidden_states = self.control_patch_embed(encoder_hidden_states, control_condition)
+ hidden_states = hidden_states + control_hidden_states
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ text_seq_length = encoder_hidden_states.shape[1]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ residual_samples = []
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = self.controlnet_blocks[i](hidden_states)
+ residual_samples.append(hidden_states)
+
+ return residual_samples
diff --git a/control_cogvideox/controlnet_cogvideox_transformer_3d_condition.py b/control_cogvideox/controlnet_cogvideox_transformer_3d_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..9737b0c27cbd81eb1a36867e7dbb041749113722
--- /dev/null
+++ b/control_cogvideox/controlnet_cogvideox_transformer_3d_condition.py
@@ -0,0 +1,521 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from .embeddings import CogVideoXPatchEmbed, CogVideoXPatchEmbed_No_Textual, TimestepEmbedding, Timesteps
+from einops import rearrange
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
+from diffusers.models.controlnet import BaseOutput, zero_module
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+class ControlCogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ control_in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 6,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ use_learned_positional_embeddings: bool = False,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
+ raise ValueError(
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+ "issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=not use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
+ )
+ self.control_patch_embed = CogVideoXPatchEmbed_No_Textual(
+ patch_size=patch_size,
+ in_channels=control_in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=False,
+ use_learned_positional_embeddings=False,
+ )
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 3. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.controlnet_blocks = torch.nn.ModuleList([])
+ for i in range(num_layers):
+ control_block = nn.Linear(inner_dim, inner_dim)
+ control_block = zero_module(control_block)
+ self.controlnet_blocks.append(control_block)
+
+ self.input_hint_block = nn.Sequential(
+ nn.Conv3d(1, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv3d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv3d(16, 16, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv3d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv3d(16, 16, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv3d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv3d(16, 16, 3, padding=(1,1,1), stride=(1,2,2)),
+ nn.SiLU(),
+ zero_module(nn.Conv3d(16, 16, 3, padding=1))
+ )
+
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ control_condition: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+ before_hidden_states = hidden_states.detach()
+ control_condition = self.input_hint_block(control_condition)
+ control_condition = rearrange(control_condition, "b c f w h -> b f c w h")
+ control_hidden_states = self.control_patch_embed(encoder_hidden_states, control_condition)
+ hidden_states = hidden_states + control_hidden_states
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ text_seq_length = encoder_hidden_states.shape[1]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ residual_samples = []
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = self.controlnet_blocks[i](hidden_states)
+ residual_samples.append(hidden_states)
+
+ return residual_samples
diff --git a/control_cogvideox/controlnet_cogvideox_transformer_3d_multi_controlnets.py b/control_cogvideox/controlnet_cogvideox_transformer_3d_multi_controlnets.py
new file mode 100644
index 0000000000000000000000000000000000000000..949a568f5893f3dc7db213e206666dd1018bf33e
--- /dev/null
+++ b/control_cogvideox/controlnet_cogvideox_transformer_3d_multi_controlnets.py
@@ -0,0 +1,501 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin
+from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.attention import Attention, FeedForward
+from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from .embeddings import CogVideoXPatchEmbed, CogVideoXPatchEmbed_No_Textual, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
+from diffusers.models.controlnet import BaseOutput, zero_module
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@maybe_allow_in_graph
+class CogVideoXBlock(nn.Module):
+ r"""
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float = 0.0,
+ activation_fn: str = "gelu-approximate",
+ attention_bias: bool = False,
+ qk_norm: bool = True,
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ final_dropout: bool = True,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm="layer_norm" if qk_norm else None,
+ eps=1e-6,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
+ )
+
+ # 2. Feed Forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm & modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed-forward
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+class ControlCogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 30,
+ attention_head_dim: int = 64,
+ in_channels: int = 16,
+ control_in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ time_embed_dim: int = 512,
+ text_embed_dim: int = 4096,
+ num_layers: int = 6,
+ dropout: float = 0.0,
+ attention_bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ patch_size: int = 2,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ activation_fn: str = "gelu-approximate",
+ timestep_activation_fn: str = "silu",
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
+ use_learned_positional_embeddings: bool = False,
+ ):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
+ raise ValueError(
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
+ "issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=not use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
+ )
+ self.control_patch_embed = CogVideoXPatchEmbed_No_Textual(
+ patch_size=patch_size,
+ in_channels=control_in_channels,
+ embed_dim=inner_dim,
+ text_embed_dim=text_embed_dim,
+ bias=True,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=False,
+ use_learned_positional_embeddings=False,
+ )
+ self.embedding_dropout = nn.Dropout(dropout)
+
+ # 2. Time embeddings
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
+
+ # 3. Define spatio-temporal transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ CogVideoXBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.controlnet_blocks = torch.nn.ModuleList([])
+ for i in range(num_layers):
+ control_block = nn.Linear(inner_dim, inner_dim)
+ control_block = zero_module(control_block)
+ self.controlnet_blocks.append(control_block)
+ self.control_patch_embed.proj = zero_module(self.control_patch_embed.proj)
+
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ control_condition: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: Union[int, float, torch.LongTensor],
+ timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ):
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+ before_hidden_states = hidden_states.detach()
+ control_hidden_states = self.control_patch_embed(encoder_hidden_states, control_condition)
+ hidden_states = hidden_states + control_hidden_states
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ text_seq_length = encoder_hidden_states.shape[1]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ residual_samples = []
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ encoder_hidden_states,
+ emb,
+ image_rotary_emb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = self.controlnet_blocks[i](hidden_states)
+ residual_samples.append(hidden_states)
+
+ return residual_samples
diff --git a/control_cogvideox/embeddings.py b/control_cogvideox/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd666a55a489d85c6f019a452423044a9ecb8e05
--- /dev/null
+++ b/control_cogvideox/embeddings.py
@@ -0,0 +1,1843 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate
+from diffusers.models.activations import FP32SiLU, get_activation
+from diffusers.models.attention_processor import Attention
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ Args
+ timesteps (torch.Tensor):
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
+ embedding_dim (int):
+ the dimension of the output.
+ flip_sin_to_cos (bool):
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
+ downscale_freq_shift (float):
+ Controls the delta between frequencies between dimensions
+ scale (float):
+ Scaling factor applied to the embeddings.
+ max_period (int):
+ Controls the maximum frequency of the embeddings
+ Returns
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def get_3d_sincos_pos_embed(
+ embed_dim: int,
+ spatial_size: Union[int, Tuple[int, int]],
+ temporal_size: int,
+ spatial_interpolation_scale: float = 1.0,
+ temporal_interpolation_scale: float = 1.0,
+) -> np.ndarray:
+ r"""
+ Args:
+ embed_dim (`int`):
+ spatial_size (`int` or `Tuple[int, int]`):
+ temporal_size (`int`):
+ spatial_interpolation_scale (`float`, defaults to 1.0):
+ temporal_interpolation_scale (`float`, defaults to 1.0):
+ """
+ if embed_dim % 4 != 0:
+ raise ValueError("`embed_dim` must be divisible by 4")
+ if isinstance(spatial_size, int):
+ spatial_size = (spatial_size, spatial_size)
+
+ embed_dim_spatial = 3 * embed_dim // 4
+ embed_dim_temporal = embed_dim // 4
+
+ # 1. Spatial
+ grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
+ grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
+
+ # 2. Temporal
+ grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
+
+ # 3. Concat
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
+ pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
+
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
+ pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
+
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
+):
+ """
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, int):
+ grid_size = (grid_size, grid_size)
+
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
+ """
+ if embed_dim % 2 != 0:
+ raise ValueError("embed_dim must be divisible by 2")
+
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding with support for SD3 cropping."""
+
+ def __init__(
+ self,
+ height=224,
+ width=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dim=768,
+ layer_norm=False,
+ flatten=True,
+ bias=True,
+ interpolation_scale=1,
+ pos_embed_type="sincos",
+ pos_embed_max_size=None, # For SD3 cropping
+ ):
+ super().__init__()
+
+ num_patches = (height // patch_size) * (width // patch_size)
+ self.flatten = flatten
+ self.layer_norm = layer_norm
+ self.pos_embed_max_size = pos_embed_max_size
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ if layer_norm:
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ self.norm = None
+
+ self.patch_size = patch_size
+ self.height, self.width = height // patch_size, width // patch_size
+ self.base_size = height // patch_size
+ self.interpolation_scale = interpolation_scale
+
+ # Calculate positional embeddings based on max size or default
+ if pos_embed_max_size:
+ grid_size = pos_embed_max_size
+ else:
+ grid_size = int(num_patches**0.5)
+
+ if pos_embed_type is None:
+ self.pos_embed = None
+ elif pos_embed_type == "sincos":
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
+ )
+ persistent = True if pos_embed_max_size else False
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
+ else:
+ raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
+
+ def cropped_pos_embed(self, height, width):
+ """Crops positional embeddings for SD3 compatibility."""
+ if self.pos_embed_max_size is None:
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
+
+ height = height // self.patch_size
+ width = width // self.patch_size
+ if height > self.pos_embed_max_size:
+ raise ValueError(
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+ if width > self.pos_embed_max_size:
+ raise ValueError(
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
+ )
+
+ top = (self.pos_embed_max_size - height) // 2
+ left = (self.pos_embed_max_size - width) // 2
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
+ return spatial_pos_embed
+
+ def forward(self, latent):
+ if self.pos_embed_max_size is not None:
+ height, width = latent.shape[-2:]
+ else:
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
+
+ latent = self.proj(latent)
+ if self.flatten:
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ if self.layer_norm:
+ latent = self.norm(latent)
+ if self.pos_embed is None:
+ return latent.to(latent.dtype)
+ # Interpolate or crop positional embeddings as needed
+ if self.pos_embed_max_size:
+ pos_embed = self.cropped_pos_embed(height, width)
+ else:
+ if self.height != height or self.width != width:
+ pos_embed = get_2d_sincos_pos_embed(
+ embed_dim=self.pos_embed.shape[-1],
+ grid_size=(height, width),
+ base_size=self.base_size,
+ interpolation_scale=self.interpolation_scale,
+ )
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
+ else:
+ pos_embed = self.pos_embed
+
+ return (latent + pos_embed).to(latent.dtype)
+
+
+class LuminaPatchEmbed(nn.Module):
+ """2D Image to Patch Embedding with support for Lumina-T2X"""
+
+ def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
+ super().__init__()
+ self.patch_size = patch_size
+ self.proj = nn.Linear(
+ in_features=patch_size * patch_size * in_channels,
+ out_features=embed_dim,
+ bias=bias,
+ )
+
+ def forward(self, x, freqs_cis):
+ """
+ Patchifies and embeds the input tensor(s).
+
+ Args:
+ x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
+ and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
+ frequency tensor(s).
+ """
+ freqs_cis = freqs_cis.to(x[0].device)
+ patch_height = patch_width = self.patch_size
+ batch_size, channel, height, width = x.size()
+ height_tokens, width_tokens = height // patch_height, width // patch_width
+
+ x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
+ 0, 2, 4, 1, 3, 5
+ )
+ x = x.flatten(3)
+ x = self.proj(x)
+ x = x.flatten(1, 2)
+
+ mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
+
+ return (
+ x,
+ mask,
+ [(height, width)] * batch_size,
+ freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
+ )
+
+class CogVideoXPatchEmbed_No_Textual(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ embed_dim: int = 1920,
+ text_embed_dim: int = 4096,
+ bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_positional_embeddings: bool = True,
+ use_learned_positional_embeddings: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.embed_dim = embed_dim
+ self.sample_height = sample_height
+ self.sample_width = sample_width
+ self.sample_frames = sample_frames
+ self.temporal_compression_ratio = temporal_compression_ratio
+ self.max_text_seq_length = max_text_seq_length
+ self.spatial_interpolation_scale = spatial_interpolation_scale
+ self.temporal_interpolation_scale = temporal_interpolation_scale
+ self.use_positional_embeddings = use_positional_embeddings
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
+
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ #self.text_proj = nn.Linear(text_embed_dim, embed_dim)
+
+ if use_positional_embeddings or use_learned_positional_embeddings:
+ persistent = use_learned_positional_embeddings
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
+
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
+ post_patch_height = sample_height // self.patch_size
+ post_patch_width = sample_width // self.patch_size
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
+
+ pos_embedding = get_3d_sincos_pos_embed(
+ self.embed_dim,
+ (post_patch_width, post_patch_height),
+ post_time_compression_frames,
+ self.spatial_interpolation_scale,
+ self.temporal_interpolation_scale,
+ )
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
+ joint_pos_embedding = torch.zeros(
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
+ )
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
+
+ return joint_pos_embedding
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ r"""
+ Args:
+ text_embeds (`torch.Tensor`):
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
+ image_embeds (`torch.Tensor`):
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
+ """
+ #text_embeds = self.text_proj(text_embeds)
+ b,f,_ = text_embeds.shape
+ text_embeds = torch.zeros([b, f, self.embed_dim]).to(dtype=self.proj.weight.dtype, device=self.proj.weight.device)
+
+ batch, num_frames, channels, height, width = image_embeds.shape
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
+ image_embeds = self.proj(image_embeds)
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
+
+ embeds = torch.cat(
+ [text_embeds, image_embeds], dim=1
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
+
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
+ if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
+ raise ValueError(
+ "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
+ "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
+
+ if (
+ self.sample_height != height
+ or self.sample_width != width
+ or self.sample_frames != pre_time_compression_frames
+ ):
+ pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
+ pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
+ else:
+ pos_embedding = self.pos_embedding
+
+ #embeds = embeds + pos_embedding
+
+ return embeds
+
+
+class CogVideoXPatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ embed_dim: int = 1920,
+ text_embed_dim: int = 4096,
+ bias: bool = True,
+ sample_width: int = 90,
+ sample_height: int = 60,
+ sample_frames: int = 49,
+ temporal_compression_ratio: int = 4,
+ max_text_seq_length: int = 226,
+ spatial_interpolation_scale: float = 1.875,
+ temporal_interpolation_scale: float = 1.0,
+ use_positional_embeddings: bool = True,
+ use_learned_positional_embeddings: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+ self.embed_dim = embed_dim
+ self.sample_height = sample_height
+ self.sample_width = sample_width
+ self.sample_frames = sample_frames
+ self.temporal_compression_ratio = temporal_compression_ratio
+ self.max_text_seq_length = max_text_seq_length
+ self.spatial_interpolation_scale = spatial_interpolation_scale
+ self.temporal_interpolation_scale = temporal_interpolation_scale
+ self.use_positional_embeddings = use_positional_embeddings
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
+
+ print("self.use_positional_embeddings", self.use_positional_embeddings)
+ print("self.use_learned_positional_embeddings", self.use_learned_positional_embeddings)
+ self.proj = nn.Conv2d(
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
+ )
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
+
+ if use_positional_embeddings or use_learned_positional_embeddings:
+ persistent = use_learned_positional_embeddings
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
+
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
+ post_patch_height = sample_height // self.patch_size
+ post_patch_width = sample_width // self.patch_size
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
+
+ pos_embedding = get_3d_sincos_pos_embed(
+ self.embed_dim,
+ (post_patch_width, post_patch_height),
+ post_time_compression_frames,
+ self.spatial_interpolation_scale,
+ self.temporal_interpolation_scale,
+ )
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
+ joint_pos_embedding = torch.zeros(
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
+ )
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
+
+ return joint_pos_embedding
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ r"""
+ Args:
+ text_embeds (`torch.Tensor`):
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
+ image_embeds (`torch.Tensor`):
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
+ """
+ text_embeds = self.text_proj(text_embeds)
+
+ batch, num_frames, channels, height, width = image_embeds.shape
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
+ image_embeds = self.proj(image_embeds)
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
+
+ embeds = torch.cat(
+ [text_embeds, image_embeds], dim=1
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
+
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
+ #if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
+ # raise ValueError(
+ # "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
+ # "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
+ # )
+
+ pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
+
+ if (
+ self.sample_height != height
+ or self.sample_width != width
+ or self.sample_frames != pre_time_compression_frames
+ ):
+ pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
+ pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
+ #print("new size")
+ else:
+ pos_embedding = self.pos_embedding
+
+ #print("embeds = embeds + pos_embedding")
+ embeds = embeds + pos_embedding
+
+ return embeds
+
+
+def get_3d_rotary_pos_embed(
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ RoPE for video tokens with 3D structure.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size, corresponding to hidden_size_head.
+ crops_coords (`Tuple[int]`):
+ The top-left and bottom-right coordinates of the crop.
+ grid_size (`Tuple[int]`):
+ The grid size of the spatial positional embedding (height, width).
+ temporal_size (`int`):
+ The size of the temporal dimension.
+ theta (`float`):
+ Scaling factor for frequency computation.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
+ """
+ if use_real is not True:
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
+ start, stop = crops_coords
+ grid_size_h, grid_size_w = grid_size
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
+
+ # Compute dimensions for each axis
+ dim_t = embed_dim // 4
+ dim_h = embed_dim // 8 * 3
+ dim_w = embed_dim // 8 * 3
+
+ # Temporal frequencies
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
+ # Spatial frequencies for height and width
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
+
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
+ freqs_t = freqs_t[:, None, None, :].expand(
+ -1, grid_size_h, grid_size_w, -1
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
+ freqs_h = freqs_h[None, :, None, :].expand(
+ temporal_size, -1, grid_size_w, -1
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
+ freqs_w = freqs_w[None, None, :, :].expand(
+ temporal_size, grid_size_h, -1, -1
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
+
+ freqs = torch.cat(
+ [freqs_t, freqs_h, freqs_w], dim=-1
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
+ freqs = freqs.view(
+ temporal_size * grid_size_h * grid_size_w, -1
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
+ return freqs
+
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
+ return cos, sin
+
+
+def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
+ """
+ RoPE for image tokens with 2d structure.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size
+ crops_coords (`Tuple[int]`)
+ The top-left and bottom-right coordinates of the crop.
+ grid_size (`Tuple[int]`):
+ The grid size of the positional embedding.
+ use_real (`bool`):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
+ """
+ start, stop = crops_coords
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0) # [2, W, H]
+
+ grid = grid.reshape([2, 1, *grid.shape[1:]])
+ pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
+ return pos_embed
+
+
+def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
+ assert embed_dim % 4 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_rotary_pos_embed(
+ embed_dim // 2, grid[0].reshape(-1), use_real=use_real
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
+ emb_w = get_1d_rotary_pos_embed(
+ embed_dim // 2, grid[1].reshape(-1), use_real=use_real
+ ) # (H*W, D/2) if use_real else (H*W, D/4)
+
+ if use_real:
+ cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
+ sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
+ return cos, sin
+ else:
+ emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
+ return emb
+
+
+def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
+ assert embed_dim % 4 == 0
+
+ emb_h = get_1d_rotary_pos_embed(
+ embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
+ ) # (H, D/4)
+ emb_w = get_1d_rotary_pos_embed(
+ embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
+ ) # (W, D/4)
+ emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
+ emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
+
+ emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
+ return emb
+
+
+def get_1d_rotary_pos_embed(
+ dim: int,
+ pos: Union[np.ndarray, int],
+ theta: float = 10000.0,
+ use_real=False,
+ linear_factor=1.0,
+ ntk_factor=1.0,
+ repeat_interleave_real=True,
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
+):
+ """
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
+ data type.
+
+ Args:
+ dim (`int`): Dimension of the frequency tensor.
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
+ theta (`float`, *optional*, defaults to 10000.0):
+ Scaling factor for frequency computation. Defaults to 10000.0.
+ use_real (`bool`, *optional*):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+ linear_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the context extrapolation. Defaults to 1.0.
+ ntk_factor (`float`, *optional*, defaults to 1.0):
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
+ Otherwise, they are concateanted with themselves.
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
+ the dtype of the frequency tensor.
+ Returns:
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
+ """
+ assert dim % 2 == 0
+
+ if isinstance(pos, int):
+ pos = torch.arange(pos)
+ if isinstance(pos, np.ndarray):
+ pos = torch.from_numpy(pos) # type: ignore # [S]
+
+ theta = theta * ntk_factor
+ freqs = (
+ 1.0
+ / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
+ / linear_factor
+ ) # [D/2]
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
+ if use_real and repeat_interleave_real:
+ # flux, hunyuan-dit, cogvideox
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ elif use_real:
+ # stable audio
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
+ return freqs_cos, freqs_sin
+ else:
+ # lumina
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
+ return freqs_cis
+
+
+def apply_rotary_emb(
+ x: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+ use_real: bool = True,
+ use_real_unbind_dim: int = -1,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ if use_real:
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, None]
+ sin = sin[None, None]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+
+ if use_real_unbind_dim == -1:
+ # Used for flux, cogvideox, hunyuan-dit
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ elif use_real_unbind_dim == -2:
+ # Used for Stable Audio
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
+ else:
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
+
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+
+ return out
+ else:
+ # used for lumina
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
+ freqs_cis = freqs_cis.unsqueeze(2)
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
+
+ return x_out.type_as(x)
+
+
+class FluxPosEmbed(nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ freqs_dtype = torch.float32 if is_mps else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ act_fn: str = "silu",
+ out_dim: int = None,
+ post_act_fn: Optional[str] = None,
+ cond_proj_dim=None,
+ sample_proj_bias=True,
+ ):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
+
+ if cond_proj_dim is not None:
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
+ else:
+ self.cond_proj = None
+
+ self.act = get_activation(act_fn)
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
+
+ if post_act_fn is None:
+ self.post_act = None
+ else:
+ self.post_act = get_activation(post_act_fn)
+
+ def forward(self, sample, condition=None):
+ if condition is not None:
+ sample = sample + self.cond_proj(condition)
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+
+ if self.post_act is not None:
+ sample = self.post_act(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+ self.scale = scale
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ scale=self.scale,
+ )
+ return t_emb
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(
+ self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
+ ):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.log = log
+ self.flip_sin_to_cos = flip_sin_to_cos
+
+ if set_W_to_weight:
+ # to delete later
+ del self.weight
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+ self.weight = self.W
+ del self.W
+
+ def forward(self, x):
+ if self.log:
+ x = torch.log(x)
+
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
+
+ if self.flip_sin_to_cos:
+ out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
+ else:
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """Apply positional information to a sequence of embeddings.
+
+ Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
+ them
+
+ Args:
+ embed_dim: (int): Dimension of the positional embedding.
+ max_seq_length: Maximum sequence length to apply positional embeddings
+
+ """
+
+ def __init__(self, embed_dim: int, max_seq_length: int = 32):
+ super().__init__()
+ position = torch.arange(max_seq_length).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
+ pe = torch.zeros(1, max_seq_length, embed_dim)
+ pe[0, :, 0::2] = torch.sin(position * div_term)
+ pe[0, :, 1::2] = torch.cos(position * div_term)
+ self.register_buffer("pe", pe)
+
+ def forward(self, x):
+ _, seq_length, _ = x.shape
+ x = x + self.pe[:, :seq_length]
+ return x
+
+
+class ImagePositionalEmbeddings(nn.Module):
+ """
+ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
+ height and width of the latent space.
+
+ For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
+
+ For VQ-diffusion:
+
+ Output vector embeddings are used as input for the transformer.
+
+ Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
+
+ Args:
+ num_embed (`int`):
+ Number of embeddings for the latent pixels embeddings.
+ height (`int`):
+ Height of the latent image i.e. the number of height embeddings.
+ width (`int`):
+ Width of the latent image i.e. the number of width embeddings.
+ embed_dim (`int`):
+ Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
+ """
+
+ def __init__(
+ self,
+ num_embed: int,
+ height: int,
+ width: int,
+ embed_dim: int,
+ ):
+ super().__init__()
+
+ self.height = height
+ self.width = width
+ self.num_embed = num_embed
+ self.embed_dim = embed_dim
+
+ self.emb = nn.Embedding(self.num_embed, embed_dim)
+ self.height_emb = nn.Embedding(self.height, embed_dim)
+ self.width_emb = nn.Embedding(self.width, embed_dim)
+
+ def forward(self, index):
+ emb = self.emb(index)
+
+ height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
+
+ # 1 x H x D -> 1 x H x 1 x D
+ height_emb = height_emb.unsqueeze(2)
+
+ width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
+
+ # 1 x W x D -> 1 x 1 x W x D
+ width_emb = width_emb.unsqueeze(1)
+
+ pos_emb = height_emb + width_emb
+
+ # 1 x H x W x D -> 1 x L xD
+ pos_emb = pos_emb.view(1, self.height * self.width, -1)
+
+ emb = emb + pos_emb[:, : emb.shape[1], :]
+
+ return emb
+
+
+class LabelEmbedding(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+
+ Args:
+ num_classes (`int`): The number of classes.
+ hidden_size (`int`): The size of the vector embeddings.
+ dropout_prob (`float`): The probability of dropping a label.
+ """
+
+ def __init__(self, num_classes, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
+ self.num_classes = num_classes
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, labels, force_drop_ids=None):
+ """
+ Drops labels to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
+ else:
+ drop_ids = torch.tensor(force_drop_ids == 1)
+ labels = torch.where(drop_ids, self.num_classes, labels)
+ return labels
+
+ def forward(self, labels: torch.LongTensor, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (self.training and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ embeddings = self.embedding_table(labels)
+ return embeddings
+
+
+class TextImageProjection(nn.Module):
+ def __init__(
+ self,
+ text_embed_dim: int = 1024,
+ image_embed_dim: int = 768,
+ cross_attention_dim: int = 768,
+ num_image_text_embeds: int = 10,
+ ):
+ super().__init__()
+
+ self.num_image_text_embeds = num_image_text_embeds
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
+ self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ batch_size = text_embeds.shape[0]
+
+ # image
+ image_text_embeds = self.image_embeds(image_embeds)
+ image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
+
+ # text
+ text_embeds = self.text_proj(text_embeds)
+
+ return torch.cat([image_text_embeds, text_embeds], dim=1)
+
+
+class ImageProjection(nn.Module):
+ def __init__(
+ self,
+ image_embed_dim: int = 768,
+ cross_attention_dim: int = 768,
+ num_image_text_embeds: int = 32,
+ ):
+ super().__init__()
+
+ self.num_image_text_embeds = num_image_text_embeds
+ self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ batch_size = image_embeds.shape[0]
+
+ # image
+ image_embeds = self.image_embeds(image_embeds)
+ image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
+ image_embeds = self.norm(image_embeds)
+ return image_embeds
+
+
+class IPAdapterFullImageProjection(nn.Module):
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
+ super().__init__()
+ from .attention import FeedForward
+
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ return self.norm(self.ff(image_embeds))
+
+
+class IPAdapterFaceIDImageProjection(nn.Module):
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
+ super().__init__()
+ from .attention import FeedForward
+
+ self.num_tokens = num_tokens
+ self.cross_attention_dim = cross_attention_dim
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ x = self.ff(image_embeds)
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
+ return self.norm(x)
+
+
+class CombinedTimestepLabelEmbeddings(nn.Module):
+ def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
+
+ def forward(self, timestep, class_labels, hidden_dtype=None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ class_labels = self.class_embedder(class_labels) # (N, D)
+
+ conditioning = timesteps_emb + class_labels # (N, D)
+
+ return conditioning
+
+
+class CombinedTimestepTextProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, pooled_projection_dim):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ def forward(self, timestep, pooled_projection):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+
+ pooled_projections = self.text_embedder(pooled_projection)
+
+ conditioning = timesteps_emb + pooled_projections
+
+ return conditioning
+
+
+class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, pooled_projection_dim):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
+
+ def forward(self, timestep, guidance, pooled_projection):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
+
+ time_guidance_emb = timesteps_emb + guidance_emb
+
+ pooled_projections = self.text_embedder(pooled_projection)
+ conditioning = time_guidance_emb + pooled_projections
+
+ return conditioning
+
+
+class HunyuanDiTAttentionPool(nn.Module):
+ # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
+
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.permute(1, 0, 2) # NLC -> LNC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x[:1],
+ key=x,
+ value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False,
+ )
+ return x.squeeze(0)
+
+
+class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
+ def __init__(
+ self,
+ embedding_dim,
+ pooled_projection_dim=1024,
+ seq_len=256,
+ cross_attention_dim=2048,
+ use_style_cond_and_image_meta_size=True,
+ ):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+
+ self.pooler = HunyuanDiTAttentionPool(
+ seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
+ )
+
+ # Here we use a default learned embedder layer for future extension.
+ self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
+ if use_style_cond_and_image_meta_size:
+ self.style_embedder = nn.Embedding(1, embedding_dim)
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
+ else:
+ extra_in_dim = pooled_projection_dim
+
+ self.extra_embedder = PixArtAlphaTextProjection(
+ in_features=extra_in_dim,
+ hidden_size=embedding_dim * 4,
+ out_features=embedding_dim,
+ act_fn="silu_fp32",
+ )
+
+ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
+
+ # extra condition1: text
+ pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
+
+ if self.use_style_cond_and_image_meta_size:
+ # extra condition2: image meta size embedding
+ image_meta_size = self.size_proj(image_meta_size.view(-1))
+ image_meta_size = image_meta_size.to(dtype=hidden_dtype)
+ image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
+
+ # extra condition3: style embedding
+ style_embedding = self.style_embedder(style) # (N, embedding_dim)
+
+ # Concatenate all extra vectors
+ extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
+ else:
+ extra_cond = torch.cat([pooled_projections], dim=1)
+
+ conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
+
+ return conditioning
+
+
+class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
+ def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
+ super().__init__()
+ self.time_proj = Timesteps(
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
+ )
+
+ self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
+
+ self.caption_embedder = nn.Sequential(
+ nn.LayerNorm(cross_attention_dim),
+ nn.Linear(
+ cross_attention_dim,
+ hidden_size,
+ bias=True,
+ ),
+ )
+
+ def forward(self, timestep, caption_feat, caption_mask):
+ # timestep embedding:
+ time_freq = self.time_proj(timestep)
+ time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
+
+ # caption condition embedding:
+ caption_mask_float = caption_mask.float().unsqueeze(-1)
+ caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
+ caption_feats_pool = caption_feats_pool.to(caption_feat)
+ caption_embed = self.caption_embedder(caption_feats_pool)
+
+ conditioning = time_embed + caption_embed
+
+ return conditioning
+
+
+class TextTimeEmbedding(nn.Module):
+ def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(encoder_dim)
+ self.pool = AttentionPooling(num_heads, encoder_dim)
+ self.proj = nn.Linear(encoder_dim, time_embed_dim)
+ self.norm2 = nn.LayerNorm(time_embed_dim)
+
+ def forward(self, hidden_states):
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.pool(hidden_states)
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+class TextImageTimeEmbedding(nn.Module):
+ def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.text_proj = nn.Linear(text_embed_dim, time_embed_dim)
+ self.text_norm = nn.LayerNorm(time_embed_dim)
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
+ # text
+ time_text_embeds = self.text_proj(text_embeds)
+ time_text_embeds = self.text_norm(time_text_embeds)
+
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+
+ return time_image_embeds + time_text_embeds
+
+
+class ImageTimeEmbedding(nn.Module):
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+ self.image_norm = nn.LayerNorm(time_embed_dim)
+
+ def forward(self, image_embeds: torch.Tensor):
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+ time_image_embeds = self.image_norm(time_image_embeds)
+ return time_image_embeds
+
+
+class ImageHintTimeEmbedding(nn.Module):
+ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536):
+ super().__init__()
+ self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
+ self.image_norm = nn.LayerNorm(time_embed_dim)
+ self.input_hint_block = nn.Sequential(
+ nn.Conv2d(3, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 16, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(16, 32, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(32, 32, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(32, 96, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(96, 96, 3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(96, 256, 3, padding=1, stride=2),
+ nn.SiLU(),
+ nn.Conv2d(256, 4, 3, padding=1),
+ )
+
+ def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
+ # image
+ time_image_embeds = self.image_proj(image_embeds)
+ time_image_embeds = self.image_norm(time_image_embeds)
+ hint = self.input_hint_block(hint)
+ return time_image_embeds, hint
+
+
+class AttentionPooling(nn.Module):
+ # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54
+
+ def __init__(self, num_heads, embed_dim, dtype=None):
+ super().__init__()
+ self.dtype = dtype
+ self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype)
+ self.num_heads = num_heads
+ self.dim_per_head = embed_dim // self.num_heads
+
+ def forward(self, x):
+ bs, length, width = x.size()
+
+ def shape(x):
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, -1, self.num_heads, self.dim_per_head)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs * self.num_heads, -1, self.dim_per_head)
+ # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length)
+ x = x.transpose(1, 2)
+ return x
+
+ class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype)
+ x = torch.cat([class_token, x], dim=1) # (bs, length+1, width)
+
+ # (bs*n_heads, class_token_length, dim_per_head)
+ q = shape(self.q_proj(class_token))
+ # (bs*n_heads, length+class_token_length, dim_per_head)
+ k = shape(self.k_proj(x))
+ v = shape(self.v_proj(x))
+
+ # (bs*n_heads, class_token_length, length+class_token_length):
+ scale = 1 / math.sqrt(math.sqrt(self.dim_per_head))
+ weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+
+ # (bs*n_heads, dim_per_head, class_token_length)
+ a = torch.einsum("bts,bcs->bct", weight, v)
+
+ # (bs, length+1, width)
+ a = a.reshape(bs, -1, 1).transpose(1, 2)
+
+ return a[:, 0, :] # cls_token
+
+
+def get_fourier_embeds_from_boundingbox(embed_dim, box):
+ """
+ Args:
+ embed_dim: int
+ box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
+ Returns:
+ [B x N x embed_dim] tensor of positional embeddings
+ """
+
+ batch_size, num_boxes = box.shape[:2]
+
+ emb = 100 ** (torch.arange(embed_dim) / embed_dim)
+ emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
+ emb = emb * box.unsqueeze(-1)
+
+ emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
+ emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
+
+ return emb
+
+
+class GLIGENTextBoundingboxProjection(nn.Module):
+ def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
+ super().__init__()
+ self.positive_len = positive_len
+ self.out_dim = out_dim
+
+ self.fourier_embedder_dim = fourier_freqs
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
+
+ if isinstance(out_dim, tuple):
+ out_dim = out_dim[0]
+
+ if feature_type == "text-only":
+ self.linears = nn.Sequential(
+ nn.Linear(self.positive_len + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear(512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+
+ elif feature_type == "text-image":
+ self.linears_text = nn.Sequential(
+ nn.Linear(self.positive_len + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear(512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+ self.linears_image = nn.Sequential(
+ nn.Linear(self.positive_len + self.position_dim, 512),
+ nn.SiLU(),
+ nn.Linear(512, 512),
+ nn.SiLU(),
+ nn.Linear(512, out_dim),
+ )
+ self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+ self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
+
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
+
+ def forward(
+ self,
+ boxes,
+ masks,
+ positive_embeddings=None,
+ phrases_masks=None,
+ image_masks=None,
+ phrases_embeddings=None,
+ image_embeddings=None,
+ ):
+ masks = masks.unsqueeze(-1)
+
+ # embedding position (it may includes padding as placeholder)
+ xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C
+
+ # learnable null embedding
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
+
+ # replace padding with learnable null embedding
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
+
+ # positionet with text only information
+ if positive_embeddings is not None:
+ # learnable null embedding
+ positive_null = self.null_positive_feature.view(1, 1, -1)
+
+ # replace padding with learnable null embedding
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
+
+ objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
+
+ # positionet with text and image information
+ else:
+ phrases_masks = phrases_masks.unsqueeze(-1)
+ image_masks = image_masks.unsqueeze(-1)
+
+ # learnable null embedding
+ text_null = self.null_text_feature.view(1, 1, -1)
+ image_null = self.null_image_feature.view(1, 1, -1)
+
+ # replace padding with learnable null embedding
+ phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null
+ image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null
+
+ objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1))
+ objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1))
+ objs = torch.cat([objs_text, objs_image], dim=1)
+
+ return objs
+
+
+class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
+ """
+ For PixArt-Alpha.
+
+ Reference:
+ https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
+ """
+
+ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
+ super().__init__()
+
+ self.outdim = size_emb_dim
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.use_additional_conditions = use_additional_conditions
+ if use_additional_conditions:
+ self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+ self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
+
+ def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ if self.use_additional_conditions:
+ resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
+ resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
+ aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
+ aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
+ conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
+ else:
+ conditioning = timesteps_emb
+
+ return conditioning
+
+
+class PixArtAlphaTextProjection(nn.Module):
+ """
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
+
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
+ """
+
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
+ super().__init__()
+ if out_features is None:
+ out_features = hidden_size
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
+ if act_fn == "gelu_tanh":
+ self.act_1 = nn.GELU(approximate="tanh")
+ elif act_fn == "silu":
+ self.act_1 = nn.SiLU()
+ elif act_fn == "silu_fp32":
+ self.act_1 = FP32SiLU()
+ else:
+ raise ValueError(f"Unknown activation function: {act_fn}")
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
+
+ def forward(self, caption):
+ hidden_states = self.linear_1(caption)
+ hidden_states = self.act_1(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+class IPAdapterPlusImageProjectionBlock(nn.Module):
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ dim_head: int = 64,
+ heads: int = 16,
+ ffn_ratio: float = 4,
+ ) -> None:
+ super().__init__()
+ from .attention import FeedForward
+
+ self.ln0 = nn.LayerNorm(embed_dims)
+ self.ln1 = nn.LayerNorm(embed_dims)
+ self.attn = Attention(
+ query_dim=embed_dims,
+ dim_head=dim_head,
+ heads=heads,
+ out_bias=False,
+ )
+ self.ff = nn.Sequential(
+ nn.LayerNorm(embed_dims),
+ FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
+ )
+
+ def forward(self, x, latents, residual):
+ encoder_hidden_states = self.ln0(x)
+ latents = self.ln1(latents)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
+ latents = self.attn(latents, encoder_hidden_states) + residual
+ latents = self.ff(latents) + latents
+ return latents
+
+
+class IPAdapterPlusImageProjection(nn.Module):
+ """Resampler of IP-Adapter Plus.
+
+ Args:
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
+ that is the same
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
+ hidden_dims (int):
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
+ Defaults to 16. num_queries (int):
+ The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
+ of feedforward network hidden
+ layer channels. Defaults to 4.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ output_dims: int = 1024,
+ hidden_dims: int = 1280,
+ depth: int = 4,
+ dim_head: int = 64,
+ heads: int = 16,
+ num_queries: int = 8,
+ ffn_ratio: float = 4,
+ ) -> None:
+ super().__init__()
+ self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
+
+ self.proj_in = nn.Linear(embed_dims, hidden_dims)
+
+ self.proj_out = nn.Linear(hidden_dims, output_dims)
+ self.norm_out = nn.LayerNorm(output_dims)
+
+ self.layers = nn.ModuleList(
+ [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ x (torch.Tensor): Input Tensor.
+ Returns:
+ torch.Tensor: Output Tensor.
+ """
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+
+ for block in self.layers:
+ residual = latents
+ latents = block(x, latents, residual)
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+
+class IPAdapterFaceIDPlusImageProjection(nn.Module):
+ """FacePerceiverResampler of IP-Adapter Plus.
+
+ Args:
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
+ that is the same
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
+ hidden_dims (int):
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
+ Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels. Defaults to 4.
+ ffproj_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels (for ID embeddings). Defaults to 4.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ output_dims: int = 768,
+ hidden_dims: int = 1280,
+ id_embeddings_dim: int = 512,
+ depth: int = 4,
+ dim_head: int = 64,
+ heads: int = 16,
+ num_tokens: int = 4,
+ num_queries: int = 8,
+ ffn_ratio: float = 4,
+ ffproj_ratio: int = 2,
+ ) -> None:
+ super().__init__()
+ from .attention import FeedForward
+
+ self.num_tokens = num_tokens
+ self.embed_dim = embed_dims
+ self.clip_embeds = None
+ self.shortcut = False
+ self.shortcut_scale = 1.0
+
+ self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
+ self.norm = nn.LayerNorm(embed_dims)
+
+ self.proj_in = nn.Linear(hidden_dims, embed_dims)
+
+ self.proj_out = nn.Linear(embed_dims, output_dims)
+ self.norm_out = nn.LayerNorm(output_dims)
+
+ self.layers = nn.ModuleList(
+ [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
+ )
+
+ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ id_embeds (torch.Tensor): Input Tensor (ID embeds).
+ Returns:
+ torch.Tensor: Output Tensor.
+ """
+ id_embeds = id_embeds.to(self.clip_embeds.dtype)
+ id_embeds = self.proj(id_embeds)
+ id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
+ id_embeds = self.norm(id_embeds)
+ latents = id_embeds
+
+ clip_embeds = self.proj_in(self.clip_embeds)
+ x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
+
+ for block in self.layers:
+ residual = latents
+ latents = block(x, latents, residual)
+
+ latents = self.proj_out(latents)
+ out = self.norm_out(latents)
+ if self.shortcut:
+ out = id_embeds + self.shortcut_scale * out
+ return out
+
+
+class MultiIPAdapterImageProjection(nn.Module):
+ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
+ super().__init__()
+ self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
+
+ def forward(self, image_embeds: List[torch.Tensor]):
+ projected_image_embeds = []
+
+ # currently, we accept `image_embeds` as
+ # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
+ # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
+ if not isinstance(image_embeds, list):
+ deprecation_message = (
+ "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
+ " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
+ )
+ deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
+ image_embeds = [image_embeds.unsqueeze(1)]
+
+ if len(image_embeds) != len(self.image_projection_layers):
+ raise ValueError(
+ f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
+ )
+
+ for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
+ batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
+ image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
+ image_embed = image_projection_layer(image_embed)
+ image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
+
+ projected_image_embeds.append(image_embed)
+
+ return projected_image_embeds
diff --git a/dataset_demo_videos.py b/dataset_demo_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..01da5ff1d88e4c6e11553a74db941535264b442d
--- /dev/null
+++ b/dataset_demo_videos.py
@@ -0,0 +1,96 @@
+import os
+import cv2
+import decord
+import torch
+from torch.utils.data import Dataset
+import numpy as np
+
+class VideoDataset(Dataset):
+ def __init__(self, root_dir):
+ self.root_dir = root_dir
+ self.data = []
+
+ # 遍历根目录下的所有文件夹,收集数据路径
+ #for folder_name in os.listdir(root_dir):
+ folder_path = os.path.join(root_dir, "")
+ if os.path.isdir(folder_path):
+ for file_name in os.listdir(folder_path):
+ if 'edit' in file_name.lower() and file_name.lower().endswith('.png'):
+ number = file_name.split('_edit')[0]
+ video_file = os.path.join(folder_path, f"{number}.mp4")
+ png_file = os.path.join(folder_path, file_name)
+ txt_file = os.path.join(folder_path, f"{number}.txt")
+ self.data.append((png_file, video_file, txt_file))
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ png_file, video_file, txt_file = self.data[idx]
+
+ try:
+ # 读取 PNG 文件并调整大小
+ image = cv2.imread(png_file)
+ resized_image = cv2.resize(image, (768, 448))
+ resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)
+ if "6.mp4" in video_file:
+ frame = np.zeros([448, 768+448, 3]).astype(np.uint8)
+ frame[:, 448//2: 448//2+768, :] = np.array(resized_image)
+ resized_image = cv2.resize(frame, (768, 448))
+
+ except Exception as e:
+ print("*"*200)
+ print(f"Error reading or resizing image {png_file}: {e}")
+ resized_image = np.zeros((448, 768, 3), dtype=np.uint8)
+
+ try:
+ # 读取对应的 MP4 文件
+ #if "6.mp4" in video_file:
+
+ vr = decord.VideoReader(video_file)
+ frames = vr.get_batch(list(range(33))).asnumpy()
+ if "6.mp4" in video_file:
+ resized_frames = [cv2.resize(frame, (768, 448)) for frame in frames]
+ for i in range(len(resized_frames)):
+ frame = np.zeros([448, 768+448, 3]).astype(np.uint8)
+ frame[:, 448//2: 448//2+768, :] = np.array(resized_frames[i])
+ resized_frames[i] = frame
+ resized_frames = [cv2.resize(frame, (768, 448)) for frame in resized_frames]
+ else:
+ resized_frames = [cv2.resize(frame, (768, 448)) for frame in frames]
+ except Exception as e:
+ print("*"*200, video_file, "*"*200)
+ print(f"Error reading or resizing video {video_file}: {e}")
+ resized_frames = [np.zeros((448, 768, 3), dtype=np.uint8) for _ in range(33)]
+
+ try:
+ # 读取对应的 TXT 文件
+ with open(txt_file, 'r') as f:
+ pos_prompt = f.readline().strip()
+ neg_prompt = f.readline().strip()
+ except Exception as e:
+ print(f"Error reading text file {txt_file}: {e}")
+ pos_prompt = ""
+ neg_prompt = ""
+
+ return {
+ 'image': torch.from_numpy(resized_image),
+ 'frames': torch.from_numpy(np.array(resized_frames)),
+ 'pos_prompt': pos_prompt,
+ 'neg_prompt': neg_prompt,
+ 'image_path': png_file # 返回图像路径
+ }
+"""
+# 示例用法
+root_dir = 'demo_videos/videos'
+dataset = VideoDataset(root_dir)
+
+# 读取第一个样本
+sample = dataset[0]
+if sample:
+ print(sample['image'].shape)
+ print(sample['frames'].shape)
+ print(sample['pos_prompt'])
+ print(sample['neg_prompt'])
+ print(sample['image_path'])
+"""
diff --git a/docker.sh b/docker.sh
new file mode 100644
index 0000000000000000000000000000000000000000..73747eda8d5525cfbcaa7511c7bf415cd425d230
--- /dev/null
+++ b/docker.sh
@@ -0,0 +1 @@
+docker run --cpus=96 --shm-size=200g --gpus=all --name=pwx -p 10092:10092 -v /:/data -it 192.168.99.199:8080/aigc/cogvideox:latest /bin/bash
diff --git a/gradio_title.md b/gradio_title.md
new file mode 100644
index 0000000000000000000000000000000000000000..d82ce0b2576148f5795cd72a9af70bdc28123d2a
--- /dev/null
+++ b/gradio_title.md
@@ -0,0 +1,19 @@
+
+
+
LivePortrait: Efficient Portrait Animation with Stitching and Retargeting Control
+
+
+
+
+
+
diff --git a/pipeline_cogvideox_controlnet_5b_i2v_instruction2.py b/pipeline_cogvideox_controlnet_5b_i2v_instruction2.py
new file mode 100644
index 0000000000000000000000000000000000000000..132c4bbbd2f1a4d0a4b2975fe19d10bc5a147d36
--- /dev/null
+++ b/pipeline_cogvideox_controlnet_5b_i2v_instruction2.py
@@ -0,0 +1,783 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import T5EncoderModel, T5Tokenizer
+
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.loaders import CogVideoXLoraLoaderMixin
+from diffusers.models import AutoencoderKLCogVideoX
+from control_cogvideox.cogvideox_transformer_3d import CogVideoXTransformer3DModel
+from control_cogvideox.controlnet_cogvideox_transformer_3d import ControlCogVideoXTransformer3DModel
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from diffusers.utils import logging, replace_example_docstring
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import CogVideoXPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
+ >>> prompt = (
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+ ... "atmosphere of this unique musical performance."
+ ... )
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=8)
+ ```
+"""
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class ControlCogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using CogVideoX.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. CogVideoX uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CogVideoXTransformer3DModel`]):
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "vae->text_encoder->controlnet_transformer->transformer->vae"#"vae->text_encoder->controlnet_transformer->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: CogVideoXTransformer3DModel,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ controlnet_transformer: ControlCogVideoXTransformer3DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler, controlnet_transformer=controlnet_transformer
+ )
+ self.vae_scale_factor_spatial = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+ )
+ self.vae_scaling_factor_image = (
+ self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ #self.controlnet_transformer = controlnet_transformer
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ shape = (
+ batch_size,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
+ latents = 1 / self.vae_scaling_factor_image * latents
+
+ frames = self.vae.decode(latents).sample
+ return frames
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def fuse_qkv_projections(self) -> None:
+ r"""Enables fused QKV projections."""
+ self.fusing_transformer = True
+ self.transformer.fuse_qkv_projections()
+
+ def unfuse_qkv_projections(self) -> None:
+ r"""Disable QKV projection fusion if enabled."""
+ if not self.fusing_transformer:
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.transformer.unfuse_qkv_projections()
+ self.fusing_transformer = False
+
+ def _prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ print('vae_scale_factor_spatial', self.vae_scale_factor_spatial)
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+ return freqs_cos, freqs_sin
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ video_condition: Optional[torch.FloatTensor] = None, # 16 source_latents
+ video_condition2: Optional[torch.FloatTensor] = None, # 16 image_latents
+ height: int = 480,
+ width: int = 720,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ interval: int = 6,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
+ num_frames (`int`, defaults to `48`):
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
+ needs to be satisfied is that of divisibility mentioned above.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `226`):
+ Maximum sequence length in encoded prompt. Must be consistent with
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if num_frames > 49:
+ raise ValueError(
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents.
+ latent_channels = 16#self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ print('height width', height, width)
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ video_condition_model_input1 = torch.cat([video_condition] * 2) if do_classifier_free_guidance else video_condition
+ if video_condition2 is not None:
+ video_condition_model_input2 = torch.cat([video_condition2] * 2) if do_classifier_free_guidance else video_condition
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ if video_condition2 is not None:
+ print("latent_model_input", latent_model_input.shape, "video_condition_model_input2", video_condition_model_input2.shape)
+ latent_model_input2 = torch.cat([latent_model_input, video_condition_model_input2], dim=2)
+ else:
+ print("latent_model_input", latent_model_input.shape)
+ latent_model_input2 = latent_model_input
+
+ print("video_condition_model_input1", video_condition_model_input1.shape)
+ print("latent_model_input2", latent_model_input2.shape)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ print('image rotary emb is None:', image_rotary_emb is None)
+
+ residual_hidden_states = self.controlnet_transformer(
+ hidden_states=latent_model_input2,# 32
+ control_condition=video_condition_model_input1, # 16
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb, # 2b is False, 5b is True
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input2,
+ encoder_hidden_states=prompt_embeds,
+ residual_hidden_states=residual_hidden_states,
+ interval=42//6, # 30 layers, 6 control blocks
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ print(noise_pred.shape)
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.to(prompt_embeds.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CogVideoXPipelineOutput(frames=video)
diff --git a/test_demo_videos_controlnet.py b/test_demo_videos_controlnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..57027d04b8ff4e682d9e1d9c1dc1c13c13340a1c
--- /dev/null
+++ b/test_demo_videos_controlnet.py
@@ -0,0 +1,231 @@
+import cv2
+import torch
+import argparse
+import numpy as np
+import os
+from control_cogvideox.cogvideox_transformer_3d import CogVideoXTransformer3DModel
+from control_cogvideox.controlnet_cogvideox_transformer_3d import ControlCogVideoXTransformer3DModel
+from pipeline_cogvideox_controlnet_5b_i2v_instruction2 import ControlCogVideoXPipeline
+from diffusers.utils import export_to_video
+from diffusers import AutoencoderKLCogVideoX
+from transformers import T5EncoderModel, T5Tokenizer
+from diffusers.schedulers import CogVideoXDDIMScheduler
+from safetensors.torch import load_file
+from omegaconf import OmegaConf
+from transformers import T5EncoderModel
+from einops import rearrange
+from decord import VideoReader
+import transformers
+from transformers import CLIPTextModel, CLIPProcessor, CLIPVisionModel, CLIPTokenizer
+from PIL import Image
+import torch.nn.functional as F
+
+from dataset_demo_videos import VideoDataset
+
+def unwarp_model(state_dict):
+ new_state_dict = {}
+ for key in state_dict:
+ new_state_dict[key.split('module.')[1]] = state_dict[key]
+ return new_state_dict
+
+"""
+def transform_tensor_to_images(images):
+ images = images.cpu().detach().numpy()
+ images = np.uint8(images)
+ images2 = []
+ for image in images:
+ image = Image.fromarray(image)
+ images2.append(image)
+ return images2
+"""
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--pos_prompt", type=str, default="")
+parser.add_argument("--neg_prompt", type=str, default="")
+parser.add_argument("--training_steps", type=int, default=30001)
+parser.add_argument("--root_path", type=str, default="./models_half")
+parser.add_argument("--i2v", action="store_true",default=True)
+parser.add_argument("--guidance_scale", type=float, default=4.0)
+parser.add_argument("--random_seed", type=int, default=0)
+args = parser.parse_args()
+
+#-----------------------------------------------------------------
+prefix = args.root_path.replace("/","_").replace(".","_") + "_" + args.pos_prompt.replace(" ","_").replace(".","_")
+#-----------------------------------------------------------------
+
+
+if args.i2v:
+ key = "i2v"
+else:
+ key = "t2v"
+noise_scheduler = CogVideoXDDIMScheduler(
+ **OmegaConf.to_container(
+ OmegaConf.load(f"./cogvideox-5b-{key}/scheduler/scheduler_config.json")
+ )
+)
+
+text_encoder = T5EncoderModel.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="text_encoder", torch_dtype=torch.float16)#.to("cuda:0")
+vae = AutoencoderKLCogVideoX.from_pretrained(f"./cogvideox-5b-{key}/", subfolder="vae", torch_dtype=torch.float16).to("cuda:0")
+tokenizer = T5Tokenizer.from_pretrained(f"./cogvideox-5b-{key}/tokenizer", torch_dtype=torch.float16)
+
+
+config = OmegaConf.to_container(
+ OmegaConf.load(f"./cogvideox-5b-{key}/transformer/config.json")
+)
+if args.i2v:
+ config["in_channels"] = 32
+else:
+ config["in_channels"] = 16
+transformer = CogVideoXTransformer3DModel(**config)
+
+control_config = OmegaConf.to_container(
+ OmegaConf.load(f"./cogvideox-5b-{key}/transformer/config.json")
+)
+if args.i2v:
+ control_config["in_channels"] = 32
+else:
+ control_config["in_channels"] = 16
+control_config['num_layers'] = 6
+control_config['control_in_channels'] = 16
+controlnet_transformer = ControlCogVideoXTransformer3DModel(**control_config)
+
+all_state_dicts = torch.load("{args.root_path}/ff_controlnet_half.pth", map_location="cpu",weights_only=True)
+transformer_state_dict = unwarp_model(all_state_dicts["transformer_state_dict"])
+controlnet_transformer_state_dict = unwarp_model(all_state_dicts["controlnet_transformer_state_dict"])
+
+transformer.load_state_dict(transformer_state_dict, strict=True)
+controlnet_transformer.load_state_dict(controlnet_transformer_state_dict, strict=True)
+
+transformer = transformer.half().to("cuda:0")
+controlnet_transformer = controlnet_transformer.half().to("cuda:0")
+
+vae = vae.eval()
+text_encoder = text_encoder.eval()
+transformer = transformer.eval()
+controlnet_transformer = controlnet_transformer.eval()
+
+pipe = ControlCogVideoXPipeline(tokenizer,
+ text_encoder,
+ vae,
+ transformer,
+ noise_scheduler,
+ controlnet_transformer,
+)#.to("cuda:0")
+
+pipe.vae.enable_slicing()
+pipe.vae.enable_tiling()
+pipe.enable_model_cpu_offload()
+
+def inference(prefix, source_images, \
+ target_images, \
+ text_prompt, negative_prompt, \
+ pipe, vae, \
+ step, guidance_scale, \
+ target_path, video_dir, \
+ h, w, random_seed):
+
+ source_pixel_values = source_images/127.5 - 1.0
+ source_pixel_values = source_pixel_values.to(torch.float16).to("cuda:0")
+ if target_images is not None:
+ target_pixel_values = target_images/127.5 - 1.0
+ target_pixel_values = target_pixel_values.to(torch.float16).to("cuda:0")
+ bsz,f,h,w,c = source_pixel_values.shape
+
+ with torch.no_grad():
+ source_pixel_values = rearrange(source_pixel_values, "b f w h c -> b c f w h")
+ source_latents = vae.encode(source_pixel_values).latent_dist.sample()
+ source_latents = source_latents.to(torch.float16)
+ source_latents = source_latents * vae.config.scaling_factor
+ source_latents = rearrange(source_latents, "b c f h w -> b f c h w")
+
+ if target_images is not None:
+ target_pixel_values = rearrange(target_pixel_values, "b f w h c -> b c f w h")
+ images = target_pixel_values[:,:,:1,...]
+ image_latents = vae.encode(images).latent_dist.sample()
+ image_latents = image_latents.to(torch.float16)
+ image_latents = image_latents * vae.config.scaling_factor
+ image_latents = rearrange(image_latents, "b c f h w -> b f c h w")
+ image_latents = torch.cat([image_latents, torch.zeros_like(source_latents)[:,1:]],dim=1)
+ latents = torch.cat([image_latents, source_latents], dim=2)
+ else:
+ image_latents = None
+ latents = source_latents
+
+ video = pipe(
+ prompt = text_prompt,
+ negative_prompt = negative_prompt,
+ video_condition = source_latents, # input to controlnet
+ video_condition2 = image_latents, # concat with latents
+ height = h,
+ width = w,
+ num_frames = f,
+ num_inference_steps = 50,
+ interval = 6,
+ guidance_scale = guidance_scale,
+ generator = torch.Generator(device=f"cuda:0").manual_seed(random_seed)
+ ).frames[0]
+
+ def transform_tensor_to_images(images):
+ images = images.cpu().detach().numpy()
+ images = np.uint8(images)
+ images2 = []
+ for image in images:
+ image = Image.fromarray(image)
+ images2.append(image)
+ return images2
+
+ source_images = transform_tensor_to_images(source_images[0])
+
+ os.makedirs(f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}", exist_ok=True)
+ export_to_video(video, f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}/output_{random_seed}.mp4", fps=8)
+ export_to_video(source_images, f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}/output_{random_seed}_org.mp4", fps=8)
+
+def read_video(video_path, h, w):
+ vr = VideoReader(video_path)
+ images = vr.get_batch(list(range(min(33, len(vr))))).asnumpy()
+ images2 = []
+ for image in images:
+ image = cv2.resize(image, (h,w))
+ images2.append(image)
+ images2 = np.array(images2)
+ images = images2
+ del vr
+ images = torch.from_numpy(images)
+ return images
+
+def resize(images, h, w):
+ images = rearrange(images, "f w h c -> f c w h")
+ images = F.interpolate(images, (h, w), mode="bilinear")
+ images = rearrange(images, "f c w h -> f w h c")
+ images = images[None,...]
+ return images
+
+h = 448
+w = 768
+
+root_dir = 'additional_videos8'
+dataset = VideoDataset(root_dir)
+print(len(dataset))
+for step, sample in enumerate(dataset):
+ image = sample['image'] # w h c
+ images = sample['frames'] # f w h c
+ pos_prompt = sample['pos_prompt']
+ neg_prompt = sample['neg_prompt']
+ image_path = sample['image_path']
+ prefix = image_path.replace("/","_")
+
+ source_images = images[None,...]
+ target_images = image[None,None,...]
+
+ print(pos_prompt, neg_prompt)
+ print(source_images.shape, torch.min(source_images), torch.max(source_images))
+ print(target_images.shape, torch.min(target_images), torch.max(target_images))
+ target_path = f"demo_first_frame_controlnet_33_stride_2_new_videos_8/{prefix}/"
+ random_seeds = [args.random_seed]
+ for random_seed in random_seeds:
+ inference("", source_images, \
+ target_images, pos_prompt, \
+ neg_prompt, pipe, vae, \
+ args.training_steps, args.guidance_scale, \
+ target_path, "", \
+ h, w, random_seed)
diff --git a/test_demo_videos_controlnet_long_video2.py b/test_demo_videos_controlnet_long_video2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9447567f6457b26bd6b989562916b8a5bc50a7c4
--- /dev/null
+++ b/test_demo_videos_controlnet_long_video2.py
@@ -0,0 +1,265 @@
+import cv2
+import torch
+import argparse
+import numpy as np
+import os
+from control_cogvideox.cogvideox_transformer_3d import CogVideoXTransformer3DModel
+from control_cogvideox.controlnet_cogvideox_transformer_3d import ControlCogVideoXTransformer3DModel
+from pipeline_cogvideox_controlnet_5b_i2v_instruction import ControlCogVideoXPipeline
+from diffusers.utils import export_to_video
+from diffusers import AutoencoderKLCogVideoX
+from transformers import T5EncoderModel, T5Tokenizer
+from diffusers.schedulers import CogVideoXDDIMScheduler
+from safetensors.torch import load_file
+from omegaconf import OmegaConf
+from transformers import T5EncoderModel
+from einops import rearrange
+from torchvision import transforms
+from decord import VideoReader
+import transformers
+from transformers import CLIPTextModel, CLIPProcessor, CLIPVisionModel, CLIPTokenizer
+from PIL import Image
+import torch.nn.functional as F
+
+from dataset_demo_videos import VideoDataset
+
+def unwarp_model(state_dict):
+ new_state_dict = {}
+ for key in state_dict:
+ new_state_dict[key.split('module.')[1]] = state_dict[key]
+ return new_state_dict
+
+"""
+def transform_tensor_to_images(images):
+ images = images.cpu().detach().numpy()
+ images = np.uint8(images)
+ images2 = []
+ for image in images:
+ image = Image.fromarray(image)
+ images2.append(image)
+ return images2
+"""
+parser = argparse.ArgumentParser()
+parser.add_argument("--pos_prompt", type=str, default="")
+parser.add_argument("--neg_prompt", type=str, default="")
+parser.add_argument("--training_steps", type=int, default=1601)
+parser.add_argument("--root_path", type=str, default="")
+parser.add_argument("--i2v", action="store_true")
+parser.add_argument("--guidance_scale", type=float, default=6.0)
+parser.add_argument("--random_seed", type=int, default=0)
+args = parser.parse_args()
+
+#-----------------------------------------------------------------
+prefix = args.root_path.replace("/","_").replace(".","_") + "_" + args.pos_prompt.replace(" ","_").replace(".","_")
+#-----------------------------------------------------------------
+
+
+if args.i2v:
+ key = "i2v"
+else:
+ key = "t2v"
+noise_scheduler = CogVideoXDDIMScheduler(
+ **OmegaConf.to_container(
+ OmegaConf.load(f"../cogvideox-5b-{key}/scheduler/scheduler_config.json")
+ )
+)
+
+text_encoder = T5EncoderModel.from_pretrained(f"../cogvideox-5b-{key}/", subfolder="text_encoder", torch_dtype=torch.float16)
+vae = AutoencoderKLCogVideoX.from_pretrained(f"../cogvideox-5b-{key}/", subfolder="vae", torch_dtype=torch.float16)
+tokenizer = T5Tokenizer.from_pretrained(f"../cogvideox-5b-{key}/tokenizer", torch_dtype=torch.float16)
+
+
+config = OmegaConf.to_container(
+ OmegaConf.load(f"../cogvideox-5b-{key}/transformer/config.json")
+)
+if args.i2v:
+ config["in_channels"] = 32
+else:
+ config["in_channels"] = 16
+transformer = CogVideoXTransformer3DModel(**config)
+
+control_config = OmegaConf.to_container(
+ OmegaConf.load(f"../cogvideox-5b-{key}/transformer/config.json")
+)
+if args.i2v:
+ control_config["in_channels"] = 32
+else:
+ control_config["in_channels"] = 16
+control_config['num_layers'] = 6
+control_config['control_in_channels'] = 16
+controlnet_transformer = ControlCogVideoXTransformer3DModel(**control_config)
+
+all_state_dicts = torch.load(f"{args.root_path}/checkpoints/checkpoint{args.training_steps}.ckpt", map_location="cpu")
+transformer_state_dict = unwarp_model(all_state_dicts["transformer_state_dict"])
+controlnet_transformer_state_dict = unwarp_model(all_state_dicts["controlnet_transformer_state_dict"])
+
+transformer.load_state_dict(transformer_state_dict, strict=True)
+controlnet_transformer.load_state_dict(controlnet_transformer_state_dict, strict=True)
+
+transformer = transformer.to(torch.float16)
+controlnet_transformer = controlnet_transformer.to(torch.float16)
+
+vae = vae.eval()
+text_encoder = text_encoder.eval()
+transformer = transformer.eval()
+controlnet_transformer = controlnet_transformer.eval()
+
+pipe = ControlCogVideoXPipeline(tokenizer,
+ text_encoder,
+ vae,
+ transformer,
+ noise_scheduler,
+ controlnet_transformer
+).to("cuda:0")
+
+pipe.vae.enable_slicing()
+pipe.vae.enable_tiling()
+
+def inference(prefix, source_images, \
+ target_images, \
+ text_prompt, negative_prompt, \
+ pipe, vae, \
+ step, guidance_scale, \
+ target_path, video_dir, \
+ h, w, random_seed):
+
+ source_pixel_values = source_images/127.5 - 1.0
+ source_pixel_values = source_pixel_values.to(torch.float16).to("cuda:0")
+ if target_images is not None:
+ target_pixel_values = target_images/127.5 - 1.0
+ target_pixel_values = target_pixel_values.to(torch.float16).to("cuda:0")
+ bsz,f,h,w,c = source_pixel_values.shape
+
+ with torch.no_grad():
+ source_pixel_values = rearrange(source_pixel_values, "b f w h c -> b c f w h")
+ source_latents = vae.encode(source_pixel_values).latent_dist.sample()
+ source_latents = source_latents.to(torch.float16)
+ source_latents = source_latents * vae.config.scaling_factor
+ source_latents = rearrange(source_latents, "b c f h w -> b f c h w")
+
+ if target_images is not None:
+ target_pixel_values = rearrange(target_pixel_values, "b f w h c -> b c f w h")
+ images = target_pixel_values[:,:,:1,...]
+ image_latents = vae.encode(images).latent_dist.sample()
+ image_latents = image_latents.to(torch.float16)
+ image_latents = image_latents * vae.config.scaling_factor
+ image_latents = rearrange(image_latents, "b c f h w -> b f c h w")
+ image_latents = torch.cat([image_latents, torch.zeros_like(source_latents)[:,1:]],dim=1)
+ latents = torch.cat([image_latents, source_latents], dim=2)
+ else:
+ image_latents = None
+ latents = source_latents
+
+ video = pipe(
+ prompt = text_prompt,
+ negative_prompt = negative_prompt,
+ video_condition = source_latents, # input to controlnet
+ video_condition2 = image_latents, # concat with latents
+ height = h,
+ width = w,
+ num_frames = f,
+ num_inference_steps = 50,
+ interval = 6,
+ guidance_scale = guidance_scale,
+ generator = torch.Generator(device=f"cuda:0").manual_seed(random_seed)
+ ).frames[0]
+
+ def transform_tensor_to_images(images):
+ images = images.cpu().detach().numpy()
+ images = np.uint8(images)
+ images2 = []
+ for image in images:
+ image = Image.fromarray(image)
+ images2.append(image)
+ return images2
+
+ source_images = transform_tensor_to_images(source_images[0])
+
+ return video, source_images
+
+ #os.makedirs(f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}", exist_ok=True)
+ #export_to_video(video, f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}/output_{random_seed}.mp4", fps=8)
+ #export_to_video(source_images, f"./{target_path}/{step}_{prefix}_video_guidance_scale_{guidance_scale}/output_{random_seed}_org.mp4", fps=8)
+
+def read_video(video_path, h, w):
+ vr = VideoReader(video_path)
+ images = vr.get_batch(list(range(min(33, len(vr))))).asnumpy()
+ images2 = []
+ for image in images:
+ image = cv2.resize(image, (h,w))
+ images2.append(image)
+ images2 = np.array(images2)
+ images = images2
+ del vr
+ images = torch.from_numpy(images)
+ return images
+
+def resize(images, h, w):
+ images = rearrange(images, "f w h c -> f c w h")
+ images = F.interpolate(images, (h, w), mode="bilinear")
+ images = rearrange(images, "f c w h -> f w h c")
+ images = images[None,...]
+ return images
+
+h = 448
+w = 768
+
+root_dir = 'long_videos'
+dataset = VideoDataset(root_dir)
+
+for step, sample in enumerate(dataset):
+ image = sample['image'] # w h c
+ images = sample['frames'] # f w h c
+ pos_prompt = sample['pos_prompt']
+ neg_prompt = sample['neg_prompt']
+ image_path = sample['image_path']
+ prefix = image_path.replace("/","_")
+
+ source_images = images[None,...]
+ target_images = image[None,None,...]
+
+ b,f,h,w,c = source_images.shape
+ source_images = source_images[:,:f//33*33,...]
+ print(pos_prompt, neg_prompt)
+ print(source_images.shape, torch.min(source_images), torch.max(source_images))
+ print(target_images.shape, torch.min(target_images), torch.max(target_images))
+ target_path = f"demo_first_frame_controlnet_136_stride_1/{prefix}/"
+ random_seeds = [args.random_seed]
+ for random_seed in random_seeds:
+ editing_video = []
+ source_video = []
+ for i in range(source_images.shape[1]//33):
+ if i == 0:
+ start_fid = 0
+ end_fid = 33
+ else:
+ target_images = video[-1]
+ transform = transforms.Compose([
+ transforms.ToTensor()
+ ])
+ target_images = transform(target_images)
+ target_images = target_images*255.0
+ target_images = rearrange(target_images,"c w h -> w h c")
+ print(target_images.shape)
+ target_images = target_images[None, None, ...]
+ start_fid = i*33 - i
+ end_fid = min((i+1)*33 - i, source_images.shape[1])
+ if start_fid >= end_fid-16:
+ break
+ print(start_fid, end_fid)
+ print("target images", torch.min(target_images), torch.max(target_images))
+ video, source_images2 = inference("", source_images[:, start_fid: end_fid], \
+ target_images, pos_prompt, \
+ neg_prompt, pipe, vae, \
+ args.training_steps, args.guidance_scale, \
+ target_path, "", \
+ h, w, random_seed)
+ if i > 0:
+ video = video[1:]
+ source_images2 = source_images2[1:]
+ editing_video += video
+ source_video += source_images2
+
+ os.makedirs(f"./{target_path}/{args.training_steps}_{prefix}_video_guidance_scale_{args.guidance_scale}", exist_ok=True)
+ export_to_video(editing_video, f"./{target_path}/{args.training_steps}_{prefix}_video_guidance_scale_{args.guidance_scale}/output_{random_seed}.mp4", fps=8)
+ export_to_video(source_video, f"./{target_path}/{args.training_steps}_{prefix}_video_guidance_scale_{args.guidance_scale}/output_{random_seed}_org.mp4", fps=8)