lichorosario commited on
Commit
dcaef81
·
verified ·
1 Parent(s): 2e4c141

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -5
app.py CHANGED
@@ -105,6 +105,7 @@ pipe = DiffusionPipeline.from_pretrained(
105
  # Lightning LoRA info (no global state)
106
  LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
107
  LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors"
 
108
 
109
  MAX_SEED = np.iinfo(np.int32).max
110
 
@@ -175,7 +176,9 @@ def update_selection(evt: gr.SelectData, aspect_ratio):
175
 
176
  def handle_speed_mode(speed_mode):
177
  """Update UI based on speed/quality toggle."""
178
- if speed_mode == "Speed (8 steps)":
 
 
179
  return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
180
  else:
181
  return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
@@ -225,7 +228,7 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
225
  pipe.unload_lora_weights()
226
 
227
  # Load LoRAs based on speed mode
228
- if speed_mode == "Speed (8 steps)":
229
  with calculateDuration("Loading Lightning LoRA and style LoRA"):
230
  # Load Lightning LoRA first
231
  pipe.load_lora_weights(
@@ -243,6 +246,26 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
243
  adapter_name="style"
244
  )
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  # Set both adapters active with their weights
247
  pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
248
  else:
@@ -448,8 +471,8 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as app:
448
  with gr.Row():
449
  speed_mode = gr.Radio(
450
  label="Generation Mode",
451
- choices=["Speed (8 steps)", "Quality (45 steps)"],
452
- value="Speed (8 steps)",
453
  info="Speed mode uses Lightning LoRA for faster generation"
454
  )
455
 
@@ -521,7 +544,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as app:
521
 
522
  app.load(
523
  fn=handle_speed_mode,
524
- inputs=[gr.State("Speed (8 steps)")],
525
  outputs=[speed_status, steps, cfg_scale]
526
  )
527
 
 
105
  # Lightning LoRA info (no global state)
106
  LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
107
  LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors"
108
+ LIGHTNING8_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V2.0-bf16.safetensors"
109
 
110
  MAX_SEED = np.iinfo(np.int32).max
111
 
 
176
 
177
  def handle_speed_mode(speed_mode):
178
  """Update UI based on speed/quality toggle."""
179
+ if speed_mode == "Speed (4 steps)":
180
+ return gr.update(value="Speed mode selected - 4 steps with Lightning LoRA"), 4, 1.0
181
+ elif speed_mode == "Speed (8 steps)":
182
  return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
183
  else:
184
  return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
 
228
  pipe.unload_lora_weights()
229
 
230
  # Load LoRAs based on speed mode
231
+ if speed_mode == "Speed (4 steps)":
232
  with calculateDuration("Loading Lightning LoRA and style LoRA"):
233
  # Load Lightning LoRA first
234
  pipe.load_lora_weights(
 
246
  adapter_name="style"
247
  )
248
 
249
+ # Set both adapters active with their weights
250
+ pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
251
+ elif speed_mode == "Speed (8 steps)":
252
+ with calculateDuration("Loading Lightning LoRA and style LoRA"):
253
+ # Load Lightning LoRA first
254
+ pipe.load_lora_weights(
255
+ LIGHTNING_LORA_REPO,
256
+ weight_name=LIGHTNING8_LORA_WEIGHT,
257
+ adapter_name="lightning"
258
+ )
259
+
260
+ # Load the selected style LoRA
261
+ weight_name = selected_lora.get("weights", None)
262
+ pipe.load_lora_weights(
263
+ lora_path,
264
+ weight_name=weight_name,
265
+ low_cpu_mem_usage=True,
266
+ adapter_name="style"
267
+ )
268
+
269
  # Set both adapters active with their weights
270
  pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
271
  else:
 
471
  with gr.Row():
472
  speed_mode = gr.Radio(
473
  label="Generation Mode",
474
+ choices=["Speed (4 steps)", "Speed (8 steps)", "Quality (45 steps)"],
475
+ value="Speed (4 steps)",
476
  info="Speed mode uses Lightning LoRA for faster generation"
477
  )
478
 
 
544
 
545
  app.load(
546
  fn=handle_speed_mode,
547
+ inputs=[gr.State("Speed (4 steps)")],
548
  outputs=[speed_status, steps, cfg_scale]
549
  )
550