Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
0dcb062
1
Parent(s):
866d9b3
Add support for pruned model and update generation modes in app.py
Browse files- .gitignore +8 -0
- app.py +30 -13
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Virtual Environments
|
| 2 |
+
venv/
|
| 3 |
+
env/
|
| 4 |
+
.venv/
|
| 5 |
+
|
| 6 |
+
# Python cache
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.pyc
|
app.py
CHANGED
|
@@ -89,6 +89,12 @@ pipe = DiffusionPipeline.from_pretrained(
|
|
| 89 |
base_model, scheduler=scheduler, torch_dtype=dtype
|
| 90 |
).to(device)
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
# Lightning LoRA info (no global state)
|
| 93 |
LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
|
| 94 |
LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
|
|
@@ -162,17 +168,19 @@ def handle_speed_mode(speed_mode):
|
|
| 162 |
"""Update UI based on speed/quality toggle."""
|
| 163 |
if speed_mode == "Speed (8 steps)":
|
| 164 |
return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
|
|
|
|
|
|
|
| 165 |
else:
|
| 166 |
return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
|
| 167 |
|
| 168 |
@spaces.GPU(duration=70)
|
| 169 |
-
def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
|
| 170 |
-
|
| 171 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
| 172 |
|
| 173 |
with calculateDuration("Generating image"):
|
| 174 |
# Generate image
|
| 175 |
-
image =
|
| 176 |
prompt=prompt_mash,
|
| 177 |
negative_prompt=negative_prompt,
|
| 178 |
num_inference_steps=steps,
|
|
@@ -205,15 +213,21 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
|
|
| 205 |
else:
|
| 206 |
prompt_mash = prompt
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
# Always unload any existing LoRAs first to avoid conflicts
|
| 209 |
with calculateDuration("Unloading existing LoRAs"):
|
| 210 |
-
|
| 211 |
|
| 212 |
# Load LoRAs based on speed mode
|
| 213 |
if speed_mode == "Speed (8 steps)":
|
| 214 |
with calculateDuration("Loading Lightning LoRA and style LoRA"):
|
| 215 |
# Load Lightning LoRA first
|
| 216 |
-
|
| 217 |
LIGHTNING_LORA_REPO,
|
| 218 |
weight_name=LIGHTNING_LORA_WEIGHT,
|
| 219 |
adapter_name="lightning"
|
|
@@ -221,7 +235,7 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
|
|
| 221 |
|
| 222 |
# Load the selected style LoRA
|
| 223 |
weight_name = selected_lora.get("weights", None)
|
| 224 |
-
|
| 225 |
lora_path,
|
| 226 |
weight_name=weight_name,
|
| 227 |
low_cpu_mem_usage=True,
|
|
@@ -229,18 +243,21 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
|
|
| 229 |
)
|
| 230 |
|
| 231 |
# Set both adapters active with their weights
|
| 232 |
-
|
| 233 |
else:
|
| 234 |
-
# Quality mode - only load the style LoRA
|
| 235 |
-
|
|
|
|
|
|
|
|
|
|
| 236 |
weight_name = selected_lora.get("weights", None)
|
| 237 |
-
|
| 238 |
lora_path,
|
| 239 |
weight_name=weight_name,
|
| 240 |
low_cpu_mem_usage=True,
|
| 241 |
adapter_name="style"
|
| 242 |
)
|
| 243 |
-
|
| 244 |
|
| 245 |
# Set random seed for reproducibility
|
| 246 |
with calculateDuration("Randomizing seed"):
|
|
@@ -251,7 +268,7 @@ def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, asp
|
|
| 251 |
width, height = get_image_size(aspect_ratio)
|
| 252 |
|
| 253 |
# Generate the image
|
| 254 |
-
final_image = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale)
|
| 255 |
|
| 256 |
return final_image, seed
|
| 257 |
|
|
@@ -433,7 +450,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as app:
|
|
| 433 |
with gr.Row():
|
| 434 |
speed_mode = gr.Radio(
|
| 435 |
label="Generation Mode",
|
| 436 |
-
choices=["Speed (8 steps)", "Quality (45 steps)"],
|
| 437 |
value="Speed (8 steps)",
|
| 438 |
info="Speed mode uses Lightning LoRA for faster generation"
|
| 439 |
)
|
|
|
|
| 89 |
base_model, scheduler=scheduler, torch_dtype=dtype
|
| 90 |
).to(device)
|
| 91 |
|
| 92 |
+
# Pruned model
|
| 93 |
+
pruned_model = "OPPOer/Qwen-Image-Pruning"
|
| 94 |
+
pruned_pipe = DiffusionPipeline.from_pretrained(
|
| 95 |
+
pruned_model, scheduler=scheduler, torch_dtype=dtype
|
| 96 |
+
).to(device)
|
| 97 |
+
|
| 98 |
# Lightning LoRA info (no global state)
|
| 99 |
LIGHTNING_LORA_REPO = "lightx2v/Qwen-Image-Lightning"
|
| 100 |
LIGHTNING_LORA_WEIGHT = "Qwen-Image-Lightning-8steps-V1.0.safetensors"
|
|
|
|
| 168 |
"""Update UI based on speed/quality toggle."""
|
| 169 |
if speed_mode == "Speed (8 steps)":
|
| 170 |
return gr.update(value="Speed mode selected - 8 steps with Lightning LoRA"), 8, 1.0
|
| 171 |
+
elif speed_mode == "Prune (8 steps)":
|
| 172 |
+
return gr.update(value="Prune mode selected - 8 steps with Pruned Model"), 8, 1.0
|
| 173 |
else:
|
| 174 |
return gr.update(value="Quality mode selected - 45 steps for best quality"), 45, 3.5
|
| 175 |
|
| 176 |
@spaces.GPU(duration=70)
|
| 177 |
+
def generate_image(current_pipe, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, negative_prompt=""):
|
| 178 |
+
current_pipe.to("cuda")
|
| 179 |
generator = torch.Generator(device="cuda").manual_seed(seed)
|
| 180 |
|
| 181 |
with calculateDuration("Generating image"):
|
| 182 |
# Generate image
|
| 183 |
+
image = current_pipe(
|
| 184 |
prompt=prompt_mash,
|
| 185 |
negative_prompt=negative_prompt,
|
| 186 |
num_inference_steps=steps,
|
|
|
|
| 213 |
else:
|
| 214 |
prompt_mash = prompt
|
| 215 |
|
| 216 |
+
# Select the pipeline based on the mode
|
| 217 |
+
if speed_mode == "Prune (8 steps)":
|
| 218 |
+
current_pipe = pruned_pipe
|
| 219 |
+
else:
|
| 220 |
+
current_pipe = pipe
|
| 221 |
+
|
| 222 |
# Always unload any existing LoRAs first to avoid conflicts
|
| 223 |
with calculateDuration("Unloading existing LoRAs"):
|
| 224 |
+
current_pipe.unload_lora_weights()
|
| 225 |
|
| 226 |
# Load LoRAs based on speed mode
|
| 227 |
if speed_mode == "Speed (8 steps)":
|
| 228 |
with calculateDuration("Loading Lightning LoRA and style LoRA"):
|
| 229 |
# Load Lightning LoRA first
|
| 230 |
+
current_pipe.load_lora_weights(
|
| 231 |
LIGHTNING_LORA_REPO,
|
| 232 |
weight_name=LIGHTNING_LORA_WEIGHT,
|
| 233 |
adapter_name="lightning"
|
|
|
|
| 235 |
|
| 236 |
# Load the selected style LoRA
|
| 237 |
weight_name = selected_lora.get("weights", None)
|
| 238 |
+
current_pipe.load_lora_weights(
|
| 239 |
lora_path,
|
| 240 |
weight_name=weight_name,
|
| 241 |
low_cpu_mem_usage=True,
|
|
|
|
| 243 |
)
|
| 244 |
|
| 245 |
# Set both adapters active with their weights
|
| 246 |
+
current_pipe.set_adapters(["lightning", "style"], adapter_weights=[1.0, lora_scale])
|
| 247 |
else:
|
| 248 |
+
# Quality or Prune mode - only load the style LoRA
|
| 249 |
+
log_message = f"Loading LoRA weights for {selected_lora['title']}"
|
| 250 |
+
if speed_mode == "Prune (8 steps)":
|
| 251 |
+
log_message += " on Pruned Model"
|
| 252 |
+
with calculateDuration(log_message):
|
| 253 |
weight_name = selected_lora.get("weights", None)
|
| 254 |
+
current_pipe.load_lora_weights(
|
| 255 |
lora_path,
|
| 256 |
weight_name=weight_name,
|
| 257 |
low_cpu_mem_usage=True,
|
| 258 |
adapter_name="style"
|
| 259 |
)
|
| 260 |
+
current_pipe.set_adapters(["style"], adapter_weights=[lora_scale])
|
| 261 |
|
| 262 |
# Set random seed for reproducibility
|
| 263 |
with calculateDuration("Randomizing seed"):
|
|
|
|
| 268 |
width, height = get_image_size(aspect_ratio)
|
| 269 |
|
| 270 |
# Generate the image
|
| 271 |
+
final_image = generate_image(current_pipe, prompt_mash, steps, seed, cfg_scale, width, height, lora_scale)
|
| 272 |
|
| 273 |
return final_image, seed
|
| 274 |
|
|
|
|
| 450 |
with gr.Row():
|
| 451 |
speed_mode = gr.Radio(
|
| 452 |
label="Generation Mode",
|
| 453 |
+
choices=["Speed (8 steps)", "Quality (45 steps)", "Prune (8 steps)"],
|
| 454 |
value="Speed (8 steps)",
|
| 455 |
info="Speed mode uses Lightning LoRA for faster generation"
|
| 456 |
)
|