rahul7star commited on
Commit
949010c
·
verified ·
1 Parent(s): c7f0892

Update optimization.py

Browse files
Files changed (1) hide show
  1. optimization.py +122 -51
optimization.py CHANGED
@@ -40,57 +40,128 @@ INDUCTOR_CONFIGS = {
40
 
41
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
42
 
43
- @spaces.GPU(duration=1500)
44
- def compile_transformer():
45
- # --- LoRA 1: lightning (loads into default transformer) ---
46
- pipeline.load_lora_weights(
47
- "Kijai/WanVideo_comfy",
48
- weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
49
- adapter_name="lightning",
50
- )
51
-
52
- # --- LoRA 2: lightning_2 (loads into transformer_2) ---
53
- kwargs_lora = {"load_into_transformer_2": True}
54
- pipeline.load_lora_weights(
55
- "deadman44/Wan2.2_T2i_T2v_LoRA",
56
- weight_name="lora_wan2.2_myjd_Low_v01.safetensors",
57
- adapter_name="lightning_2",
58
- **kwargs_lora,
59
- )
60
-
61
- # --- LoRA 3: orbit_shot (the ostris repo you asked for) ---
62
- # Load into transformer_2 as well (set load_into_transformer_2 True if this adapter targets transformer_2)
63
- # pipeline.load_lora_weights(
64
- # "ostris/wan22_i2v_14b_orbit_shot_lora",
65
- # weight_name="wan22_14b_i2v_orbit_low_noise.safetensors",
66
- # adapter_name="orbit_shot",
67
- # **kwargs_lora,
68
- # )
69
-
70
- # Register adapters and their relative weights
71
- # (adjust adapter_weights to taste; here each is weight 1.0)
72
- #pipeline.set_adapters(["lightning", "lightning_2", "orbit_shot"], adapter_weights=[1.0, 1.0, 1.0])
73
- pipeline.set_adapters(["lightning", "lightning_2", "orbit_shot"], adapter_weights=[1., 1.])
74
-
75
-
76
- # Fuse each adapter into the correct component with chosen lora_scale:
77
- # - lightning -> transformer (boosted by 3x in your original)
78
- # - lightning_2 -> transformer_2
79
- # - orbit_shot -> transformer_2 (or transformer depending on the LoRA design)
80
- pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3.0, components=["transformer"])
81
- pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1.0, components=["transformer_2"])
82
- #pipeline.fuse_lora(adapter_names=["orbit_shot"], lora_scale=1.0, components=["transformer_2"])
83
-
84
- # After fusing, you can unload LoRA weights to free memory (fused weights remain applied)
85
- pipeline.unload_lora_weights()
86
-
87
- # --- then continue with capture_component_call / export / compile logic as you already have ---
88
- with capture_component_call(pipeline, 'transformer') as call:
89
- pipeline(*args, **kwargs)
90
-
91
- # ... rest of your function unchanged ...
92
-
93
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
95
  cl1, cl2, cp1, cp2 = compile_transformer()
96
 
 
40
 
41
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
42
 
43
+ @spaces.GPU(duration=1500)
44
+ def compile_transformer():
45
+
46
+ # pipeline.load_lora_weights(
47
+ # "Kijai/WanVideo_comfy",
48
+ # weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
49
+ # adapter_name="lightning"
50
+ # )
51
+ # kwargs_lora = {}
52
+ # kwargs_lora["load_into_transformer_2"] = True
53
+ # pipeline.load_lora_weights(
54
+ # "Kijai/WanVideo_comfy",
55
+ # weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
56
+ # #weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
57
+ # adapter_name="lightning_2", **kwargs_lora
58
+ # )
59
+ # pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])
60
+
61
+
62
+ # pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
63
+ # pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
64
+ # pipeline.unload_lora_weights()
65
+
66
+
67
+
68
+ pipeline.load_lora_weights(
69
+ "Kijai/WanVideo_comfy",
70
+ weight_name="Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank128_bf16.safetensors",
71
+ adapter_name="lightning"
72
+ )
73
+ kwargs_lora = {}
74
+ kwargs_lora["load_into_transformer_2"] = True
75
+ # pipeline.load_lora_weights(
76
+ # #"drozbay/Wan2.2_A14B_lora_extract",
77
+ # "Kijai/WanVideo_comfy",
78
+ # #weight_name="MTVCrafter/Wan2_1_MTV-Crafter_motion_adapter_bf16.safetensors",
79
+ # #weight_name="Skyreels/Wan2_1_Skyreels-v2-T2V-720P_LoRA_rank_64_fp16.safetensors",
80
+ # #weight_name="Pusa/Wan21_PusaV1_LoRA_14B_rank512_bf16.safetensors",
81
+ # weight_name="Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors",
82
+ # adapter_name="lightning_2", **kwargs_lora
83
+ # )
84
+
85
+
86
+
87
+ pipeline.load_lora_weights(
88
+ #"drozbay/Wan2.2_A14B_lora_extract",
89
+ "deadman44/Wan2.2_T2i_T2v_LoRA",
90
+ #weight_name="MTVCrafter/Wan2_1_MTV-Crafter_motion_adapter_bf16.safetensors",
91
+ #weight_name="Skyreels/Wan2_1_Skyreels-v2-T2V-720P_LoRA_rank_64_fp16.safetensors",
92
+ weight_name="lora_wan2.2_myjd_Low_v01.safetensors",
93
+ #weight_name="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors",
94
+ adapter_name="lightning_2", **kwargs_lora
95
+ )
96
+
97
+ # pipeline.load_lora_weights(
98
+ # #"drozbay/Wan2.2_A14B_lora_extract",
99
+ # "ostris/wan22_i2v_14b_orbit_shot_lora",
100
+ # #weight_name="MTVCrafter/Wan2_1_MTV-Crafter_motion_adapter_bf16.safetensors",
101
+ # #weight_name="Skyreels/Wan2_1_Skyreels-v2-T2V-720P_LoRA_rank_64_fp16.safetensors",
102
+ # weight_name="wan22_14b_i2v_orbit_low_noise.safetensors",
103
+ # #weight_name="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors",
104
+ # adapter_name="lightning_200", **kwargs_lora
105
+ # )
106
+
107
+
108
+
109
+
110
+
111
+
112
+ pipeline.set_adapters(["lightning", "lightning_2"], adapter_weights=[1., 1.])
113
+
114
+
115
+ pipeline.fuse_lora(adapter_names=["lightning"], lora_scale=3., components=["transformer"])
116
+ pipeline.fuse_lora(adapter_names=["lightning_2"], lora_scale=1., components=["transformer_2"])
117
+
118
+ pipeline.unload_lora_weights()
119
+
120
+ with capture_component_call(pipeline, 'transformer') as call:
121
+ pipeline(*args, **kwargs)
122
+
123
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
124
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
125
+
126
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
127
+ quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
128
+
129
+ hidden_states: torch.Tensor = call.kwargs['hidden_states']
130
+ hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
131
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
132
+ hidden_states_landscape = hidden_states
133
+ hidden_states_portrait = hidden_states_transposed
134
+ else:
135
+ hidden_states_landscape = hidden_states_transposed
136
+ hidden_states_portrait = hidden_states
137
+
138
+ exported_landscape_1 = torch.export.export(
139
+ mod=pipeline.transformer,
140
+ args=call.args,
141
+ kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
142
+ dynamic_shapes=dynamic_shapes,
143
+ )
144
+
145
+ exported_portrait_2 = torch.export.export(
146
+ mod=pipeline.transformer_2,
147
+ args=call.args,
148
+ kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
149
+ dynamic_shapes=dynamic_shapes,
150
+ )
151
+
152
+ compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
153
+ compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
154
+
155
+ compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
156
+ compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
157
+
158
+ return (
159
+ compiled_landscape_1,
160
+ compiled_landscape_2,
161
+ compiled_portrait_1,
162
+ compiled_portrait_2,
163
+ )
164
+
165
  quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
166
  cl1, cl2, cp1, cp2 = compile_transformer()
167