lichorosario's picture
Update app.py
7c466dd verified
raw
history blame
9.09 kB
import os
import gradio as gr
import json
import logging
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
import copy
import random
import time
import re
import math
import numpy as np
import traceback
# Load LoRAs from JSON file
def load_loras_from_file():
"""Load LoRA configurations from external JSON file."""
try:
with open('loras.json', 'r', encoding='utf-8') as f:
return json.load(f)
except FileNotFoundError:
print("Warning: loras.json file not found. Using empty list.")
return []
except json.JSONDecodeError as e:
print(f"Error parsing loras.json: {e}")
return []
# Load the LoRAs
loras = load_loras_from_file()
# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "Qwen/Qwen-Image"
# Scheduler configuration from the Qwen-Image-Lightning repository
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3),
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3),
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None,
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
pipe = DiffusionPipeline.from_pretrained(
base_model, scheduler=scheduler, torch_dtype=dtype
).to(device)
# Lightning LoRA info (no global state)
LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors"
LIGHTNING8_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V2.0-bf16.safetensors"
MAX_SEED = np.iinfo(np.int32).max
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
def get_image_size(aspect_ratio):
"""Converts aspect ratio string to width, height tuple."""
if aspect_ratio == "1:1":
return 1024, 1024
elif aspect_ratio == "2:1":
return 1280, 640
elif aspect_ratio == "16:9":
return 1152, 640
elif aspect_ratio == "9:16":
return 640, 1152
elif aspect_ratio == "4:3":
return 1024, 768
elif aspect_ratio == "3:4":
return 768, 1024
elif aspect_ratio == "3:2":
return 1024, 688
elif aspect_ratio == "2:3":
return 688, 1024
elif aspect_ratio == "3:1":
return 1920, 640
elif aspect_ratio == "2:1":
return 1280, 640
else:
return 1024, 1024
def update_selection(evt: gr.SelectData, aspect_ratio):
selected_lora = loras[evt.index]
new_placeholder = f"Type a prompt for {selected_lora['title']}"
lora_repo = selected_lora["repo"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨"
# Get model card examples
examples_list = []
try:
model_card = ModelCard.load(lora_repo)
widget_data = model_card.data.get("widget", [])
if widget_data and len(widget_data) > 0:
# Get examples from widget data
for example in widget_data[:4]:
if "output" in example and "url" in example["output"]:
image_url = f"https://huggingface.co/{lora_repo}/resolve/main/{example['output']['url']}"
prompt_text = example.get("text", "")
examples_list.append([prompt_text])
except Exception as e:
print(f"Could not load model card for {lora_repo}: {e}")
# Update aspect ratio if specified in LoRA config
if "aspect" in selected_lora:
if selected_lora["aspect"] == "portrait":
aspect_ratio = "9:16"
elif selected_lora["aspect"] == "landscape":
aspect_ratio = "16:9"
elif selected_lora["aspect"] == "square":
aspect_ratio = "1:1"
else:
aspect_ratio = selected_lora["aspect"]
return (
gr.update(placeholder=new_placeholder),
updated_text,
evt.index,
aspect_ratio
)
def handle_speed_mode(speed_mode):
"""Update UI based on speed/quality toggle."""
if speed_mode == "Speed (4 steps)":
return gr.update(value="Speed mode selected - 4 steps with Lightning LoRA"), 4, 1.0
elif speed_mode == "Speed (8 steps)":
return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
else:
return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
@spaces.GPU(duration=70)
def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(seed)
with calculateDuration("Generating image"):
# Generate image
image = pipe(
prompt=prompt_mash,
negative_prompt=negative_prompt,
num_inference_steps=steps,
true_cfg_scale=cfg_scale, # Use true_cfg_scale for Qwen-Image
width=width,
height=height,
generator=generator,
).images[0]
return image
@spaces.GPU(duration=70)
def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, aspect_ratio, lora_scale, speed_mode, progress=gr.Progress(track_tqdm=True)):
if selected_index is None:
raise gr.Error("You must select a LoRA before proceeding.")
selected_lora = loras[selected_index]
lora_path = selected_lora["repo"]
trigger_word = selected_lora["trigger_word"]
# Prepare prompt with trigger word
if trigger_word:
if "trigger_position" in selected_lora:
if selected_lora["trigger_position"] == "prepend":
prompt_mash = f"{trigger_word} {prompt}"
else:
prompt_mash = f"{prompt} {trigger_word}"
else:
prompt_mash = f"{trigger_word} {prompt}"
else:
prompt_mash = prompt
# Always unload any existing LoRAs first to avoid conflicts
with calculateDuration("Unloading existing LoRAs"):
pipe.unload_lora_weights()
# Load LoRAs based on speed mode
if speed_mode == "Speed (4 steps)":
with calculateDuration("Loading Lightning LoRA and style LoRA"):
pipe.load_lora_weights(
LIGHTNING_LORA_REPO,
weight_name=LIGHTNING_LORA_WEIGHT,
adapter_name="lightning"
)
weight_name = selected_lora.get("weights", None)
pipe.load_lora_weights(
lora_path,
weight_name=weight_name,
low_cpu_mem_usage=True,
adapter_name="style"
)
pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
elif speed_mode == "Speed (8 steps)":
with calculateDuration("Loading Lightning LoRA and style LoRA"):
pipe.load_lora_weights(
LIGHTNING_LORA_REPO,
weight_name=LIGHTNING8_LORA_WEIGHT,
adapter_name="lightning"
)
weight_name = selected_lora.get("weights", None)
pipe.load_lora_weights(
lora_path,
weight_name=weight_name,
low_cpu_mem_usage=True,
adapter_name="style"
)
pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
else:
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
weight_name = selected_lora.get("weights", None)
pipe.load_lora_weights(
lora_path,
weight_name=weight_name,
low_cpu_mem_usage=True,
adapter_name="style"
)
pipe.set_adapters(["style"], adapter_weights=[lora_scale])
# Set random seed for reproducibility
with calculateDuration("Randomizing seed"):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Get image dimensions from aspect ratio
width, height = get_image_size(aspect_ratio)
# Generate the image
final_image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale)
return final_image, seed
# (resto del código con interfaz Gradio, etc.)