|
|
import os |
|
|
IS_SPACE = True |
|
|
|
|
|
if IS_SPACE: |
|
|
import spaces |
|
|
|
|
|
|
|
|
import sys |
|
|
import warnings |
|
|
import subprocess |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, Dict |
|
|
import torch |
|
|
|
|
|
def space_context(duration: int): |
|
|
if IS_SPACE: |
|
|
return spaces.GPU(duration=duration) |
|
|
return lambda x: x |
|
|
|
|
|
@space_context(duration=120) |
|
|
def test_env(): |
|
|
assert torch.cuda.is_available() |
|
|
|
|
|
try: |
|
|
import flash_attn |
|
|
except ImportError: |
|
|
print("Flash-attn not found, installing...") |
|
|
os.system("pip install flash-attn==2.7.3 --no-build-isolation") |
|
|
|
|
|
else: |
|
|
print("Flash-attn found, skipping installation...") |
|
|
test_env() |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
try: |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
from hyimage.diffusion.pipelines.hunyuanimage_pipeline import HunyuanImagePipeline |
|
|
from huggingface_hub import snapshot_download |
|
|
import modelscope |
|
|
except ImportError as e: |
|
|
print(f"Missing required dependencies: {e}") |
|
|
print("Please install with: pip install -r requirements_gradio.txt") |
|
|
print("For checkpoint downloads, also install: pip install -U 'huggingface_hub[cli]' modelscope") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
BASE_DIR = os.environ.get('HUNYUANIMAGE_V2_1_MODEL_ROOT', './ckpts') |
|
|
|
|
|
class CheckpointDownloader: |
|
|
"""Handles downloading of all required checkpoints for HunyuanImage.""" |
|
|
|
|
|
def __init__(self, base_dir: str = BASE_DIR): |
|
|
self.base_dir = Path(base_dir) |
|
|
self.base_dir.mkdir(exist_ok=True) |
|
|
print(f'Downloading checkpoints to: {self.base_dir}') |
|
|
|
|
|
|
|
|
self.checkpoints = { |
|
|
"main_model": { |
|
|
"repo_id": "tencent/HunyuanImage-2.1", |
|
|
"local_dir": self.base_dir, |
|
|
}, |
|
|
"mllm_encoder": { |
|
|
"repo_id": "Qwen/Qwen2.5-VL-7B-Instruct", |
|
|
"local_dir": self.base_dir / "text_encoder" / "llm", |
|
|
}, |
|
|
"byt5_encoder": { |
|
|
"repo_id": "google/byt5-small", |
|
|
"local_dir": self.base_dir / "text_encoder" / "byt5-small", |
|
|
}, |
|
|
"glyph_encoder": { |
|
|
"repo_id": "AI-ModelScope/Glyph-SDXL-v2", |
|
|
"local_dir": self.base_dir / "text_encoder" / "Glyph-SDXL-v2", |
|
|
"use_modelscope": True |
|
|
} |
|
|
} |
|
|
|
|
|
def download_checkpoint(self, checkpoint_name: str, progress_callback=None) -> Tuple[bool, str]: |
|
|
"""Download a specific checkpoint.""" |
|
|
if checkpoint_name not in self.checkpoints: |
|
|
return False, f"Unknown checkpoint: {checkpoint_name}" |
|
|
|
|
|
config = self.checkpoints[checkpoint_name] |
|
|
local_dir = config["local_dir"] |
|
|
local_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
try: |
|
|
if config.get("use_modelscope", False): |
|
|
|
|
|
return self._download_with_modelscope(config, progress_callback) |
|
|
else: |
|
|
|
|
|
return self._download_with_hf(config, progress_callback) |
|
|
except Exception as e: |
|
|
return False, f"Download failed: {str(e)}" |
|
|
|
|
|
def _download_with_hf(self, config: Dict, progress_callback=None) -> Tuple[bool, str]: |
|
|
"""Download using huggingface_hub.""" |
|
|
repo_id = config["repo_id"] |
|
|
local_dir = config["local_dir"] |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(f"Downloading {repo_id}...") |
|
|
|
|
|
try: |
|
|
snapshot_download( |
|
|
repo_id=repo_id, |
|
|
local_dir=str(local_dir), |
|
|
local_dir_use_symlinks=False, |
|
|
resume_download=True |
|
|
) |
|
|
return True, f"Successfully downloaded {repo_id}" |
|
|
except Exception as e: |
|
|
return False, f"HF download failed: {str(e)}" |
|
|
|
|
|
def _download_with_modelscope(self, config: Dict, progress_callback=None) -> Tuple[bool, str]: |
|
|
"""Download using modelscope.""" |
|
|
repo_id = config["repo_id"] |
|
|
local_dir = config["local_dir"] |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(f"Downloading {repo_id} via ModelScope...") |
|
|
print(f"Downloading {repo_id} via ModelScope...") |
|
|
|
|
|
try: |
|
|
|
|
|
cmd = [ |
|
|
"modelscope", "download", |
|
|
"--model", repo_id, |
|
|
"--local_dir", str(local_dir) |
|
|
] |
|
|
|
|
|
subprocess.run(cmd, capture_output=True, text=True, check=True) |
|
|
return True, f"Successfully downloaded {repo_id} via ModelScope" |
|
|
except subprocess.CalledProcessError as e: |
|
|
return False, f"ModelScope download failed: {e.stderr}" |
|
|
except FileNotFoundError: |
|
|
return False, "ModelScope CLI not found. Install with: pip install modelscope" |
|
|
|
|
|
def download_all_checkpoints(self, progress_callback=None) -> Tuple[bool, str, Dict[str, any]]: |
|
|
"""Download all checkpoints.""" |
|
|
results = {} |
|
|
for name, _ in self.checkpoints.items(): |
|
|
if progress_callback: |
|
|
progress_callback(f"Starting download of {name}...") |
|
|
|
|
|
success, message = self.download_checkpoint(name, progress_callback) |
|
|
results[name] = {"success": success, "message": message} |
|
|
|
|
|
if not success: |
|
|
return False, f"Failed to download {name}: {message}", results |
|
|
return True, "All checkpoints downloaded successfully", results |
|
|
|
|
|
|
|
|
@space_context(duration=2000) |
|
|
def load_pipeline(use_distilled: bool = False, device: str = "cuda"): |
|
|
"""Load the HunyuanImage pipeline (only load once, refiner and reprompt are accessed from it).""" |
|
|
try: |
|
|
assert not use_distilled |
|
|
|
|
|
print(f"Loading HunyuanImage pipeline (distilled={use_distilled})...") |
|
|
model_name = "hunyuanimage-v2.1-distilled" if use_distilled else "hunyuanimage-v2.1" |
|
|
pipeline = HunyuanImagePipeline.from_pretrained( |
|
|
model_name=model_name, |
|
|
device=device, |
|
|
enable_dit_offloading=True, |
|
|
enable_reprompt_model_offloading=True, |
|
|
enable_refiner_offloading=True |
|
|
) |
|
|
pipeline.to('cpu') |
|
|
refiner_pipeline = pipeline.refiner_pipeline |
|
|
refiner_pipeline.text_encoder.model = pipeline.text_encoder.model |
|
|
refiner_pipeline.to('cpu') |
|
|
reprompt_model = pipeline.reprompt_model |
|
|
|
|
|
print("β Pipeline loaded successfully") |
|
|
return pipeline |
|
|
except Exception as e: |
|
|
error_msg = f"Error loading pipeline: {str(e)}" |
|
|
print(f"β {error_msg}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipeline = load_pipeline(use_distilled=False, device="cuda") |
|
|
class HunyuanImageApp: |
|
|
|
|
|
@space_context(duration=290) |
|
|
def __init__(self, auto_load: bool = True, use_distilled: bool = False, device: str = "cuda"): |
|
|
"""Initialize the HunyuanImage Gradio app.""" |
|
|
global pipeline |
|
|
|
|
|
self.pipeline = pipeline |
|
|
self.current_use_distilled = None |
|
|
|
|
|
|
|
|
self.aspect_ratio_mappings = { |
|
|
"16:9": (2560, 1536), |
|
|
"4:3": (2304, 1792), |
|
|
"1:1": (2048, 2048), |
|
|
"3:4": (1792, 2304), |
|
|
"9:16": (1536, 2560) |
|
|
} |
|
|
|
|
|
|
|
|
def print_peak_memory(self): |
|
|
import torch |
|
|
stats = torch.cuda.memory_stats() |
|
|
peak_bytes_requirement = stats["allocated_bytes.all.peak"] |
|
|
print(f"Before refiner Peak memory requirement: {peak_bytes_requirement / 1024 ** 3:.2f} GB") |
|
|
|
|
|
def update_resolution(self, aspect_ratio_choice: str) -> Tuple[int, int]: |
|
|
"""Update width and height based on selected aspect ratio.""" |
|
|
|
|
|
aspect_key = aspect_ratio_choice.split(" (")[0] |
|
|
if aspect_key in self.aspect_ratio_mappings: |
|
|
return self.aspect_ratio_mappings[aspect_key] |
|
|
else: |
|
|
|
|
|
return self.aspect_ratio_mappings["1:1"] |
|
|
|
|
|
@space_context(duration=300) |
|
|
def generate_image(self, |
|
|
prompt: str, |
|
|
negative_prompt: str, |
|
|
width: int, |
|
|
height: int, |
|
|
num_inference_steps: int, |
|
|
guidance_scale: float, |
|
|
seed: int, |
|
|
use_reprompt: bool, |
|
|
use_refiner: bool, |
|
|
|
|
|
) -> Tuple[Optional[Image.Image], str]: |
|
|
"""Generate an image using the HunyuanImage pipeline.""" |
|
|
try: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
if self.pipeline is None: |
|
|
return None, "Pipeline not loaded. Please try again." |
|
|
|
|
|
|
|
|
if hasattr(self.pipeline, '_refiner_pipeline'): |
|
|
self.pipeline.refiner_pipeline.to('cpu') |
|
|
self.pipeline.to('cuda') |
|
|
if seed == -1: |
|
|
import random |
|
|
seed = random.randint(100000, 999999) |
|
|
|
|
|
|
|
|
image = self.pipeline( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
width=width, |
|
|
height=height, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
seed=seed, |
|
|
shift=5, |
|
|
use_reprompt=use_reprompt, |
|
|
use_refiner=use_refiner |
|
|
) |
|
|
self.print_peak_memory() |
|
|
return image, "Image generated successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error generating image: {str(e)}" |
|
|
print(f"β {error_msg}") |
|
|
return None, error_msg |
|
|
|
|
|
@space_context(duration=300) |
|
|
def enhance_prompt(self, prompt: str, |
|
|
) -> Tuple[str, str]: |
|
|
"""Enhance a prompt using the reprompt model.""" |
|
|
try: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
if self.pipeline is None: |
|
|
return prompt, "Pipeline not loaded. Please try again." |
|
|
|
|
|
self.pipeline.to('cpu') |
|
|
if hasattr(self.pipeline, '_refiner_pipeline'): |
|
|
self.pipeline.refiner_pipeline.to('cpu') |
|
|
|
|
|
|
|
|
enhanced_prompt = self.pipeline.reprompt_model.predict(prompt) |
|
|
self.print_peak_memory() |
|
|
return enhanced_prompt, "Prompt enhanced successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error enhancing prompt: {str(e)}" |
|
|
print(f"β {error_msg}") |
|
|
return prompt, error_msg |
|
|
|
|
|
@space_context(duration=300) |
|
|
def refine_image(self, |
|
|
image: Image.Image, |
|
|
prompt: str, |
|
|
width: int, |
|
|
height: int, |
|
|
num_inference_steps: int, |
|
|
guidance_scale: float, |
|
|
seed: int) -> Tuple[Optional[Image.Image], str]: |
|
|
"""Refine an image using the refiner pipeline.""" |
|
|
try: |
|
|
if image is None: |
|
|
return None, "Please upload an image to refine." |
|
|
|
|
|
if not prompt or prompt.strip() == "": |
|
|
return None, "Please enter a refinement prompt." |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
if image.size != (width, height): |
|
|
image = image.resize((width, height), Image.Resampling.LANCZOS) |
|
|
|
|
|
self.pipeline.to('cpu') |
|
|
self.pipeline.refiner_pipeline.to('cuda') |
|
|
if seed == -1: |
|
|
import random |
|
|
seed = random.randint(100000, 999999) |
|
|
|
|
|
|
|
|
refined_image = self.pipeline.refiner_pipeline( |
|
|
image=image, |
|
|
prompt=prompt, |
|
|
width=width, |
|
|
height=height, |
|
|
num_inference_steps=num_inference_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
seed=seed |
|
|
) |
|
|
self.print_peak_memory() |
|
|
return refined_image, "Image refined successfully!" |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error refining image: {str(e)}" |
|
|
print(f"β {error_msg}") |
|
|
return None, error_msg |
|
|
|
|
|
|
|
|
def download_single_checkpoint(self, checkpoint_name: str) -> Tuple[bool, str]: |
|
|
"""Download a single checkpoint.""" |
|
|
try: |
|
|
success, message = self.downloader.download_checkpoint(checkpoint_name) |
|
|
return success, message |
|
|
except Exception as e: |
|
|
return False, f"Download error: {str(e)}" |
|
|
|
|
|
def download_all_checkpoints(self) -> Tuple[bool, str, Dict[str, any]]: |
|
|
"""Download all missing checkpoints.""" |
|
|
try: |
|
|
success, message, results = self.downloader.download_all_checkpoints() |
|
|
return success, message, results |
|
|
except Exception as e: |
|
|
return False, f"Download error: {str(e)}", {} |
|
|
|
|
|
def create_interface(auto_load: bool = True, use_distilled: bool = False, device: str = "cuda"): |
|
|
"""Create the Gradio interface.""" |
|
|
app = HunyuanImageApp(auto_load=auto_load, use_distilled=use_distilled, device=device) |
|
|
|
|
|
|
|
|
css = """ |
|
|
.gradio-container { |
|
|
max-width: 1200px !important; |
|
|
margin: auto !important; |
|
|
} |
|
|
.tab-nav { |
|
|
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); |
|
|
border-radius: 10px; |
|
|
padding: 10px; |
|
|
margin-bottom: 20px; |
|
|
} |
|
|
.model-info { |
|
|
background: var(--background-fill-secondary); |
|
|
border: 1px solid var(--border-color-primary); |
|
|
border-radius: 8px; |
|
|
padding: 15px; |
|
|
margin-bottom: 20px; |
|
|
color: var(--body-text-color); |
|
|
} |
|
|
.model-info h1, .model-info h2, .model-info h3 { |
|
|
color: var(--body-text-color) !important; |
|
|
} |
|
|
.model-info p, .model-info li { |
|
|
color: var(--body-text-color) !important; |
|
|
} |
|
|
.model-info strong { |
|
|
color: var(--body-text-color) !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, title="HunyuanImage Pipeline", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π¨ HunyuanImage 2.1 Pipeline |
|
|
**HunyuanImage-2.1: An Efficient Diffusion Model for High-Resolution (2K) Text-to-Image Generationβ** |
|
|
|
|
|
This app provides three main functionalities: |
|
|
1. **Text-to-Image Generation**: Generate high-quality images from text prompts |
|
|
2. **Prompt Enhancement**: Improve your prompts using MLLM reprompting |
|
|
3. **Image Refinement**: Enhance existing images with the refiner model |
|
|
""", |
|
|
elem_classes="model-info" |
|
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("πΌοΈ Text-to-Image Generation"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Generation Settings") |
|
|
gr.Markdown("**Model**: HunyuanImage v2.1 (Non-distilled)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_distilled = False |
|
|
|
|
|
prompt = gr.Textbox( |
|
|
label="Prompt", |
|
|
placeholder="", |
|
|
lines=3, |
|
|
value="A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, wearing a red knitted scarf and a red beret with the word βTencentβ on it, holding a paintbrush with a focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style." |
|
|
) |
|
|
|
|
|
negative_prompt = gr.Textbox( |
|
|
label="Negative Prompt", |
|
|
placeholder="", |
|
|
lines=2, |
|
|
value="" |
|
|
) |
|
|
|
|
|
|
|
|
aspect_ratios = [ |
|
|
("16:9 (2560Γ1536)", "16:9"), |
|
|
("4:3 (2304Γ1792)", "4:3"), |
|
|
("1:1 (2048Γ2048)", "1:1"), |
|
|
("3:4 (1792Γ2304)", "3:4"), |
|
|
("9:16 (1536Γ2560)", "9:16") |
|
|
] |
|
|
|
|
|
aspect_ratio = gr.Radio( |
|
|
choices=aspect_ratios, |
|
|
value="1:1", |
|
|
label="Aspect Ratio", |
|
|
info="Select the aspect ratio for image generation" |
|
|
) |
|
|
|
|
|
|
|
|
width = gr.Number(value=2048, visible=False) |
|
|
height = gr.Number(value=2048, visible=False) |
|
|
|
|
|
with gr.Row(): |
|
|
num_inference_steps = gr.Slider( |
|
|
minimum=10, maximum=100, step=5, value=50, |
|
|
label="Inference Steps", info="More steps = better quality, slower generation" |
|
|
) |
|
|
guidance_scale = gr.Slider( |
|
|
minimum=1.0, maximum=10.0, step=0.1, value=3.5, |
|
|
label="Guidance Scale", info="How closely to follow the prompt" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
seed = gr.Number( |
|
|
label="Seed", value=-1, precision=0, |
|
|
info="Random seed for reproducibility. (-1 for random seed)" |
|
|
) |
|
|
use_reprompt = gr.Checkbox( |
|
|
label="Use Reprompt", value=True, |
|
|
info="Enhance prompt automatically" |
|
|
) |
|
|
use_refiner = gr.Checkbox( |
|
|
label="Use Refiner", value=True, |
|
|
info="Apply refiner after generation ", |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("π¨ Generate Image", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Generated Image") |
|
|
generated_image = gr.Image( |
|
|
label="Generated Image", |
|
|
format="png", |
|
|
show_download_button=True, |
|
|
type="pil", |
|
|
height=600 |
|
|
) |
|
|
generation_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value="Ready to generate" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("β¨ Prompt Enhancement"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Prompt Enhancement Settings") |
|
|
gr.Markdown("**Model**: HunyuanImage v2.1 Reprompt Model") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enhance_use_distilled = False |
|
|
|
|
|
original_prompt = gr.Textbox( |
|
|
label="Original Prompt", |
|
|
placeholder="A cat sitting on a table", |
|
|
lines=4, |
|
|
value="A cat sitting on a table" |
|
|
) |
|
|
|
|
|
enhance_btn = gr.Button("β¨ Enhance Prompt", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Enhanced Prompt") |
|
|
enhanced_prompt = gr.Textbox( |
|
|
label="Enhanced Prompt", |
|
|
lines=6, |
|
|
interactive=False |
|
|
) |
|
|
enhancement_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value="Ready to enhance" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("π§ Image Refinement"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Refinement Settings") |
|
|
gr.Markdown("**Model**: HunyuanImage v2.1 Refiner") |
|
|
|
|
|
input_image = gr.Image( |
|
|
label="Input Image", |
|
|
type="pil", |
|
|
height=300 |
|
|
) |
|
|
|
|
|
refine_prompt = gr.Textbox( |
|
|
label="Refinement Prompt", |
|
|
placeholder="Image description", |
|
|
info="This prompt should describe the image content.", |
|
|
lines=2, |
|
|
value="" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
refine_width = gr.Slider( |
|
|
minimum=512, maximum=2048, step=64, value=2048, |
|
|
label="Width", info="Output width" |
|
|
) |
|
|
refine_height = gr.Slider( |
|
|
minimum=512, maximum=2048, step=64, value=2048, |
|
|
label="Height", info="Output height" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
refine_steps = gr.Slider( |
|
|
minimum=1, maximum=20, step=1, value=4, |
|
|
label="Refinement Steps", info="More steps = more refinement" |
|
|
) |
|
|
refine_guidance = gr.Slider( |
|
|
minimum=1.0, maximum=10.0, step=0.1, value=3.5, |
|
|
label="Guidance Scale", info="How strongly to follow the prompt" |
|
|
) |
|
|
|
|
|
refine_seed = gr.Number( |
|
|
label="Seed", value=-1, precision=0, |
|
|
info="Random seed for reproducibility" |
|
|
) |
|
|
|
|
|
refine_btn = gr.Button("π§ Refine Image", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Refined Image") |
|
|
refined_image = gr.Image( |
|
|
label="Refined Image", |
|
|
type="pil", |
|
|
format="png", |
|
|
show_download_button=True, |
|
|
height=600 |
|
|
) |
|
|
refinement_status = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
value="Ready to refine" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
aspect_ratio.change( |
|
|
fn=app.update_resolution, |
|
|
inputs=[aspect_ratio], |
|
|
outputs=[width, height] |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=app.generate_image, |
|
|
inputs=[ |
|
|
prompt, negative_prompt, width, height, num_inference_steps, |
|
|
guidance_scale, seed, use_reprompt, use_refiner |
|
|
], |
|
|
outputs=[generated_image, generation_status] |
|
|
) |
|
|
|
|
|
enhance_btn.click( |
|
|
fn=app.enhance_prompt, |
|
|
inputs=[original_prompt], |
|
|
outputs=[enhanced_prompt, enhancement_status] |
|
|
) |
|
|
|
|
|
refine_btn.click( |
|
|
fn=app.refine_image, |
|
|
inputs=[ |
|
|
input_image, refine_prompt, |
|
|
refine_width, refine_height, refine_steps, refine_guidance, refine_seed |
|
|
], |
|
|
outputs=[refined_image, refinement_status] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### π Usage Tips |
|
|
|
|
|
**Text-to-Image Generation:** |
|
|
- Use descriptive prompts with specific details |
|
|
- Adjust guidance scale: higher values follow prompts more closely |
|
|
- More inference steps generally produce better quality |
|
|
- Enable reprompt for automatic prompt enhancement |
|
|
- Enable refiner for additional quality improvement |
|
|
|
|
|
**Prompt Enhancement:** |
|
|
- Enter your basic prompt idea |
|
|
- The AI will enhance it with better structure and details |
|
|
- Enhanced prompts often produce better results |
|
|
|
|
|
**Image Refinement:** |
|
|
- Upload any image you want to improve |
|
|
- Describe what improvements you want in the refinement prompt |
|
|
- The refiner will enhance details and quality |
|
|
- Works best with images generated by HunyuanImage |
|
|
""", |
|
|
elem_classes="model-info" |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Launch HunyuanImage Gradio App") |
|
|
parser.add_argument("--no-auto-load", action="store_true", help="Disable auto-loading pipeline on startup") |
|
|
parser.add_argument("--use-distilled", action="store_true", help="Use distilled model") |
|
|
parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda/cpu)") |
|
|
parser.add_argument("--port", type=int, default=8081, help="Port to run the app on") |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
auto_load = not args.no_auto_load |
|
|
demo = create_interface(auto_load=auto_load, use_distilled=args.use_distilled, device=args.device) |
|
|
|
|
|
print("π Starting HunyuanImage Gradio App...") |
|
|
print(f"π§ Auto-load pipeline: {'Yes' if auto_load else 'No'}") |
|
|
print(f"π― Model type: {'Distilled' if args.use_distilled else 'Non-distilled'}") |
|
|
print(f"π» Device: {args.device}") |
|
|
print("β οΈ Make sure you have the required model checkpoints downloaded!") |
|
|
|
|
|
demo.launch( |
|
|
server_name=args.host, |
|
|
|
|
|
share=False, |
|
|
show_error=True, |
|
|
quiet=False, |
|
|
max_threads=1, |
|
|
|
|
|
) |
|
|
|