lichorosario commited on
Commit
7ffc60d
verified
1 Parent(s): 7c466dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +297 -1
app.py CHANGED
@@ -202,11 +202,14 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
202
  # Load LoRAs based on speed mode
203
  if speed_mode == "Speed (4 steps)":
204
  with calculateDuration("Loading Lightning LoRA and style LoRA"):
 
205
  pipe.load_lora_weights(
206
  LIGHTNING_LORA_REPO,
207
  weight_name=LIGHTNING_LORA_WEIGHT,
208
  adapter_name="lightning"
209
  )
 
 
210
  weight_name = selected_lora.get("weights", None)
211
  pipe.load_lora_weights(
212
  lora_path,
@@ -214,14 +217,19 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
214
  low_cpu_mem_usage=True,
215
  adapter_name="style"
216
  )
 
 
217
  pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
218
  elif speed_mode == "Speed (8 steps)":
219
  with calculateDuration("Loading Lightning LoRA and style LoRA"):
 
220
  pipe.load_lora_weights(
221
  LIGHTNING_LORA_REPO,
222
  weight_name=LIGHTNING8_LORA_WEIGHT,
223
  adapter_name="lightning"
224
  )
 
 
225
  weight_name = selected_lora.get("weights", None)
226
  pipe.load_lora_weights(
227
  lora_path,
@@ -229,8 +237,11 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
229
  low_cpu_mem_usage=True,
230
  adapter_name="style"
231
  )
 
 
232
  pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
233
  else:
 
234
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
235
  weight_name = selected_lora.get("weights", None)
236
  pipe.load_lora_weights(
@@ -254,4 +265,289 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
254
 
255
  return final_image, seed
256
 
257
- # (resto del c贸digo con interfaz Gradio, etc.)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  # Load LoRAs based on speed mode
203
  if speed_mode == "Speed (4 steps)":
204
  with calculateDuration("Loading Lightning LoRA and style LoRA"):
205
+ # Load Lightning LoRA first
206
  pipe.load_lora_weights(
207
  LIGHTNING_LORA_REPO,
208
  weight_name=LIGHTNING_LORA_WEIGHT,
209
  adapter_name="lightning"
210
  )
211
+
212
+ # Load the selected style LoRA
213
  weight_name = selected_lora.get("weights", None)
214
  pipe.load_lora_weights(
215
  lora_path,
 
217
  low_cpu_mem_usage=True,
218
  adapter_name="style"
219
  )
220
+
221
+ # Set both adapters active with their weights
222
  pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
223
  elif speed_mode == "Speed (8 steps)":
224
  with calculateDuration("Loading Lightning LoRA and style LoRA"):
225
+ # Load Lightning LoRA first
226
  pipe.load_lora_weights(
227
  LIGHTNING_LORA_REPO,
228
  weight_name=LIGHTNING8_LORA_WEIGHT,
229
  adapter_name="lightning"
230
  )
231
+
232
+ # Load the selected style LoRA
233
  weight_name = selected_lora.get("weights", None)
234
  pipe.load_lora_weights(
235
  lora_path,
 
237
  low_cpu_mem_usage=True,
238
  adapter_name="style"
239
  )
240
+
241
+ # Set both adapters active with their weights
242
  pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
243
  else:
244
+ # Quality mode - only load the style LoRA
245
  with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
246
  weight_name = selected_lora.get("weights", None)
247
  pipe.load_lora_weights(
 
265
 
266
  return final_image, seed
267
 
268
+ def get_huggingface_safetensors(link):
269
+ split_link = link.split("/")
270
+ if len(split_link) != 2:
271
+ raise Exception("Invalid Hugging Face repository link format.")
272
+
273
+ print(f"Repository attempted: {split_link}")
274
+
275
+ # Load model card
276
+ model_card = ModelCard.load(link)
277
+ base_model = model_card.data.get("base_model")
278
+ print(f"Base model: {base_model}")
279
+
280
+ # Validate model type (for Qwen-Image)
281
+ acceptable_models = {"Qwen/Qwen-Image"}
282
+
283
+ models_to_check = base_model if isinstance(base_model, list) else [base_model]
284
+
285
+ if not any(model in acceptable_models for model in models_to_check):
286
+ raise Exception("Not a Qwen-Image LoRA!")
287
+
288
+ # Extract image and trigger word
289
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
290
+ trigger_word = model_card.data.get("instance_prompt", "")
291
+ image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
292
+
293
+ # Initialize Hugging Face file system
294
+ fs = HfFileSystem()
295
+ try:
296
+ list_of_files = fs.ls(link, detail=False)
297
+
298
+ # Find safetensors file
299
+ safetensors_name = None
300
+ for file in list_of_files:
301
+ filename = file.split("/")[-1]
302
+ if filename.endswith(".safetensors"):
303
+ safetensors_name = filename
304
+ break
305
+
306
+ if not safetensors_name:
307
+ raise Exception("No valid *.safetensors file found in the repository.")
308
+
309
+ except Exception as e:
310
+ print(e)
311
+ raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
312
+
313
+ return split_link[1], link, safetensors_name, trigger_word, image_url
314
+
315
+ def check_custom_model(link):
316
+ print(f"Checking a custom model on: {link}")
317
+
318
+ if link.endswith('.safetensors'):
319
+ if 'huggingface.co' in link:
320
+ parts = link.split('/')
321
+ try:
322
+ hf_index = parts.index('huggingface.co')
323
+ username = parts[hf_index + 1]
324
+ repo_name = parts[hf_index + 2]
325
+ repo = f"{username}/{repo_name}"
326
+
327
+ safetensors_name = parts[-1]
328
+
329
+ try:
330
+ model_card = ModelCard.load(repo)
331
+ trigger_word = model_card.data.get("instance_prompt", "")
332
+ image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
333
+ image_url = f"https://huggingface.co/{repo}/resolve/main/{image_path}" if image_path else None
334
+ except:
335
+ trigger_word = ""
336
+ image_url = None
337
+
338
+ return repo_name, repo, safetensors_name, trigger_word, image_url
339
+ except:
340
+ raise Exception("Invalid safetensors URL format")
341
+
342
+ if link.startswith("https://"):
343
+ if link.startswith("https://huggingface.co") or link.startswith("https://www.huggingface.co"):
344
+ link_split = link.split("huggingface.co/")
345
+ return get_huggingface_safetensors(link_split[1])
346
+ else:
347
+ return get_huggingface_safetensors(link)
348
+
349
+ def add_custom_lora(custom_lora):
350
+ global loras
351
+ if custom_lora:
352
+ try:
353
+ title, repo, path, trigger_word, image = check_custom_model(custom_lora)
354
+ print(f"Loaded custom LoRA: {repo}")
355
+
356
+ # Get model card examples for custom LoRA
357
+ model_card_examples = ""
358
+ try:
359
+ model_card = ModelCard.load(repo)
360
+ widget_data = model_card.data.get("widget", [])
361
+ if widget_data and len(widget_data) > 0:
362
+ examples_html = '<div style="margin-top: 10px;">'
363
+ examples_html += '<h4 style="margin-bottom: 8px; font-size: 0.9em;">Sample Images:</h4>'
364
+ examples_html += '<div style="display: grid; grid-template-columns: repeat(4, 1fr); gap: 8px;">'
365
+
366
+ for i, example in enumerate(widget_data[:4]):
367
+ if "output" in example and "url" in example["output"]:
368
+ image_url = f"https://huggingface.co/{repo}/resolve/main/{example['output']['url']}"
369
+ caption = example.get("text", f"Example {i+1}")
370
+ examples_html += f'''
371
+ <div style="text-align: center;">
372
+ <img src="{image_url}" style="width: 100%; height: auto; border-radius: 4px;" />
373
+ <p style="font-size: 0.7em; margin: 2px 0;">{caption[:30]}{'...' if len(caption) > 30 else ''}</p>
374
+ </div>
375
+ '''
376
+
377
+ examples_html += '</div></div>'
378
+ model_card_examples = examples_html
379
+ except Exception as e:
380
+ print(f"Could not load model card examples for custom LoRA: {e}")
381
+
382
+ card = f'''
383
+ <div class="custom_lora_card">
384
+ <span>Loaded custom LoRA:</span>
385
+ <div class="card_internal">
386
+ <img src="{image}" />
387
+ <div>
388
+ <h3>{title}</h3>
389
+ <small>{"Using: <code><b>"+trigger_word+"</code></b> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}<br></small>
390
+ </div>
391
+ </div>
392
+ {model_card_examples}
393
+ </div>
394
+ '''
395
+ existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo), None)
396
+ if existing_item_index is None:
397
+ new_item = {
398
+ "image": image,
399
+ "title": title,
400
+ "repo": repo,
401
+ "weights": path,
402
+ "trigger_word": trigger_word
403
+ }
404
+ print(new_item)
405
+ loras.append(new_item)
406
+ existing_item_index = len(loras) - 1 # Get the actual index after adding
407
+
408
+ return gr.update(visible=True, value=card), gr.update(visible=True), gr.Gallery(selected_index=None), f"Custom: {path}", existing_item_index, trigger_word
409
+ except Exception as e:
410
+ full_traceback = traceback.format_exc()
411
+ print(f"Full traceback:\n{full_traceback}")
412
+ gr.Warning(f"Invalid LoRA: either you entered an invalid link, or a non-Qwen-Image LoRA, this was the issue: {e}")
413
+ return gr.update(visible=True, value=f"Invalid LoRA: either you entered an invalid link, a non-Qwen-Image LoRA"), gr.update(visible=True), gr.update(), "", None, ""
414
+ else:
415
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
416
+
417
+ def remove_custom_lora():
418
+ return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, ""
419
+
420
+ run_lora.zerogpu = True
421
+
422
+ css = '''
423
+ #gen_btn{height: 100%}
424
+ #gen_column{align-self: stretch}
425
+ #title{text-align: center}
426
+ #title h1{font-size: 3em; display:inline-flex; align-items:center}
427
+ #title img{width: 100px; margin-right: 0.5em}
428
+ #gallery .grid-wrap{height: 10vh}
429
+ #lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
430
+ .card_internal{display: flex;height: 100px;margin-top: .5em}
431
+ .card_internal img{margin-right: 1em}
432
+ .styler{--form-gap-width: 0px !important}
433
+ #speed_status{padding: .5em; border-radius: 5px; margin: 1em 0}
434
+ '''
435
+
436
+ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as app:
437
+ title = gr.HTML(
438
+ """<img src=\"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_logo.png\" alt=\"Qwen-Image\" style=\"width: 280px; margin: 0 auto\">
439
+ <h3 style=\"margin-top: -10px\">LoRA馃 ChoquinLabs Explorer</h3>""",
440
+ elem_id="title",
441
+ )
442
+
443
+ selected_index = gr.State(None)
444
+
445
+ with gr.Row():
446
+ with gr.Column(scale=3):
447
+ prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
448
+ with gr.Column(scale=1, elem_id="gen_column"):
449
+ generate_button = gr.Button("Generate", variant="primary", elem_id="gen_btn")
450
+
451
+ with gr.Row():
452
+ with gr.Column():
453
+ selected_info = gr.Markdown("")
454
+ examples_component = gr.Examples(examples=[], inputs=[prompt], label="Sample Prompts", visible=False)
455
+ gallery = gr.Gallery(
456
+ [(item["image"], item["title"]) for item in loras],
457
+ label="LoRA Gallery",
458
+ allow_preview=False,
459
+ columns=3,
460
+ elem_id="gallery",
461
+ show_share_button=False
462
+ )
463
+ with gr.Group():
464
+ custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path", placeholder="username/qwen-image-custom-lora")
465
+ gr.Markdown("[Check Qwen-Image LoRAs](https://huggingface.co/models?other=base_model:adapter:Qwen/Qwen-Image)", elem_id="lora_list")
466
+ custom_lora_info = gr.HTML(visible=False)
467
+ custom_lora_button = gr.Button("Remove custom LoRA", visible=False)
468
+
469
+ with gr.Column():
470
+ result = gr.Image(label="Generated Image")
471
+
472
+ with gr.Row():
473
+ speed_mode = gr.Radio(
474
+ label="Generation Mode",
475
+ choices=["Speed (4 steps)", "Speed (8 steps)", "Quality (45 steps)"],
476
+ value="Speed (4 steps)",
477
+ info="Speed mode uses Lightning LoRA for faster generation"
478
+ )
479
+
480
+ speed_status = gr.Markdown("Quality mode active", elem_id="speed_status")
481
+
482
+ with gr.Row():
483
+ with gr.Accordion("Advanced Settings", open=False):
484
+ with gr.Column():
485
+ with gr.Row():
486
+ aspect_ratio = gr.Radio(
487
+ label="Aspect Ratio",
488
+ choices=["1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3", "3:1", "2:1"],
489
+ value="16:9"
490
+ )
491
+
492
+ with gr.Row():
493
+ cfg_scale = gr.Slider(
494
+ label="Guidance Scale (True CFG)",
495
+ minimum=1.0,
496
+ maximum=5.0,
497
+ step=0.1,
498
+ value=3.5,
499
+ info="Lower for speed mode, higher for quality"
500
+ )
501
+ steps = gr.Slider(
502
+ label="Steps",
503
+ minimum=4,
504
+ maximum=50,
505
+ step=1,
506
+ value=45,
507
+ info="Automatically set by speed mode"
508
+ )
509
+
510
+ with gr.Row():
511
+ randomize_seed = gr.Checkbox(True, label="Randomize seed")
512
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
513
+ lora_scale = gr.Slider(label="LoRA Scale", minimum=0, maximum=3, step=0.01, value=1.0)
514
+
515
+ # Event handlers
516
+ gallery.select(
517
+ update_selection,
518
+ inputs=[aspect_ratio],
519
+ outputs=[prompt, selected_info, selected_index, aspect_ratio]
520
+ )
521
+
522
+ speed_mode.change(
523
+ handle_speed_mode,
524
+ inputs=[speed_mode],
525
+ outputs=[speed_status, steps, cfg_scale]
526
+ )
527
+
528
+ custom_lora.input(
529
+ add_custom_lora,
530
+ inputs=[custom_lora],
531
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, prompt]
532
+ )
533
+
534
+ custom_lora_button.click(
535
+ remove_custom_lora,
536
+ outputs=[custom_lora_info, custom_lora_button, gallery, selected_info, selected_index, custom_lora]
537
+ )
538
+
539
+ gr.on(
540
+ triggers=[generate_button.click, prompt.submit],
541
+ fn=run_lora,
542
+ inputs=[prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode],
543
+ outputs=[result, seed]
544
+ )
545
+
546
+ app.load(
547
+ fn=handle_speed_mode,
548
+ inputs=[gr.State("Speed (4 steps)")],
549
+ outputs=[speed_status, steps, cfg_scale]
550
+ )
551
+
552
+ app.queue()
553
+ app.launch()