rahul7star commited on
Commit
02c2928
·
verified ·
1 Parent(s): 2e591bf

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +47 -134
optimization.py CHANGED
@@ -42,140 +42,53 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
42
 
43
  @spaces.GPU(duration=1500)
44
  def compile_transformer():
45
-
46
- # pipeline.load_lora_weights(
47
- # "Kijai/WanVideo_comfy",
48
- # weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
49
- # adapter_name="lightning"
50
- # )
51
- # kwargs_lora = {}
52
- # kwargs_lora["load_into_transformer_2"] = True
53
- # pipeline.load_lora_weights(
54
- # "Kijai/WanVideo_comfy",
55
- # weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
56
- # #weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
57
- # adapter_name="lightning_2", **kwargs_lora
58
- # )
59
- # pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])
60
-
61
-
62
- # pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
63
- # pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
64
- # pipeline.unload_lora_weights()
65
-
66
-
67
-
68
- pipeline.load_lora_weights(
69
- "Kijai/WanVideo_comfy",
70
- weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
71
- adapter_name="lightning"
72
- )
73
- kwargs_lora = {}
74
- kwargs_lora["load_into_transformer_2"] = True
75
- # pipeline.load_lora_weights(
76
- # #"drozbay/Wan2.2_A14B_lora_extract",
77
- # "Kijai/WanVideo_comfy",
78
- # #weight_name="MTVCrafter/Wan2_1_MTV-Crafter_motion_adapter_bf16.safetensors",
79
- # #weight_name="Skyreels/Wan2_1_Skyreels-v2-T2V-720P_LoRA_rank_64_fp16.safetensors",
80
- # #weight_name="Pusa/Wan21_PusaV1_LoRA_14B_rank512_bf16.safetensors",
81
- # weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
82
- # adapter_name="lightning_2", **kwargs_lora
83
- # )
84
-
85
-
86
-
87
- pipeline.load_lora_weights(
88
- #"drozbay/Wan2.2_A14B_lora_extract",
89
- "lightx2v/Wan2.2-Lightning",
90
- #weight_name="MTVCrafter/Wan2_1_MTV-Crafter_motion_adapter_bf16.safetensors",
91
- #weight_name="Skyreels/Wan2_1_Skyreels-v2-T2V-720P_LoRA_rank_64_fp16.safetensors",
92
- weight_name="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1/low_noise_model.safetensors",
93
- #weight_name="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors",
94
- adapter_name="lightning_2", **kwargs_lora
95
- )
96
-
97
- # pipeline.load_lora_weights(
98
- # #"drozbay/Wan2.2_A14B_lora_extract",
99
- # "ostris/wan22_i2v_14b_orbit_shot_lora",
100
- # #weight_name="MTVCrafter/Wan2_1_MTV-Crafter_motion_adapter_bf16.safetensors",
101
- # #weight_name="Skyreels/Wan2_1_Skyreels-v2-T2V-720P_LoRA_rank_64_fp16.safetensors",
102
- # weight_name="wan22_14b_i2v_orbit_low_noise.safetensors",
103
- # #weight_name="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors",
104
- # adapter_name="lightning_200", **kwargs_lora
105
- # )
106
-
107
-
108
-
109
- pipeline.load_lora_weights(
110
- #"drozbay/Wan2.2_A14B_lora_extract",
111
- "deadman44/Wan2.2_T2i_T2v_LoRA",
112
- #weight_name="MTVCrafter/Wan2_1_MTV-Crafter_motion_adapter_bf16.safetensors",
113
- #weight_name="Skyreels/Wan2_1_Skyreels-v2-T2V-720P_LoRA_rank_64_fp16.safetensors",
114
- weight_name="lora_wan2.2_myjd_Low_v01.safetensors",
115
- #weight_name="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors",
116
- adapter_name="lightning_22", **kwargs_lora
117
- )
118
-
119
-
120
-
121
-
122
-
123
-
124
-
125
- pipeline.set_adapters(["lightning", "lightning_2","lightning_22"], adapter_weights=[1., 1.,1.])
126
-
127
-
128
- pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
129
- pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
130
- pipeline.fuse_lora(adapter_names=["lightning_22"], lora_scale=1., components=["transformer_2"])
131
-
132
- pipeline.unload_lora_weights()
133
-
134
- with capture_component_call(pipeline, 'transformer') as call:
135
- pipeline(*args, **kwargs)
136
-
137
- dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
138
- dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
139
-
140
- quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
141
- quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
142
-
143
- hidden_states: torch.Tensor = call.kwargs['hidden_states']
144
- hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
145
- if hidden_states.shape[-1] > hidden_states.shape[-2]:
146
- hidden_states_landscape = hidden_states
147
- hidden_states_portrait = hidden_states_transposed
148
- else:
149
- hidden_states_landscape = hidden_states_transposed
150
- hidden_states_portrait = hidden_states
151
-
152
- exported_landscape_1 = torch.export.export(
153
- mod=pipeline.transformer,
154
- args=call.args,
155
- kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
156
- dynamic_shapes=dynamic_shapes,
157
- )
158
-
159
- exported_portrait_2 = torch.export.export(
160
- mod=pipeline.transformer_2,
161
- args=call.args,
162
- kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
163
- dynamic_shapes=dynamic_shapes,
164
- )
165
-
166
- compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
167
- compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
168
-
169
- compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
170
- compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
171
-
172
- return (
173
- compiled_landscape_1,
174
- compiled_landscape_2,
175
- compiled_portrait_1,
176
- compiled_portrait_2,
177
- )
178
-
179
  quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
180
  cl1, cl2, cp1, cp2 = compile_transformer()
181
 
 
42
 
43
  @spaces.GPU(duration=1500)
44
  def compile_transformer():
45
+ # --- LoRA 1: lightning (loads into default transformer) ---
46
+ pipeline.load_lora_weights(
47
+ "Kijai/WanVideo_comfy",
48
+ weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
49
+ adapter_name="lightning",
50
+ )
51
+
52
+ # --- LoRA 2: lightning_2 (loads into transformer_2) ---
53
+ kwargs_lora = {"load_into_transformer_2": True}
54
+ pipeline.load_lora_weights(
55
+ "lightx2v/Wan2.2-Lightning",
56
+ weight_name="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1/low_noise_model.safetensors",
57
+ adapter_name="lightning_2",
58
+ **kwargs_lora,
59
+ )
60
+
61
+ # --- LoRA 3: orbit_shot (the ostris repo you asked for) ---
62
+ # Load into transformer_2 as well (set load_into_transformer_2 True if this adapter targets transformer_2)
63
+ pipeline.load_lora_weights(
64
+ "ostris/wan22_i2v_14b_orbit_shot_lora",
65
+ weight_name="wan22_14b_i2v_orbit_low_noise.safetensors",
66
+ adapter_name="orbit_shot",
67
+ **kwargs_lora,
68
+ )
69
+
70
+ # Register adapters and their relative weights
71
+ # (adjust adapter_weights to taste; here each is weight 1.0)
72
+ pipeline.set_adapters(["lightning", "lightning_2", "orbit_shot"], adapter_weights=[1.0, 1.0, 1.0])
73
+
74
+ # Fuse each adapter into the correct component with chosen lora_scale:
75
+ # - lightning -> transformer (boosted by 3x in your original)
76
+ # - lightning_2 -> transformer_2
77
+ # - orbit_shot -> transformer_2 (or transformer depending on the LoRA design)
78
+ pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
79
+ pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
80
+ pipeline.fuse_lora(adapter_names=["orbit_shot"], lora_scale=1.0, components=["transformer_2"])
81
+
82
+ # After fusing, you can unload LoRA weights to free memory (fused weights remain applied)
83
+ pipeline.unload_lora_weights()
84
+
85
+ # --- then continue with capture_component_call / export / compile logic as you already have ---
86
+ with capture_component_call(pipeline, 'transformer') as call:
87
+ pipeline(*args, **kwargs)
88
+
89
+ # ... rest of your function unchanged ...
90
+
91
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
93
  cl1, cl2, cp1, cp2 = compile_transformer()
94