Baigfft commited on
Commit
1d8d7e0
Β·
verified Β·
1 Parent(s): 6c67805

Upload 2 files

Browse files
Files changed (2) hide show
  1. app (1).py +698 -0
  2. requirements (1).txt +17 -0
app (1).py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ IS_SPACE = True
3
+
4
+ if IS_SPACE:
5
+ import spaces
6
+
7
+
8
+ import sys
9
+ import warnings
10
+ import subprocess
11
+ from pathlib import Path
12
+ from typing import Optional, Tuple, Dict
13
+ import torch
14
+
15
+ def space_context(duration: int):
16
+ if IS_SPACE:
17
+ return spaces.GPU(duration=duration)
18
+ return lambda x: x
19
+
20
+ @space_context(duration=120)
21
+ def test_env():
22
+ assert torch.cuda.is_available()
23
+
24
+ try:
25
+ import flash_attn
26
+ except ImportError:
27
+ print("Flash-attn not found, installing...")
28
+ os.system("pip install flash-attn==2.7.3 --no-build-isolation")
29
+
30
+ else:
31
+ print("Flash-attn found, skipping installation...")
32
+ test_env()
33
+
34
+ warnings.filterwarnings("ignore")
35
+
36
+ # Add the current directory to Python path
37
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
38
+
39
+ try:
40
+ import gradio as gr
41
+ from PIL import Image
42
+ from hyimage.diffusion.pipelines.hunyuanimage_pipeline import HunyuanImagePipeline
43
+ from huggingface_hub import snapshot_download
44
+ import modelscope
45
+ except ImportError as e:
46
+ print(f"Missing required dependencies: {e}")
47
+ print("Please install with: pip install -r requirements_gradio.txt")
48
+ print("For checkpoint downloads, also install: pip install -U 'huggingface_hub[cli]' modelscope")
49
+ sys.exit(1)
50
+
51
+
52
+ BASE_DIR = os.environ.get('HUNYUANIMAGE_V2_1_MODEL_ROOT', './ckpts')
53
+
54
+ class CheckpointDownloader:
55
+ """Handles downloading of all required checkpoints for HunyuanImage."""
56
+
57
+ def __init__(self, base_dir: str = BASE_DIR):
58
+ self.base_dir = Path(base_dir)
59
+ self.base_dir.mkdir(exist_ok=True)
60
+ print(f'Downloading checkpoints to: {self.base_dir}')
61
+
62
+ # Define all required checkpoints
63
+ self.checkpoints = {
64
+ "main_model": {
65
+ "repo_id": "tencent/HunyuanImage-2.1",
66
+ "local_dir": self.base_dir,
67
+ },
68
+ "mllm_encoder": {
69
+ "repo_id": "Qwen/Qwen2.5-VL-7B-Instruct",
70
+ "local_dir": self.base_dir / "text_encoder" / "llm",
71
+ },
72
+ "byt5_encoder": {
73
+ "repo_id": "google/byt5-small",
74
+ "local_dir": self.base_dir / "text_encoder" / "byt5-small",
75
+ },
76
+ "glyph_encoder": {
77
+ "repo_id": "AI-ModelScope/Glyph-SDXL-v2",
78
+ "local_dir": self.base_dir / "text_encoder" / "Glyph-SDXL-v2",
79
+ "use_modelscope": True
80
+ }
81
+ }
82
+
83
+ def download_checkpoint(self, checkpoint_name: str, progress_callback=None) -> Tuple[bool, str]:
84
+ """Download a specific checkpoint."""
85
+ if checkpoint_name not in self.checkpoints:
86
+ return False, f"Unknown checkpoint: {checkpoint_name}"
87
+
88
+ config = self.checkpoints[checkpoint_name]
89
+ local_dir = config["local_dir"]
90
+ local_dir.mkdir(parents=True, exist_ok=True)
91
+
92
+ try:
93
+ if config.get("use_modelscope", False):
94
+ # Use modelscope for Chinese models
95
+ return self._download_with_modelscope(config, progress_callback)
96
+ else:
97
+ # Use huggingface_hub for other models
98
+ return self._download_with_hf(config, progress_callback)
99
+ except Exception as e:
100
+ return False, f"Download failed: {str(e)}"
101
+
102
+ def _download_with_hf(self, config: Dict, progress_callback=None) -> Tuple[bool, str]:
103
+ """Download using huggingface_hub."""
104
+ repo_id = config["repo_id"]
105
+ local_dir = config["local_dir"]
106
+
107
+ if progress_callback:
108
+ progress_callback(f"Downloading {repo_id}...")
109
+
110
+ try:
111
+ snapshot_download(
112
+ repo_id=repo_id,
113
+ local_dir=str(local_dir),
114
+ local_dir_use_symlinks=False,
115
+ resume_download=True
116
+ )
117
+ return True, f"Successfully downloaded {repo_id}"
118
+ except Exception as e:
119
+ return False, f"HF download failed: {str(e)}"
120
+
121
+ def _download_with_modelscope(self, config: Dict, progress_callback=None) -> Tuple[bool, str]:
122
+ """Download using modelscope."""
123
+ repo_id = config["repo_id"]
124
+ local_dir = config["local_dir"]
125
+
126
+ if progress_callback:
127
+ progress_callback(f"Downloading {repo_id} via ModelScope...")
128
+ print(f"Downloading {repo_id} via ModelScope...")
129
+
130
+ try:
131
+ # Use subprocess to call modelscope CLI
132
+ cmd = [
133
+ "modelscope", "download",
134
+ "--model", repo_id,
135
+ "--local_dir", str(local_dir)
136
+ ]
137
+
138
+ subprocess.run(cmd, capture_output=True, text=True, check=True)
139
+ return True, f"Successfully downloaded {repo_id} via ModelScope"
140
+ except subprocess.CalledProcessError as e:
141
+ return False, f"ModelScope download failed: {e.stderr}"
142
+ except FileNotFoundError:
143
+ return False, "ModelScope CLI not found. Install with: pip install modelscope"
144
+
145
+ def download_all_checkpoints(self, progress_callback=None) -> Tuple[bool, str, Dict[str, any]]:
146
+ """Download all checkpoints."""
147
+ results = {}
148
+ for name, _ in self.checkpoints.items():
149
+ if progress_callback:
150
+ progress_callback(f"Starting download of {name}...")
151
+
152
+ success, message = self.download_checkpoint(name, progress_callback)
153
+ results[name] = {"success": success, "message": message}
154
+
155
+ if not success:
156
+ return False, f"Failed to download {name}: {message}", results
157
+ return True, "All checkpoints downloaded successfully", results
158
+
159
+
160
+ @space_context(duration=2000)
161
+ def load_pipeline(use_distilled: bool = False, device: str = "cuda"):
162
+ """Load the HunyuanImage pipeline (only load once, refiner and reprompt are accessed from it)."""
163
+ try:
164
+ assert not use_distilled # use_distilled is a placeholder for the future
165
+
166
+ print(f"Loading HunyuanImage pipeline (distilled={use_distilled})...")
167
+ model_name = "hunyuanimage-v2.1-distilled" if use_distilled else "hunyuanimage-v2.1"
168
+ pipeline = HunyuanImagePipeline.from_pretrained(
169
+ model_name=model_name,
170
+ device=device,
171
+ enable_dit_offloading=True,
172
+ enable_reprompt_model_offloading=True,
173
+ enable_refiner_offloading=True
174
+ )
175
+ pipeline.to('cpu')
176
+ refiner_pipeline = pipeline.refiner_pipeline
177
+ refiner_pipeline.text_encoder.model = pipeline.text_encoder.model
178
+ refiner_pipeline.to('cpu')
179
+ reprompt_model = pipeline.reprompt_model
180
+
181
+ print("βœ“ Pipeline loaded successfully")
182
+ return pipeline
183
+ except Exception as e:
184
+ error_msg = f"Error loading pipeline: {str(e)}"
185
+ print(f"βœ— {error_msg}")
186
+ raise
187
+
188
+
189
+ # if IS_SPACE:
190
+ # downloader = CheckpointDownloader()
191
+ # downloader.download_all_checkpoints()
192
+
193
+ pipeline = load_pipeline(use_distilled=False, device="cuda")
194
+ class HunyuanImageApp:
195
+
196
+ @space_context(duration=290)
197
+ def __init__(self, auto_load: bool = True, use_distilled: bool = False, device: str = "cuda"):
198
+ """Initialize the HunyuanImage Gradio app."""
199
+ global pipeline
200
+
201
+ self.pipeline = pipeline
202
+ self.current_use_distilled = None
203
+
204
+ # Define aspect ratio mappings
205
+ self.aspect_ratio_mappings = {
206
+ "16:9": (2560, 1536),
207
+ "4:3": (2304, 1792),
208
+ "1:1": (2048, 2048),
209
+ "3:4": (1792, 2304),
210
+ "9:16": (1536, 2560)
211
+ }
212
+
213
+
214
+ def print_peak_memory(self):
215
+ import torch
216
+ stats = torch.cuda.memory_stats()
217
+ peak_bytes_requirement = stats["allocated_bytes.all.peak"]
218
+ print(f"Before refiner Peak memory requirement: {peak_bytes_requirement / 1024 ** 3:.2f} GB")
219
+
220
+ def update_resolution(self, aspect_ratio_choice: str) -> Tuple[int, int]:
221
+ """Update width and height based on selected aspect ratio."""
222
+ # Extract the aspect ratio key from the choice (e.g., "16:9" from "16:9 (2560Γ—1536)")
223
+ aspect_key = aspect_ratio_choice.split(" (")[0]
224
+ if aspect_key in self.aspect_ratio_mappings:
225
+ return self.aspect_ratio_mappings[aspect_key]
226
+ else:
227
+ # Default to 1:1 if not found
228
+ return self.aspect_ratio_mappings["1:1"]
229
+
230
+ @space_context(duration=300)
231
+ def generate_image(self,
232
+ prompt: str,
233
+ negative_prompt: str,
234
+ width: int,
235
+ height: int,
236
+ num_inference_steps: int,
237
+ guidance_scale: float,
238
+ seed: int,
239
+ use_reprompt: bool,
240
+ use_refiner: bool,
241
+ # use_distilled: bool
242
+ ) -> Tuple[Optional[Image.Image], str]:
243
+ """Generate an image using the HunyuanImage pipeline."""
244
+ try:
245
+ torch.cuda.empty_cache()
246
+
247
+ if self.pipeline is None:
248
+ return None, "Pipeline not loaded. Please try again."
249
+
250
+
251
+ if hasattr(self.pipeline, '_refiner_pipeline'):
252
+ self.pipeline.refiner_pipeline.to('cpu')
253
+ self.pipeline.to('cuda')
254
+ if seed == -1:
255
+ import random
256
+ seed = random.randint(100000, 999999)
257
+
258
+ # Generate image
259
+ image = self.pipeline(
260
+ prompt=prompt,
261
+ negative_prompt=negative_prompt,
262
+ width=width,
263
+ height=height,
264
+ num_inference_steps=num_inference_steps,
265
+ guidance_scale=guidance_scale,
266
+ seed=seed,
267
+ shift=5,
268
+ use_reprompt=use_reprompt,
269
+ use_refiner=use_refiner
270
+ )
271
+ self.print_peak_memory()
272
+ return image, "Image generated successfully!"
273
+
274
+ except Exception as e:
275
+ error_msg = f"Error generating image: {str(e)}"
276
+ print(f"βœ— {error_msg}")
277
+ return None, error_msg
278
+
279
+ @space_context(duration=300)
280
+ def enhance_prompt(self, prompt: str, # use_distilled: bool
281
+ ) -> Tuple[str, str]:
282
+ """Enhance a prompt using the reprompt model."""
283
+ try:
284
+ torch.cuda.empty_cache()
285
+
286
+ # Load pipeline if needed
287
+ if self.pipeline is None:
288
+ return prompt, "Pipeline not loaded. Please try again."
289
+
290
+ self.pipeline.to('cpu')
291
+ if hasattr(self.pipeline, '_refiner_pipeline'):
292
+ self.pipeline.refiner_pipeline.to('cpu')
293
+
294
+ # Use reprompt model from the main pipeline
295
+ enhanced_prompt = self.pipeline.reprompt_model.predict(prompt)
296
+ self.print_peak_memory()
297
+ return enhanced_prompt, "Prompt enhanced successfully!"
298
+
299
+ except Exception as e:
300
+ error_msg = f"Error enhancing prompt: {str(e)}"
301
+ print(f"βœ— {error_msg}")
302
+ return prompt, error_msg
303
+
304
+ @space_context(duration=300)
305
+ def refine_image(self,
306
+ image: Image.Image,
307
+ prompt: str,
308
+ width: int,
309
+ height: int,
310
+ num_inference_steps: int,
311
+ guidance_scale: float,
312
+ seed: int) -> Tuple[Optional[Image.Image], str]:
313
+ """Refine an image using the refiner pipeline."""
314
+ try:
315
+ if image is None:
316
+ return None, "Please upload an image to refine."
317
+
318
+ if not prompt or prompt.strip() == "":
319
+ return None, "Please enter a refinement prompt."
320
+
321
+ torch.cuda.empty_cache()
322
+
323
+ # Resize image to target dimensions if needed
324
+ if image.size != (width, height):
325
+ image = image.resize((width, height), Image.Resampling.LANCZOS)
326
+
327
+ self.pipeline.to('cpu')
328
+ self.pipeline.refiner_pipeline.to('cuda')
329
+ if seed == -1:
330
+ import random
331
+ seed = random.randint(100000, 999999)
332
+
333
+ # Use refiner from the main pipeline
334
+ refined_image = self.pipeline.refiner_pipeline(
335
+ image=image,
336
+ prompt=prompt,
337
+ width=width,
338
+ height=height,
339
+ num_inference_steps=num_inference_steps,
340
+ guidance_scale=guidance_scale,
341
+ seed=seed
342
+ )
343
+ self.print_peak_memory()
344
+ return refined_image, "Image refined successfully!"
345
+
346
+ except Exception as e:
347
+ error_msg = f"Error refining image: {str(e)}"
348
+ print(f"βœ— {error_msg}")
349
+ return None, error_msg
350
+
351
+
352
+ def download_single_checkpoint(self, checkpoint_name: str) -> Tuple[bool, str]:
353
+ """Download a single checkpoint."""
354
+ try:
355
+ success, message = self.downloader.download_checkpoint(checkpoint_name)
356
+ return success, message
357
+ except Exception as e:
358
+ return False, f"Download error: {str(e)}"
359
+
360
+ def download_all_checkpoints(self) -> Tuple[bool, str, Dict[str, any]]:
361
+ """Download all missing checkpoints."""
362
+ try:
363
+ success, message, results = self.downloader.download_all_checkpoints()
364
+ return success, message, results
365
+ except Exception as e:
366
+ return False, f"Download error: {str(e)}", {}
367
+
368
+ def create_interface(auto_load: bool = True, use_distilled: bool = False, device: str = "cuda"):
369
+ """Create the Gradio interface."""
370
+ app = HunyuanImageApp(auto_load=auto_load, use_distilled=use_distilled, device=device)
371
+
372
+ # Custom CSS for better styling with dark mode support
373
+ css = """
374
+ .gradio-container {
375
+ max-width: 1200px !important;
376
+ margin: auto !important;
377
+ }
378
+ .tab-nav {
379
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
380
+ border-radius: 10px;
381
+ padding: 10px;
382
+ margin-bottom: 20px;
383
+ }
384
+ .model-info {
385
+ background: var(--background-fill-secondary);
386
+ border: 1px solid var(--border-color-primary);
387
+ border-radius: 8px;
388
+ padding: 15px;
389
+ margin-bottom: 20px;
390
+ color: var(--body-text-color);
391
+ }
392
+ .model-info h1, .model-info h2, .model-info h3 {
393
+ color: var(--body-text-color) !important;
394
+ }
395
+ .model-info p, .model-info li {
396
+ color: var(--body-text-color) !important;
397
+ }
398
+ .model-info strong {
399
+ color: var(--body-text-color) !important;
400
+ }
401
+ """
402
+
403
+ with gr.Blocks(css=css, title="HunyuanImage Pipeline", theme=gr.themes.Soft()) as demo:
404
+ gr.Markdown(
405
+ """
406
+ # 🎨 HunyuanImage 2.1 Pipeline
407
+ **HunyuanImage-2.1: An Efficient Diffusion Model for High-Resolution (2K) Text-to-Image Generation​**
408
+
409
+ This app provides three main functionalities:
410
+ 1. **Text-to-Image Generation**: Generate high-quality images from text prompts
411
+ 2. **Prompt Enhancement**: Improve your prompts using MLLM reprompting
412
+ 3. **Image Refinement**: Enhance existing images with the refiner model
413
+ """,
414
+ elem_classes="model-info"
415
+ )
416
+
417
+ with gr.Tabs():
418
+ # Tab 1: Text-to-Image Generation
419
+ with gr.Tab("πŸ–ΌοΈ Text-to-Image Generation"):
420
+ with gr.Row():
421
+ with gr.Column(scale=1):
422
+ gr.Markdown("### Generation Settings")
423
+ gr.Markdown("**Model**: HunyuanImage v2.1 (Non-distilled)")
424
+
425
+ # use_distilled = gr.Checkbox(
426
+ # label="Use Distilled Model",
427
+ # value=False,
428
+ # info="Faster generation with slightly lower quality"
429
+ # )
430
+ use_distilled = False
431
+
432
+ prompt = gr.Textbox(
433
+ label="Prompt",
434
+ placeholder="",
435
+ lines=3,
436
+ value="A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, wearing a red knitted scarf and a red beret with the word β€œTencent” on it, holding a paintbrush with a focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
437
+ )
438
+
439
+ negative_prompt = gr.Textbox(
440
+ label="Negative Prompt",
441
+ placeholder="",
442
+ lines=2,
443
+ value=""
444
+ )
445
+
446
+ # Predefined aspect ratios
447
+ aspect_ratios = [
448
+ ("16:9 (2560Γ—1536)", "16:9"),
449
+ ("4:3 (2304Γ—1792)", "4:3"),
450
+ ("1:1 (2048Γ—2048)", "1:1"),
451
+ ("3:4 (1792Γ—2304)", "3:4"),
452
+ ("9:16 (1536Γ—2560)", "9:16")
453
+ ]
454
+
455
+ aspect_ratio = gr.Radio(
456
+ choices=aspect_ratios,
457
+ value="1:1",
458
+ label="Aspect Ratio",
459
+ info="Select the aspect ratio for image generation"
460
+ )
461
+
462
+ # Hidden width and height inputs that get updated based on aspect ratio
463
+ width = gr.Number(value=2048, visible=False)
464
+ height = gr.Number(value=2048, visible=False)
465
+
466
+ with gr.Row():
467
+ num_inference_steps = gr.Slider(
468
+ minimum=10, maximum=100, step=5, value=50,
469
+ label="Inference Steps", info="More steps = better quality, slower generation"
470
+ )
471
+ guidance_scale = gr.Slider(
472
+ minimum=1.0, maximum=10.0, step=0.1, value=3.5,
473
+ label="Guidance Scale", info="How closely to follow the prompt"
474
+ )
475
+
476
+ with gr.Row():
477
+ seed = gr.Number(
478
+ label="Seed", value=-1, precision=0,
479
+ info="Random seed for reproducibility. (-1 for random seed)"
480
+ )
481
+ use_reprompt = gr.Checkbox(
482
+ label="Use Reprompt", value=True,
483
+ info="Enhance prompt automatically"
484
+ )
485
+ use_refiner = gr.Checkbox(
486
+ label="Use Refiner", value=True,
487
+ info="Apply refiner after generation ",
488
+ interactive=True
489
+ )
490
+
491
+ generate_btn = gr.Button("🎨 Generate Image", variant="primary", size="lg")
492
+
493
+ with gr.Column(scale=1):
494
+ gr.Markdown("### Generated Image")
495
+ generated_image = gr.Image(
496
+ label="Generated Image",
497
+ format="png",
498
+ show_download_button=True,
499
+ type="pil",
500
+ height=600
501
+ )
502
+ generation_status = gr.Textbox(
503
+ label="Status",
504
+ interactive=False,
505
+ value="Ready to generate"
506
+ )
507
+
508
+ # Tab 2: Prompt Enhancement
509
+ with gr.Tab("✨ Prompt Enhancement"):
510
+ with gr.Row():
511
+ with gr.Column(scale=1):
512
+ gr.Markdown("### Prompt Enhancement Settings")
513
+ gr.Markdown("**Model**: HunyuanImage v2.1 Reprompt Model")
514
+
515
+ # enhance_use_distilled = gr.Checkbox(
516
+ # label="Use Distilled Model",
517
+ # value=False,
518
+ # info="For loading the reprompt model"
519
+ # )
520
+ enhance_use_distilled = False
521
+
522
+ original_prompt = gr.Textbox(
523
+ label="Original Prompt",
524
+ placeholder="A cat sitting on a table",
525
+ lines=4,
526
+ value="A cat sitting on a table"
527
+ )
528
+
529
+ enhance_btn = gr.Button("✨ Enhance Prompt", variant="primary", size="lg")
530
+
531
+ with gr.Column(scale=1):
532
+ gr.Markdown("### Enhanced Prompt")
533
+ enhanced_prompt = gr.Textbox(
534
+ label="Enhanced Prompt",
535
+ lines=6,
536
+ interactive=False
537
+ )
538
+ enhancement_status = gr.Textbox(
539
+ label="Status",
540
+ interactive=False,
541
+ value="Ready to enhance"
542
+ )
543
+
544
+ # # Tab 3: Image Refinement
545
+ with gr.Tab("πŸ”§ Image Refinement"):
546
+ with gr.Row():
547
+ with gr.Column(scale=1):
548
+ gr.Markdown("### Refinement Settings")
549
+ gr.Markdown("**Model**: HunyuanImage v2.1 Refiner")
550
+
551
+ input_image = gr.Image(
552
+ label="Input Image",
553
+ type="pil",
554
+ height=300
555
+ )
556
+
557
+ refine_prompt = gr.Textbox(
558
+ label="Refinement Prompt",
559
+ placeholder="Image description",
560
+ info="This prompt should describe the image content.",
561
+ lines=2,
562
+ value=""
563
+ )
564
+
565
+ with gr.Row():
566
+ refine_width = gr.Slider(
567
+ minimum=512, maximum=2048, step=64, value=2048,
568
+ label="Width", info="Output width"
569
+ )
570
+ refine_height = gr.Slider(
571
+ minimum=512, maximum=2048, step=64, value=2048,
572
+ label="Height", info="Output height"
573
+ )
574
+
575
+ with gr.Row():
576
+ refine_steps = gr.Slider(
577
+ minimum=1, maximum=20, step=1, value=4,
578
+ label="Refinement Steps", info="More steps = more refinement"
579
+ )
580
+ refine_guidance = gr.Slider(
581
+ minimum=1.0, maximum=10.0, step=0.1, value=3.5,
582
+ label="Guidance Scale", info="How strongly to follow the prompt"
583
+ )
584
+
585
+ refine_seed = gr.Number(
586
+ label="Seed", value=-1, precision=0,
587
+ info="Random seed for reproducibility"
588
+ )
589
+
590
+ refine_btn = gr.Button("πŸ”§ Refine Image", variant="primary", size="lg")
591
+
592
+ with gr.Column(scale=1):
593
+ gr.Markdown("### Refined Image")
594
+ refined_image = gr.Image(
595
+ label="Refined Image",
596
+ type="pil",
597
+ format="png",
598
+ show_download_button=True,
599
+ height=600
600
+ )
601
+ refinement_status = gr.Textbox(
602
+ label="Status",
603
+ interactive=False,
604
+ value="Ready to refine"
605
+ )
606
+
607
+ # Event handlers
608
+ # Update width and height when aspect ratio changes
609
+ aspect_ratio.change(
610
+ fn=app.update_resolution,
611
+ inputs=[aspect_ratio],
612
+ outputs=[width, height]
613
+ )
614
+
615
+ generate_btn.click(
616
+ fn=app.generate_image,
617
+ inputs=[
618
+ prompt, negative_prompt, width, height, num_inference_steps,
619
+ guidance_scale, seed, use_reprompt, use_refiner # , use_distilled
620
+ ],
621
+ outputs=[generated_image, generation_status]
622
+ )
623
+
624
+ enhance_btn.click(
625
+ fn=app.enhance_prompt,
626
+ inputs=[original_prompt],
627
+ outputs=[enhanced_prompt, enhancement_status]
628
+ )
629
+
630
+ refine_btn.click(
631
+ fn=app.refine_image,
632
+ inputs=[
633
+ input_image, refine_prompt,
634
+ refine_width, refine_height, refine_steps, refine_guidance, refine_seed
635
+ ],
636
+ outputs=[refined_image, refinement_status]
637
+ )
638
+
639
+ # Additional info
640
+ gr.Markdown(
641
+ """
642
+ ### πŸ“ Usage Tips
643
+
644
+ **Text-to-Image Generation:**
645
+ - Use descriptive prompts with specific details
646
+ - Adjust guidance scale: higher values follow prompts more closely
647
+ - More inference steps generally produce better quality
648
+ - Enable reprompt for automatic prompt enhancement
649
+ - Enable refiner for additional quality improvement
650
+
651
+ **Prompt Enhancement:**
652
+ - Enter your basic prompt idea
653
+ - The AI will enhance it with better structure and details
654
+ - Enhanced prompts often produce better results
655
+
656
+ **Image Refinement:**
657
+ - Upload any image you want to improve
658
+ - Describe what improvements you want in the refinement prompt
659
+ - The refiner will enhance details and quality
660
+ - Works best with images generated by HunyuanImage
661
+ """,
662
+ elem_classes="model-info"
663
+ )
664
+
665
+ return demo
666
+
667
+ if __name__ == "__main__":
668
+ import argparse
669
+
670
+ # Parse command line arguments
671
+ parser = argparse.ArgumentParser(description="Launch HunyuanImage Gradio App")
672
+ parser.add_argument("--no-auto-load", action="store_true", help="Disable auto-loading pipeline on startup")
673
+ parser.add_argument("--use-distilled", action="store_true", help="Use distilled model")
674
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda/cpu)")
675
+ parser.add_argument("--port", type=int, default=8081, help="Port to run the app on")
676
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
677
+
678
+ args = parser.parse_args()
679
+
680
+ # Create and launch the interface
681
+ auto_load = not args.no_auto_load
682
+ demo = create_interface(auto_load=auto_load, use_distilled=args.use_distilled, device=args.device)
683
+
684
+ print("πŸš€ Starting HunyuanImage Gradio App...")
685
+ print(f"πŸ”§ Auto-load pipeline: {'Yes' if auto_load else 'No'}")
686
+ print(f"🎯 Model type: {'Distilled' if args.use_distilled else 'Non-distilled'}")
687
+ print(f"πŸ’» Device: {args.device}")
688
+ print("⚠️ Make sure you have the required model checkpoints downloaded!")
689
+
690
+ demo.launch(
691
+ server_name=args.host,
692
+ # server_port=args.port,
693
+ share=False,
694
+ show_error=True,
695
+ quiet=False,
696
+ max_threads=1, # Default: sequential processing (recommended for GPU apps)
697
+ # max_threads=4, # Enable parallel processing (requires more GPU memory)
698
+ )
requirements (1).txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm==4.67.1
2
+ torch>=2.6.0
3
+ einops==0.8.0
4
+ loguru==0.7.3
5
+ numpy==1.26.4
6
+ pillow==11.3.0
7
+ omegaconf>=2.3.0
8
+ torchaudio==2.6.0
9
+ diffusers>=0.32.0
10
+ safetensors==0.4.5
11
+ torchvision==0.21.0
12
+ huggingface-hub==0.34.0
13
+ transformers[accelerate,tiktoken]==4.56.0
14
+ wheel
15
+ setuptools
16
+ modelscope
17
+ huggingface_hub[cli]