Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from diffusers import DiffusionPipeline, QwenImageEditPipeline, FlowMatchEulerDiscreteScheduler | |
| import random | |
| import uuid | |
| import numpy as np | |
| import time | |
| import zipfile | |
| import os | |
| import requests | |
| from urllib.parse import urlparse | |
| import tempfile | |
| import shutil | |
| import math | |
| # --- App Description --- | |
| DESCRIPTION = """## Qwen Image Hpc/.""" | |
| # --- Helper Functions for Both Tabs --- | |
| MAX_SEED = np.iinfo(np.int32).max | |
| def save_image(img): | |
| """Saves a PIL image to a temporary file with a unique name.""" | |
| unique_name = str(uuid.uuid4()) + ".png" | |
| img.save(unique_name) | |
| return unique_name | |
| def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
| """Returns a random seed if randomize_seed is True, otherwise returns the original seed.""" | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| return seed | |
| # --- Model Loading --- | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # --- Qwen-Image-Gen Model --- | |
| pipe_qwen_gen = DiffusionPipeline.from_pretrained( | |
| "Qwen/Qwen-Image", | |
| torch_dtype=dtype | |
| ).to(device) | |
| # --- Qwen-Image-Edit Model with Lightning LoRA --- | |
| scheduler_config = { | |
| "base_image_seq_len": 256, | |
| "base_shift": math.log(3), | |
| "invert_sigmas": False, | |
| "max_image_seq_len": 8192, | |
| "max_shift": math.log(3), | |
| "num_train_timesteps": 1000, | |
| "shift": 1.0, | |
| "shift_terminal": None, | |
| "stochastic_sampling": False, | |
| "time_shift_type": "exponential", | |
| "use_beta_sigmas": False, | |
| "use_dynamic_shifting": True, | |
| "use_exponential_sigmas": False, | |
| "use_karras_sigmas": False, | |
| } | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) | |
| pipe_qwen_edit = QwenImageEditPipeline.from_pretrained( | |
| "Qwen/Qwen-Image-Edit", | |
| scheduler=scheduler, | |
| torch_dtype=dtype | |
| ).to(device) | |
| try: | |
| pipe_qwen_edit.load_lora_weights( | |
| "lightx2v/Qwen-Image-Lightning", | |
| weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors" | |
| ) | |
| pipe_qwen_edit.fuse_lora() | |
| print("Successfully loaded Lightning LoRA weights for Qwen-Image-Edit") | |
| except Exception as e: | |
| print(f"Warning: Could not load Lightning LoRA weights for Qwen-Image-Edit: {e}") | |
| print("Continuing with the base Qwen-Image-Edit model...") | |
| # --- Qwen-Image-Gen Functions --- | |
| aspect_ratios = { | |
| "1:1": (1328, 1328), | |
| "16:9": (1664, 928), | |
| "9:16": (928, 1664), | |
| "4:3": (1472, 1140), | |
| "3:4": (1140, 1472) | |
| } | |
| def load_lora_opt(pipe, lora_input): | |
| """Loads a LoRA from a local path, Hugging Face repo, or URL.""" | |
| lora_input = lora_input.strip() | |
| if not lora_input: | |
| return | |
| if "/" in lora_input and not lora_input.startswith("http"): | |
| pipe.load_lora_weights(lora_input, adapter_name="default") | |
| return | |
| if lora_input.startswith("http"): | |
| url = lora_input | |
| if "huggingface.co" in url and "/blob/" not in url and "/resolve/" not in url: | |
| repo_id = urlparse(url).path.strip("/") | |
| pipe.load_lora_weights(repo_id, adapter_name="default") | |
| return | |
| if "/blob/" in url: | |
| url = url.replace("/blob/", "/resolve/") | |
| tmp_dir = tempfile.mkdtemp() | |
| local_path = os.path.join(tmp_dir, os.path.basename(urlparse(url).path)) | |
| try: | |
| print(f"Downloading LoRA from {url}...") | |
| resp = requests.get(url, stream=True) | |
| resp.raise_for_status() | |
| with open(local_path, "wb") as f: | |
| for chunk in resp.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| print(f"Saved LoRA to {local_path}") | |
| pipe.load_lora_weights(local_path, adapter_name="default") | |
| finally: | |
| shutil.rmtree(tmp_dir, ignore_errors=True) | |
| def generate_qwen( | |
| prompt: str, | |
| negative_prompt: str = "", | |
| seed: int = 0, | |
| width: int = 1024, | |
| height: int = 1024, | |
| guidance_scale: float = 4.0, | |
| randomize_seed: bool = False, | |
| num_inference_steps: int = 50, | |
| num_images: int = 1, | |
| zip_images: bool = False, | |
| lora_input: str = "", | |
| lora_scale: float = 1.0, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Main generation function for Qwen-Image-Gen.""" | |
| seed = randomize_seed_fn(seed, randomize_seed) | |
| generator = torch.Generator(device).manual_seed(seed) | |
| start_time = time.time() | |
| current_adapters = pipe_qwen_gen.get_list_adapters() | |
| for adapter in current_adapters: | |
| pipe_qwen_gen.delete_adapters(adapter) | |
| pipe_qwen_gen.disable_lora() | |
| if lora_input and lora_input.strip() != "": | |
| load_lora_opt(pipe_qwen_gen, lora_input) | |
| pipe_qwen_gen.set_adapters(["default"], adapter_weights=[lora_scale]) | |
| images = pipe_qwen_gen( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt if negative_prompt else " ", | |
| height=height, | |
| width=width, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| num_images_per_prompt=num_images, | |
| generator=generator, | |
| ).images | |
| end_time = time.time() | |
| duration = end_time - start_time | |
| image_paths = [save_image(img) for img in images] | |
| zip_path = None | |
| if zip_images and len(image_paths) > 0: | |
| zip_name = str(uuid.uuid4()) + ".zip" | |
| with zipfile.ZipFile(zip_name, 'w') as zipf: | |
| for i, img_path in enumerate(image_paths): | |
| zipf.write(img_path, arcname=f"Img_{i}.png") | |
| zip_path = zip_name | |
| current_adapters = pipe_qwen_gen.get_list_adapters() | |
| for adapter in current_adapters: | |
| pipe_qwen_gen.delete_adapters(adapter) | |
| pipe_qwen_gen.disable_lora() | |
| return image_paths, seed, f"{duration:.2f}", zip_path | |
| def generate( | |
| prompt: str, | |
| negative_prompt: str, | |
| use_negative_prompt: bool, | |
| seed: int, | |
| width: int, | |
| height: int, | |
| guidance_scale: float, | |
| randomize_seed: bool, | |
| num_inference_steps: int, | |
| num_images: int, | |
| zip_images: bool, | |
| lora_input: str, | |
| lora_scale: float, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """UI wrapper for the Qwen-Image-Gen generation function.""" | |
| final_negative_prompt = negative_prompt if use_negative_prompt else "" | |
| return generate_qwen( | |
| prompt=prompt, | |
| negative_prompt=final_negative_prompt, | |
| seed=seed, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| randomize_seed=randomize_seed, | |
| num_inference_steps=num_inference_steps, | |
| num_images=num_images, | |
| zip_images=zip_images, | |
| lora_input=lora_input, | |
| lora_scale=lora_scale, | |
| progress=progress, | |
| ) | |
| # --- Qwen-Image-Edit Functions --- | |
| def infer_edit( | |
| image, | |
| prompt, | |
| seed=42, | |
| randomize_seed=False, | |
| true_guidance_scale=1.0, | |
| num_inference_steps=8, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Main inference function for Qwen-Image-Edit.""" | |
| if image is None: | |
| raise gr.Error("Please upload an image to edit.") | |
| negative_prompt = " " | |
| seed = randomize_seed_fn(seed, randomize_seed) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| print(f"Original prompt: '{prompt}'") | |
| print(f"Negative Prompt: '{negative_prompt}'") | |
| print(f"Seed: {seed}, Steps: {num_inference_steps}, Guidance: {true_guidance_scale}") | |
| try: | |
| images = pipe_qwen_edit( | |
| image, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| true_cfg_scale=true_guidance_scale, | |
| num_images_per_prompt=1 | |
| ).images | |
| return images[0], seed | |
| except Exception as e: | |
| print(f"Error during inference: {e}") | |
| raise gr.Error(f"An error occurred during image editing: {e}") | |
| # --- Gradio UI --- | |
| css = ''' | |
| .gradio-container { | |
| max-width: 800px !important; | |
| margin: 0 auto !important; | |
| } | |
| h1 { | |
| text-align: center; | |
| } | |
| footer { | |
| visibility: hidden; | |
| } | |
| ''' | |
| with gr.Blocks(css=css, theme="bethecloud/storj_theme", delete_cache=(240, 240)) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Tabs(): | |
| with gr.TabItem("Qwen-Image-Gen"): | |
| with gr.Column(): | |
| with gr.Row(): | |
| prompt_gen = gr.Text( | |
| label="Prompt", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="✦︎ Enter your prompt for generation", | |
| container=False, | |
| ) | |
| run_button_gen = gr.Button("Generate", scale=0, variant="primary") | |
| result_gen = gr.Gallery(label="Result", columns=2, show_label=False, preview=True, height="auto") | |
| with gr.Row(): | |
| aspect_ratio_gen = gr.Dropdown( | |
| label="Aspect Ratio", | |
| choices=list(aspect_ratios.keys()), | |
| value="1:1", | |
| ) | |
| lora_gen = gr.Textbox(label="Optional LoRA", placeholder="Enter Hugging Face repo ID or URL...") | |
| with gr.Accordion("Additional Options", open=False): | |
| use_negative_prompt_gen = gr.Checkbox(label="Use negative prompt", value=True) | |
| negative_prompt_gen = gr.Text( | |
| label="Negative prompt", | |
| max_lines=1, | |
| placeholder="Enter a negative prompt", | |
| value="text, watermark, copyright, blurry, low resolution", | |
| ) | |
| seed_gen = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
| randomize_seed_gen = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| width_gen = gr.Slider(label="Width", minimum=512, maximum=2048, step=64, value=1328) | |
| height_gen = gr.Slider(label="Height", minimum=512, maximum=2048, step=64, value=1328) | |
| guidance_scale_gen = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.1, value=4.0) | |
| num_inference_steps_gen = gr.Slider("Number of inference steps", 1, 100, 50, step=1) | |
| num_images_gen = gr.Slider("Number of images", 1, 5, 1, step=1) | |
| zip_images_gen = gr.Checkbox(label="Zip generated images", value=False) | |
| with gr.Row(): | |
| lora_scale_gen = gr.Slider(label="LoRA Scale", minimum=0, maximum=2, step=0.01, value=1) | |
| gr.Markdown("### Output Information") | |
| seed_display_gen = gr.Textbox(label="Seed used", interactive=False) | |
| generation_time_gen = gr.Textbox(label="Generation time (seconds)", interactive=False) | |
| zip_file_gen = gr.File(label="Download ZIP") | |
| # --- Gen Tab Logic --- | |
| def set_dimensions(ar): | |
| w, h = aspect_ratios[ar] | |
| return gr.update(value=w), gr.update(value=h) | |
| aspect_ratio_gen.change(fn=set_dimensions, inputs=aspect_ratio_gen, outputs=[width_gen, height_gen]) | |
| use_negative_prompt_gen.change(fn=lambda x: gr.update(visible=x), inputs=use_negative_prompt_gen, outputs=negative_prompt_gen) | |
| gen_inputs = [ | |
| prompt_gen, negative_prompt_gen, use_negative_prompt_gen, seed_gen, width_gen, height_gen, | |
| guidance_scale_gen, randomize_seed_gen, num_inference_steps_gen, num_images_gen, | |
| zip_images_gen, lora_gen, lora_scale_gen | |
| ] | |
| gen_outputs = [result_gen, seed_display_gen, generation_time_gen, zip_file_gen] | |
| gr.on(triggers=[prompt_gen.submit, run_button_gen.click], fn=generate, inputs=gen_inputs, outputs=gen_outputs) | |
| gen_examples = [ | |
| "A decadent slice of layered chocolate cake on a ceramic plate with a drizzle of chocolate syrup and powdered sugar dusted on top.", | |
| "A young girl wearing school uniform stands in a classroom, writing on a chalkboard. The text 'Introducing Qwen-Image' appears in neat white chalk.", | |
| "一幅精致细腻的工笔画,画面中心是一株蓬勃生长的红色牡丹,花朵繁茂。", | |
| "Realistic still life photography style: A single, fresh apple, resting on a clean, soft-textured surface.", | |
| ] | |
| gr.Examples(examples=gen_examples, inputs=prompt_gen, outputs=gen_outputs, fn=generate, cache_examples=False) | |
| with gr.TabItem("Qwen-Image-Edit"): | |
| with gr.Column(): | |
| with gr.Row(): | |
| input_image_edit = gr.Image(label="Input Image", type="pil", height=400) | |
| result_edit = gr.Image(label="Result", type="pil", height=400) | |
| with gr.Row(): | |
| prompt_edit = gr.Text( | |
| label="Edit Instruction", | |
| show_label=False, | |
| placeholder="Describe the edit you want to make", | |
| container=False, | |
| ) | |
| run_button_edit = gr.Button("Edit", variant="primary") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed_edit = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) | |
| randomize_seed_edit = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| true_guidance_scale_edit = gr.Slider( | |
| label="True guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0 | |
| ) | |
| num_inference_steps_edit = gr.Slider( | |
| label="Inference steps (Lightning LoRA)", minimum=4, maximum=28, step=1, value=8 | |
| ) | |
| # --- Edit Tab Logic --- | |
| edit_inputs = [ | |
| input_image_edit, prompt_edit, seed_edit, randomize_seed_edit, | |
| true_guidance_scale_edit, num_inference_steps_edit | |
| ] | |
| edit_outputs = [result_edit, seed_edit] | |
| gr.on(triggers=[prompt_edit.submit, run_button_edit.click], fn=infer_edit, inputs=edit_inputs, outputs=edit_outputs) | |
| edit_examples = [ | |
| ["image-edit/cat.png", "make the cat wear sunglasses"], | |
| ["image-edit/girl.png", "change her hair to blonde"], | |
| ] | |
| gr.Examples(examples=edit_examples, inputs=[input_image_edit, prompt_edit], outputs=edit_outputs, fn=infer_edit, cache_examples=True) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=50).launch(share=False, debug=True) |