lichorosario commited on
Commit
7c466dd
·
verified ·
1 Parent(s): 5244c15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -164
app.py CHANGED
@@ -113,11 +113,13 @@ def update_selection(evt: gr.SelectData, aspect_ratio):
113
  lora_repo = selected_lora["repo"]
114
  updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
115
 
 
116
  examples_list = []
117
  try:
118
  model_card = ModelCard.load(lora_repo)
119
  widget_data = model_card.data.get("widget", [])
120
  if widget_data and len(widget_data) > 0:
 
121
  for example in widget_data[:4]:
122
  if "output" in example and "url" in example["output"]:
123
  image_url = f"https://huggingface.co/{lora_repo}/resolve/main/{example['output']['url']}"
@@ -126,6 +128,7 @@ def update_selection(evt: gr.SelectData, aspect_ratio):
126
  except Exception as e:
127
  print(f"Could not load model card for {lora_repo}: {e}")
128
 
 
129
  if "aspect" in selected_lora:
130
  if selected_lora["aspect"] == "portrait":
131
  aspect_ratio = "9:16"
@@ -158,11 +161,12 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
158
  generator = torch.Generator(device="cuda").manual_seed(seed)
159
 
160
  with calculateDuration("Generating image"):
 
161
  image = pipe(
162
  prompt=prompt_mash,
163
  negative_prompt=negative_prompt,
164
  num_inference_steps=steps,
165
- true_cfg_scale=cfg_scale,
166
  width=width,
167
  height=height,
168
  generator=generator,
@@ -171,14 +175,15 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scal
171
  return image
172
 
173
  @spaces.GPU(duration=70)
174
- def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, image_count, progress=gr.Progress(track_tqdm=True)):
175
  if selected_index is None:
176
  raise gr.Error("You must select a LoRA before proceeding.")
177
 
178
  selected_lora = loras[selected_index]
179
  lora_path = selected_lora["repo"]
180
  trigger_word = selected_lora["trigger_word"]
181
-
 
182
  if trigger_word:
183
  if "trigger_position" in selected_lora:
184
  if selected_lora["trigger_position"] == "prepend":
@@ -190,174 +195,63 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
190
  else:
191
  prompt_mash = prompt
192
 
 
193
  with calculateDuration("Unloading existing LoRAs"):
194
  pipe.unload_lora_weights()
195
 
 
196
  if speed_mode == "Speed (4 steps)":
197
- pipe.load_lora_weights(LIGHTNING_LORA_REPO, weight_name=LIGHTNING_LORA_WEIGHT, adapter_name="lightning")
198
- weight_name = selected_lora.get("weights", None)
199
- pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name="style")
200
- pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
 
 
 
 
 
 
 
 
 
 
201
  elif speed_mode == "Speed (8 steps)":
202
- pipe.load_lora_weights(LIGHTNING_LORA_REPO, weight_name=LIGHTNING8_LORA_WEIGHT, adapter_name="lightning")
203
- weight_name = selected_lora.get("weights", None)
204
- pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name="style")
205
- pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
206
- else:
207
- weight_name = selected_lora.get("weights", None)
208
- pipe.load_lora_weights(lora_path, weight_name=weight_name, adapter_name="style")
209
- pipe.set_adapters(["style"], adapter_weights=[lora_scale])
210
-
211
- if randomize_seed:
212
- seed = random.randint(0, MAX_SEED)
213
-
214
- width, height = get_image_size(aspect_ratio)
215
-
216
- # ✅ Validación segura del parámetro
217
- num_images = int(image_count) if image_count and str(image_count).isdigit() else 1
218
- seed_offsets = [i * 100 for i in range(num_images)]
219
- images = []
220
-
221
- for offset in seed_offsets:
222
- current_seed = (seed + offset) % MAX_SEED
223
- img = generate_image(prompt_mash, steps, current_seed, cfg_scale, width, height, lora_scale)
224
- images.append(img)
225
-
226
- if num_images == 1:
227
- return images[0], seed
228
  else:
229
- grid_width = min(2, num_images)
230
- grid_height = math.ceil(num_images / grid_width)
231
- new_img = Image.new("RGB", (width * grid_width, height * grid_height))
232
- for i, img in enumerate(images):
233
- x = (i % grid_width) * width
234
- y = (i // grid_width) * height
235
- new_img.paste(img, (x, y))
236
- return new_img, seed
237
-
238
- # --- UI ---
239
- css = '''
240
- #gen_btn{height: 100%}
241
- #gen_column{align-self: stretch}
242
- #title{text-align: center}
243
- #title h1{font-size: 3em; display:inline-flex; align-items:center}
244
- #title img{width: 100px; margin-right: 0.5em}
245
- #gallery .grid-wrap{height: 10vh}
246
- #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
247
- .card_internal{display: flex;height: 100px;margin-top: .5em}
248
- .card_internal img{margin-right: 1em}
249
- .styler{--form-gap-width: 0px !important}
250
- #speed_status{padding: .5em; border-radius: 5px; margin: 1em 0}
251
- '''
252
-
253
- with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as app:
254
- title = gr.HTML(
255
- """<img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_logo.png" alt="Qwen-Image" style="width: 280px; margin: 0 auto">
256
- <h3 style="margin-top: -10px">LoRA🦜 ChoquinLabs Explorer</h3>""",
257
- elem_id="title",
258
- )
259
-
260
- selected_index = gr.State(None)
261
-
262
- with gr.Row():
263
- with gr.Column(scale=3):
264
- prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
265
- with gr.Column(scale=1, elem_id="gen_column"):
266
- with gr.Group():
267
- generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
268
- image_count = gr.Radio(
269
- label="Number of images",
270
- choices=["1", "2", "3", "4"],
271
- value="1",
272
- info="How many images to generate simultaneously"
273
- )
274
-
275
- with gr.Row():
276
- with gr.Column():
277
- selected_info = gr.Markdown("")
278
- gallery = gr.Gallery(
279
- [(item["image"], item["title"]) for item in loras],
280
- label="LoRA Gallery",
281
- allow_preview=False,
282
- columns=3,
283
- elem_id="gallery",
284
- show_share_button=False
285
  )
286
- with gr.Group():
287
- custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="username/qwen-image-custom-lora")
288
- gr.Markdown("[Check Qwen-Image LoRAs](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
289
- custom_lora_info = gr.HTML(visible=False)
290
- custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
291
-
292
- with gr.Column():
293
- result = gr.Image(label="Generated Image")
294
-
295
- with gr.Row():
296
- speed_mode = gr.Radio(
297
- label="Generation Mode",
298
- choices=["Speed (4 steps)", "Speed (8 steps)", "Quality (45 steps)"],
299
- value="Speed (4 steps)",
300
- info="Speed mode uses Lightning LoRA for faster generation"
301
- )
302
-
303
- speed_status = gr.Markdown("Quality mode active", elem_id="speed_status")
304
-
305
- with gr.Row():
306
- with gr.Accordion("Advanced Settings", open=False):
307
- with gr.Column():
308
- with gr.Row():
309
- aspect_ratio = gr.Radio(
310
- label="Aspect Ratio",
311
- choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3", "3:1", "2:1"],
312
- value="16:9"
313
- )
314
- with gr.Row():
315
- cfg_scale = gr.Slider(
316
- label="Guidance Scale (True CFG)",
317
- minimum=1.0,
318
- maximum=5.0,
319
- step=0.1,
320
- value=3.5,
321
- info="Lower for speed mode, higher for quality"
322
- )
323
- steps = gr.Slider(
324
- label="Steps",
325
- minimum=4,
326
- maximum=50,
327
- step=1,
328
- value=45,
329
- info="Automatically set by speed mode"
330
- )
331
- with gr.Row():
332
- randomize_seed = gr.Checkbox(True, label="Randomize seed")
333
- seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
334
- lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=1.0)
335
-
336
- # Event handlers
337
- gallery.select(
338
- update_selection,
339
- inputs=[aspect_ratio],
340
- outputs=[prompt, selected_info, selected_index, aspect_ratio]
341
- )
342
 
343
- speed_mode.change(
344
- handle_speed_mode,
345
- inputs=[speed_mode],
346
- outputs=[speed_status, steps, cfg_scale]
347
- )
348
 
349
- gr.on(
350
- triggers=[generate_button.click, prompt.submit],
351
- fn=run_lora,
352
- inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, image_count],
353
- outputs=[result, seed]
354
- )
355
-
356
- app.load(
357
- fn=handle_speed_mode,
358
- inputs=[gr.State("Speed (4 steps)")],
359
- outputs=[speed_status, steps, cfg_scale]
360
- )
361
 
362
- app.queue()
363
- app.launch()
 
113
  lora_repo = selected_lora["repo"]
114
  updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
115
 
116
+ # Get model card examples
117
  examples_list = []
118
  try:
119
  model_card = ModelCard.load(lora_repo)
120
  widget_data = model_card.data.get("widget", [])
121
  if widget_data and len(widget_data) > 0:
122
+ # Get examples from widget data
123
  for example in widget_data[:4]:
124
  if "output" in example and "url" in example["output"]:
125
  image_url = f"https://huggingface.co/{lora_repo}/resolve/main/{example['output']['url']}"
 
128
  except Exception as e:
129
  print(f"Could not load model card for {lora_repo}: {e}")
130
 
131
+ # Update aspect ratio if specified in LoRA config
132
  if "aspect" in selected_lora:
133
  if selected_lora["aspect"] == "portrait":
134
  aspect_ratio = "9:16"
 
161
  generator = torch.Generator(device="cuda").manual_seed(seed)
162
 
163
  with calculateDuration("Generating image"):
164
+ # Generate image
165
  image = pipe(
166
  prompt=prompt_mash,
167
  negative_prompt=negative_prompt,
168
  num_inference_steps=steps,
169
+ true_cfg_scale=cfg_scale, # Use true_cfg_scale for Qwen-Image
170
  width=width,
171
  height=height,
172
  generator=generator,
 
175
  return image
176
 
177
  @spaces.GPU(duration=70)
178
+ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, progress=gr.Progress(track_tqdm=True)):
179
  if selected_index is None:
180
  raise gr.Error("You must select a LoRA before proceeding.")
181
 
182
  selected_lora = loras[selected_index]
183
  lora_path = selected_lora["repo"]
184
  trigger_word = selected_lora["trigger_word"]
185
+
186
+ # Prepare prompt with trigger word
187
  if trigger_word:
188
  if "trigger_position" in selected_lora:
189
  if selected_lora["trigger_position"] == "prepend":
 
195
  else:
196
  prompt_mash = prompt
197
 
198
+ # Always unload any existing LoRAs first to avoid conflicts
199
  with calculateDuration("Unloading existing LoRAs"):
200
  pipe.unload_lora_weights()
201
 
202
+ # Load LoRAs based on speed mode
203
  if speed_mode == "Speed (4 steps)":
204
+ with calculateDuration("Loading Lightning LoRA and style LoRA"):
205
+ pipe.load_lora_weights(
206
+ LIGHTNING_LORA_REPO,
207
+ weight_name=LIGHTNING_LORA_WEIGHT,
208
+ adapter_name="lightning"
209
+ )
210
+ weight_name = selected_lora.get("weights", None)
211
+ pipe.load_lora_weights(
212
+ lora_path,
213
+ weight_name=weight_name,
214
+ low_cpu_mem_usage=True,
215
+ adapter_name="style"
216
+ )
217
+ pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
218
  elif speed_mode == "Speed (8 steps)":
219
+ with calculateDuration("Loading Lightning LoRA and style LoRA"):
220
+ pipe.load_lora_weights(
221
+ LIGHTNING_LORA_REPO,
222
+ weight_name=LIGHTNING8_LORA_WEIGHT,
223
+ adapter_name="lightning"
224
+ )
225
+ weight_name = selected_lora.get("weights", None)
226
+ pipe.load_lora_weights(
227
+ lora_path,
228
+ weight_name=weight_name,
229
+ low_cpu_mem_usage=True,
230
+ adapter_name="style"
231
+ )
232
+ pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
 
 
 
 
 
 
 
 
 
 
 
 
233
  else:
234
+ with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
235
+ weight_name = selected_lora.get("weights", None)
236
+ pipe.load_lora_weights(
237
+ lora_path,
238
+ weight_name=weight_name,
239
+ low_cpu_mem_usage=True,
240
+ adapter_name="style"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  )
242
+ pipe.set_adapters(["style"], adapter_weights=[lora_scale])
243
+
244
+ # Set random seed for reproducibility
245
+ with calculateDuration("Randomizing seed"):
246
+ if randomize_seed:
247
+ seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ # Get image dimensions from aspect ratio
250
+ width, height = get_image_size(aspect_ratio)
 
 
 
251
 
252
+ # Generate the image
253
+ final_image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale)
254
+
255
+ return final_image, seed
 
 
 
 
 
 
 
 
256
 
257
+ # (resto del código con interfaz Gradio, etc.)