latostadaok commited on
Commit
0dcb062
·
1 Parent(s): 866d9b3

Add support for pruned model and update generation modes in app.py

Browse files
Files changed (2) hide show
  1. .gitignore +8 -0
  2. app.py +30 -13
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Virtual Environments
2
+ venv/
3
+ env/
4
+ .venv/
5
+
6
+ # Python cache
7
+ __pycache__/
8
+ *.pyc
app.py CHANGED
@@ -89,6 +89,12 @@ pipe = DiffusionPipeline.from_pretrained(
89
  base_model, scheduler=scheduler, torch_dtype=dtype
90
  ).to(device)
91
 
 
 
 
 
 
 
92
  # Lightning LoRA info (no global state)
93
  LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
94
  LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
@@ -162,17 +168,19 @@ def handle_speed_mode(speed_mode):
162
  """Update UI based on speed/quality toggle."""
163
  if speed_mode == "Speed (8 steps)":
164
  return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
 
 
165
  else:
166
  return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
167
 
168
  @spaces.GPU(duration=70)
169
- def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
170
- pipe.to("cuda")
171
  generator = torch.Generator(device="cuda").manual_seed(seed)
172
 
173
  with calculateDuration("Generating image"):
174
  # Generate image
175
- image = pipe(
176
  prompt=prompt_mash,
177
  negative_prompt=negative_prompt,
178
  num_inference_steps=steps,
@@ -205,15 +213,21 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
205
  else:
206
  prompt_mash = prompt
207
 
 
 
 
 
 
 
208
  # Always unload any existing LoRAs first to avoid conflicts
209
  with calculateDuration("Unloading existing LoRAs"):
210
- pipe.unload_lora_weights()
211
 
212
  # Load LoRAs based on speed mode
213
  if speed_mode == "Speed (8 steps)":
214
  with calculateDuration("Loading Lightning LoRA and style LoRA"):
215
  # Load Lightning LoRA first
216
- pipe.load_lora_weights(
217
  LIGHTNING_LORA_REPO,
218
  weight_name=LIGHTNING_LORA_WEIGHT,
219
  adapter_name="lightning"
@@ -221,7 +235,7 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
221
 
222
  # Load the selected style LoRA
223
  weight_name = selected_lora.get("weights", None)
224
- pipe.load_lora_weights(
225
  lora_path,
226
  weight_name=weight_name,
227
  low_cpu_mem_usage=True,
@@ -229,18 +243,21 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
229
  )
230
 
231
  # Set both adapters active with their weights
232
- pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
233
  else:
234
- # Quality mode - only load the style LoRA
235
- with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
 
 
 
236
  weight_name = selected_lora.get("weights", None)
237
- pipe.load_lora_weights(
238
  lora_path,
239
  weight_name=weight_name,
240
  low_cpu_mem_usage=True,
241
  adapter_name="style"
242
  )
243
- pipe.set_adapters(["style"], adapter_weights=[lora_scale])
244
 
245
  # Set random seed for reproducibility
246
  with calculateDuration("Randomizing seed"):
@@ -251,7 +268,7 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
251
  width, height = get_image_size(aspect_ratio)
252
 
253
  # Generate the image
254
- final_image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale)
255
 
256
  return final_image, seed
257
 
@@ -433,7 +450,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as app:
433
  with gr.Row():
434
  speed_mode = gr.Radio(
435
  label="Generation Mode",
436
- choices=["Speed (8 steps)", "Quality (45 steps)"],
437
  value="Speed (8 steps)",
438
  info="Speed mode uses Lightning LoRA for faster generation"
439
  )
 
89
  base_model, scheduler=scheduler, torch_dtype=dtype
90
  ).to(device)
91
 
92
+ # Pruned model
93
+ pruned_model = "OPPOer/Qwen-Image-Pruning"
94
+ pruned_pipe = DiffusionPipeline.from_pretrained(
95
+ pruned_model, scheduler=scheduler, torch_dtype=dtype
96
+ ).to(device)
97
+
98
  # Lightning LoRA info (no global state)
99
  LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
100
  LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
 
168
  """Update UI based on speed/quality toggle."""
169
  if speed_mode == "Speed (8 steps)":
170
  return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
171
+ elif speed_mode == "Prune (8 steps)":
172
+ return gr.update(value="Prune mode selected - 8 steps with Pruned Model"), 8, 1.0
173
  else:
174
  return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
175
 
176
  @spaces.GPU(duration=70)
177
+ def generate_image(current_pipe, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
178
+ current_pipe.to("cuda")
179
  generator = torch.Generator(device="cuda").manual_seed(seed)
180
 
181
  with calculateDuration("Generating image"):
182
  # Generate image
183
+ image = current_pipe(
184
  prompt=prompt_mash,
185
  negative_prompt=negative_prompt,
186
  num_inference_steps=steps,
 
213
  else:
214
  prompt_mash = prompt
215
 
216
+ # Select the pipeline based on the mode
217
+ if speed_mode == "Prune (8 steps)":
218
+ current_pipe = pruned_pipe
219
+ else:
220
+ current_pipe = pipe
221
+
222
  # Always unload any existing LoRAs first to avoid conflicts
223
  with calculateDuration("Unloading existing LoRAs"):
224
+ current_pipe.unload_lora_weights()
225
 
226
  # Load LoRAs based on speed mode
227
  if speed_mode == "Speed (8 steps)":
228
  with calculateDuration("Loading Lightning LoRA and style LoRA"):
229
  # Load Lightning LoRA first
230
+ current_pipe.load_lora_weights(
231
  LIGHTNING_LORA_REPO,
232
  weight_name=LIGHTNING_LORA_WEIGHT,
233
  adapter_name="lightning"
 
235
 
236
  # Load the selected style LoRA
237
  weight_name = selected_lora.get("weights", None)
238
+ current_pipe.load_lora_weights(
239
  lora_path,
240
  weight_name=weight_name,
241
  low_cpu_mem_usage=True,
 
243
  )
244
 
245
  # Set both adapters active with their weights
246
+ current_pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
247
  else:
248
+ # Quality or Prune mode - only load the style LoRA
249
+ log_message = f"Loading LoRA weights for {selected_lora['title']}"
250
+ if speed_mode == "Prune (8 steps)":
251
+ log_message += " on Pruned Model"
252
+ with calculateDuration(log_message):
253
  weight_name = selected_lora.get("weights", None)
254
+ current_pipe.load_lora_weights(
255
  lora_path,
256
  weight_name=weight_name,
257
  low_cpu_mem_usage=True,
258
  adapter_name="style"
259
  )
260
+ current_pipe.set_adapters(["style"], adapter_weights=[lora_scale])
261
 
262
  # Set random seed for reproducibility
263
  with calculateDuration("Randomizing seed"):
 
268
  width, height = get_image_size(aspect_ratio)
269
 
270
  # Generate the image
271
+ final_image = generate_image(current_pipe, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale)
272
 
273
  return final_image, seed
274
 
 
450
  with gr.Row():
451
  speed_mode = gr.Radio(
452
  label="Generation Mode",
453
+ choices=["Speed (8 steps)", "Quality (45 steps)", "Prune (8 steps)"],
454
  value="Speed (8 steps)",
455
  info="Speed mode uses Lightning LoRA for faster generation"
456
  )