multimodalart HF Staff commited on
Commit
a3f5a50
·
verified ·
1 Parent(s): 1c1110e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +478 -0
app.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import spaces
6
+
7
+ from PIL import Image
8
+ from diffusers import FlowMatchEulerDiscreteScheduler
9
+ from optimization import optimize_pipeline_
10
+ from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
11
+ from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
12
+ from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
13
+
14
+ import math
15
+ from huggingface_hub import hf_hub_download
16
+ from safetensors.torch import load_file
17
+
18
+ from PIL import Image
19
+ import os
20
+ import gradio as gr
21
+ from gradio_client import Client, handle_file
22
+ import tempfile
23
+ from huggingface_hub import InferenceClient
24
+
25
+
26
+ # --- Model Loading ---
27
+ dtype = torch.bfloat16
28
+ device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509",
31
+ transformer= QwenImageTransformer2DModel.from_pretrained("linoyts/Qwen-Image-Edit-Rapid-AIO",
32
+ subfolder='transformer',
33
+ torch_dtype=dtype,
34
+ device_map='cuda'),torch_dtype=dtype).to(device)
35
+
36
+ # Load the relight LoRA
37
+ pipe.load_lora_weights(
38
+ "dx8152/Qwen-Image-Edit-2509-Relight",
39
+ weight_name="Qwen-Edit-Relight.safetensors", adapter_name="relight"
40
+ )
41
+
42
+ pipe.set_adapters(["relight"], adapter_weights=[1.])
43
+ pipe.fuse_lora(adapter_names=["relight"], lora_scale=1.25)
44
+ pipe.unload_lora_weights()
45
+
46
+ pipe.transformer.__class__ = QwenImageTransformer2DModel
47
+ pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
48
+
49
+ optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], prompt="prompt")
50
+
51
+
52
+ MAX_SEED = np.iinfo(np.int32).max
53
+
54
+ # Initialize translation client
55
+ translation_client = InferenceClient(
56
+ api_key=os.environ.get("HF_TOKEN"),
57
+ )
58
+
59
+ def translate_to_chinese(text: str) -> str:
60
+ """Translate any language text to Chinese using Qwen API."""
61
+ if not text or not text.strip():
62
+ return ""
63
+
64
+ # Check if text is already primarily Chinese
65
+ chinese_chars = sum(1 for char in text if '\u4e00' <= char <= '\u9fff')
66
+ if chinese_chars / max(len(text), 1) > 0.5:
67
+ # Already mostly Chinese, return as is
68
+ return text
69
+
70
+ try:
71
+ completion = translation_client.chat.completions.create(
72
+ model="Qwen/Qwen3-Next-80B-A3B-Instruct:novita",
73
+ messages=[
74
+ {
75
+ "role": "system",
76
+ "content": "You are a professional translator. Translate the user's text to Chinese. Only output the translated text, nothing else."
77
+ },
78
+ {
79
+ "role": "user",
80
+ "content": f"Translate this to Chinese: {text}"
81
+ }
82
+ ],
83
+ max_tokens=500,
84
+ )
85
+
86
+ translated = completion.choices[0].message.content.strip()
87
+ print(f"Translated '{text}' to '{translated}'")
88
+ return translated
89
+ except Exception as e:
90
+ print(f"Translation error: {e}")
91
+ # Fallback to original text if translation fails
92
+ return text
93
+
94
+ def _generate_video_segment(input_image_path: str, output_image_path: str, prompt: str, request: gr.Request) -> str:
95
+ """Generates a single video segment using the external service."""
96
+ x_ip_token = request.headers['x-ip-token']
97
+ video_client = Client("multimodalart/wan-2-2-first-last-frame", headers={"x-ip-token": x_ip_token})
98
+ result = video_client.predict(
99
+ start_image_pil=handle_file(input_image_path),
100
+ end_image_pil=handle_file(output_image_path),
101
+ prompt=prompt, api_name="/generate_video",
102
+ )
103
+ return result[0]["video"]
104
+
105
+ def build_relight_prompt(light_type, light_direction, light_intensity, custom_prompt, user_prompt):
106
+ """Build the relighting prompt based on user selections."""
107
+
108
+ # Priority 1: User's own prompt (translated to Chinese)
109
+ if user_prompt and user_prompt.strip():
110
+ translated = translate_to_chinese(user_prompt)
111
+ # Add trigger word if not already present
112
+ if "重新照明" not in translated:
113
+ return f"重新照明,{translated}"
114
+ return translated
115
+
116
+ # Priority 2: Custom prompt field
117
+ if custom_prompt and custom_prompt.strip():
118
+ return f"重新照明,{custom_prompt}"
119
+
120
+ # Priority 3: Build from controls
121
+ prompt_parts = ["重新照明"]
122
+
123
+ # Light type descriptions
124
+ light_descriptions = {
125
+ "soft_window": "使用窗帘透光(柔和漫射)的光线", # Soft diffuse light from curtains
126
+ "golden_hour": "使用金色黄昏的温暖光线", # Warm golden hour light
127
+ "studio": "使用专业摄影棚的均匀光线", # Professional studio lighting
128
+ "dramatic": "使用戏剧性的高对比度光线", # Dramatic high-contrast lighting
129
+ "natural": "使用自然日光", # Natural daylight
130
+ "neon": "使用霓虹灯光效果", # Neon lighting effect
131
+ "candlelight": "使用烛光的温暖氛围", # Warm candlelight ambiance
132
+ "moonlight": "使用月光的冷色调", # Cool-toned moonlight
133
+ }
134
+
135
+ # Direction descriptions
136
+ direction_descriptions = {
137
+ "front": "从正面照射", # From the front
138
+ "side": "从侧面照射", # From the side
139
+ "back": "从背后照射", # From behind (backlight)
140
+ "top": "从上方照射", # From above
141
+ "bottom": "从下方照射", # From below
142
+ }
143
+
144
+ # Intensity descriptions
145
+ intensity_descriptions = {
146
+ "soft": "柔和强度", # Soft intensity
147
+ "medium": "中等强度", # Medium intensity
148
+ "strong": "强烈强度", # Strong intensity
149
+ }
150
+
151
+ # Build the prompt
152
+ if light_type != "none":
153
+ prompt_parts.append(light_descriptions.get(light_type, ""))
154
+
155
+ if light_direction != "none":
156
+ prompt_parts.append(direction_descriptions.get(light_direction, ""))
157
+
158
+ if light_intensity != "none":
159
+ prompt_parts.append(intensity_descriptions.get(light_intensity, ""))
160
+
161
+ final_prompt = ",".join([p for p in prompt_parts if p])
162
+
163
+ # Add instruction if we have settings
164
+ if len(prompt_parts) > 1:
165
+ final_prompt += "对图片进行重新照明" # Relight the image
166
+
167
+ return final_prompt if len(prompt_parts) > 1 else "重新照明,使用自然光线对图片进行重新照明"
168
+
169
+
170
+ @spaces.GPU
171
+ def infer_relight(
172
+ image,
173
+ light_type,
174
+ light_direction,
175
+ light_intensity,
176
+ custom_prompt,
177
+ user_prompt,
178
+ seed,
179
+ randomize_seed,
180
+ true_guidance_scale,
181
+ num_inference_steps,
182
+ height,
183
+ width,
184
+ prev_output = None,
185
+ progress=gr.Progress(track_tqdm=True)
186
+ ):
187
+ prompt = build_relight_prompt(light_type, light_direction, light_intensity, custom_prompt, user_prompt)
188
+ print(f"Generated Prompt: {prompt}")
189
+
190
+ if randomize_seed:
191
+ seed = random.randint(0, MAX_SEED)
192
+ generator = torch.Generator(device=device).manual_seed(seed)
193
+
194
+ # Choose input image (prefer uploaded, else last output)
195
+ pil_images = []
196
+ if image is not None:
197
+ if isinstance(image, Image.Image):
198
+ pil_images.append(image.convert("RGB"))
199
+ elif hasattr(image, "name"):
200
+ pil_images.append(Image.open(image.name).convert("RGB"))
201
+ elif prev_output:
202
+ pil_images.append(prev_output.convert("RGB"))
203
+
204
+ if len(pil_images) == 0:
205
+ raise gr.Error("Please upload an image first.")
206
+
207
+ result = pipe(
208
+ image=pil_images,
209
+ prompt=prompt,
210
+ height=height if height != 0 else None,
211
+ width=width if width != 0 else None,
212
+ num_inference_steps=num_inference_steps,
213
+ generator=generator,
214
+ true_cfg_scale=true_guidance_scale,
215
+ num_images_per_prompt=1,
216
+ ).images[0]
217
+
218
+ return result, seed, prompt
219
+
220
+ def create_video_between_images(input_image, output_image, prompt: str, request: gr.Request) -> str:
221
+ """Create a video between the input and output images."""
222
+ if input_image is None or output_image is None:
223
+ raise gr.Error("Both input and output images are required to create a video.")
224
+
225
+ try:
226
+
227
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
228
+ input_image.save(tmp.name)
229
+ input_image_path = tmp.name
230
+
231
+ output_pil = Image.fromarray(output_image.astype('uint8'))
232
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
233
+ output_pil.save(tmp.name)
234
+ output_image_path = tmp.name
235
+
236
+ video_path = _generate_video_segment(
237
+ input_image_path,
238
+ output_image_path,
239
+ prompt if prompt else "Relighting transformation",
240
+ request
241
+ )
242
+ return video_path
243
+ except Exception as e:
244
+ raise gr.Error(f"Video generation failed: {e}")
245
+
246
+
247
+ # --- UI ---
248
+ css = '''#col-container { max-width: 800px; margin: 0 auto; }
249
+ .dark .progress-text{color: white !important}
250
+ #examples{max-width: 800px; margin: 0 auto; }'''
251
+
252
+ def reset_all():
253
+ return ["none", "none", "none", "", "", False, True]
254
+
255
+ def end_reset():
256
+ return False
257
+
258
+ def update_dimensions_on_upload(image):
259
+ if image is None:
260
+ return 1024, 1024
261
+
262
+ original_width, original_height = image.size
263
+
264
+ if original_width > original_height:
265
+ new_width = 1024
266
+ aspect_ratio = original_height / original_width
267
+ new_height = int(new_width * aspect_ratio)
268
+ else:
269
+ new_height = 1024
270
+ aspect_ratio = original_width / original_height
271
+ new_width = int(new_height * aspect_ratio)
272
+
273
+ # Ensure dimensions are multiples of 8
274
+ new_width = (new_width // 8) * 8
275
+ new_height = (new_height // 8) * 8
276
+
277
+ return new_width, new_height
278
+
279
+
280
+ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
281
+ with gr.Column(elem_id="col-container"):
282
+ gr.Markdown("## 💡 Qwen Image Edit — Relighting Control")
283
+ gr.Markdown("""
284
+ Qwen Image Edit 2509 for Image Relighting ✨
285
+ Using [dx8152's Qwen-Image-Edit-2509-Relight LoRA](https://huggingface.co/dx8152/Qwen-Image-Edit-2509-Relight) and [linoyts/Qwen-Image-Edit-Rapid-AIO](https://huggingface.co/linoyts/Qwen-Image-Edit-Rapid-AIO) for 4-step inference 💨
286
+
287
+ **Three ways to use:**
288
+ 1. 🌟 **Write your own prompt** in any language (automatically translated to Chinese)
289
+ 2. Use the preset lighting controls
290
+ 3. Write a custom Chinese prompt with the trigger word "重新照明"
291
+
292
+ Example: `Add dramatic sunset lighting from the left` or `使用窗帘透光(柔和漫射)的光线对图片进行重新照明`
293
+ """
294
+ )
295
+
296
+ with gr.Row():
297
+ with gr.Column():
298
+ image = gr.Image(label="Input Image", type="pil")
299
+ prev_output = gr.Image(value=None, visible=False)
300
+ is_reset = gr.Checkbox(value=False, visible=False)
301
+
302
+ # User's own prompt (highest priority)
303
+ with gr.Group():
304
+ gr.Markdown("### 🌟 Your Prompt (Any Language)")
305
+ user_prompt = gr.Textbox(
306
+ label="Describe the lighting you want",
307
+ placeholder="Example: 'Add warm sunset lighting from the right' or 'Make it look like it's lit by neon signs' or 'Add dramatic spotlight from above'",
308
+ lines=2,
309
+ info="Write in any language! It will be automatically translated to Chinese for the model."
310
+ )
311
+
312
+ with gr.Tab("Lighting Controls"):
313
+ light_type = gr.Dropdown(
314
+ label="Light Type",
315
+ choices=[
316
+ ("None", "none"),
317
+ ("Soft Window Light (柔和窗光)", "soft_window"),
318
+ ("Golden Hour (金色黄昏)", "golden_hour"),
319
+ ("Studio Lighting (摄影棚灯光)", "studio"),
320
+ ("Dramatic (戏剧性)", "dramatic"),
321
+ ("Natural Daylight (自然日光)", "natural"),
322
+ ("Neon (霓虹灯)", "neon"),
323
+ ("Candlelight (烛光)", "candlelight"),
324
+ ("Moonlight (月光)", "moonlight"),
325
+ ],
326
+ value="none"
327
+ )
328
+
329
+ light_direction = gr.Dropdown(
330
+ label="Light Direction",
331
+ choices=[
332
+ ("None", "none"),
333
+ ("Front (正面)", "front"),
334
+ ("Side (侧面)", "side"),
335
+ ("Back (背光)", "back"),
336
+ ("Top (上方)", "top"),
337
+ ("Bottom (下方)", "bottom"),
338
+ ],
339
+ value="none"
340
+ )
341
+
342
+ light_intensity = gr.Dropdown(
343
+ label="Light Intensity",
344
+ choices=[
345
+ ("None", "none"),
346
+ ("Soft (柔和)", "soft"),
347
+ ("Medium (中等)", "medium"),
348
+ ("Strong (强烈)", "strong"),
349
+ ],
350
+ value="none"
351
+ )
352
+
353
+ with gr.Tab("Custom Prompt"):
354
+ custom_prompt = gr.Textbox(
355
+ label="Custom Chinese Relighting Prompt (Optional)",
356
+ placeholder="Example: 使用窗帘透光(柔和漫射)的光线对图片进行重新照明\nLeave empty to use controls or user prompt above",
357
+ lines=3
358
+ )
359
+ gr.Markdown("*Note: This field is for Chinese prompts. The trigger word '重新照明' will be added automatically. If you entered text in 'Your Prompt' above, it takes priority.*")
360
+
361
+ with gr.Row():
362
+ reset_btn = gr.Button("Reset")
363
+ run_btn = gr.Button("Generate", variant="primary")
364
+
365
+ with gr.Accordion("Advanced Settings", open=False):
366
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
367
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
368
+ true_guidance_scale = gr.Slider(label="True Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
369
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=4)
370
+ height = gr.Slider(label="Height", minimum=256, maximum=2048, step=8, value=1024)
371
+ width = gr.Slider(label="Width", minimum=256, maximum=2048, step=8, value=1024)
372
+
373
+ with gr.Column():
374
+ result = gr.Image(label="Output Image", interactive=False)
375
+ prompt_preview = gr.Textbox(label="Processed Prompt", interactive=False)
376
+ create_video_button = gr.Button("🎥 Create Video Between Images", variant="secondary", visible=False)
377
+ with gr.Group(visible=False) as video_group:
378
+ video_output = gr.Video(label="Generated Video", show_download_button=True, autoplay=True)
379
+
380
+ inputs = [
381
+ image, light_type, light_direction, light_intensity, custom_prompt, user_prompt,
382
+ seed, randomize_seed, true_guidance_scale, num_inference_steps, height, width, prev_output
383
+ ]
384
+ outputs = [result, seed, prompt_preview]
385
+
386
+ # Reset behavior
387
+ reset_btn.click(
388
+ fn=reset_all,
389
+ inputs=None,
390
+ outputs=[light_type, light_direction, light_intensity, custom_prompt, user_prompt, is_reset],
391
+ queue=False
392
+ ).then(fn=end_reset, inputs=None, outputs=[is_reset], queue=False)
393
+
394
+ # Manual generation with video button visibility control
395
+ def infer_and_show_video_button(*args):
396
+ result_img, result_seed, result_prompt = infer_relight(*args)
397
+ # Show video button if we have both input and output images
398
+ show_button = args[0] is not None and result_img is not None
399
+ return result_img, result_seed, result_prompt, gr.update(visible=show_button)
400
+
401
+ run_event = run_btn.click(
402
+ fn=infer_and_show_video_button,
403
+ inputs=inputs,
404
+ outputs=outputs + [create_video_button]
405
+ )
406
+
407
+ # Video creation
408
+ create_video_button.click(
409
+ fn=lambda: gr.update(visible=True),
410
+ outputs=[video_group],
411
+ api_name=False
412
+ ).then(
413
+ fn=create_video_between_images,
414
+ inputs=[image, result, prompt_preview],
415
+ outputs=[video_output],
416
+ api_name=False
417
+ )
418
+
419
+ # Examples - You'll need to add your own example images
420
+ gr.Examples(
421
+ examples=[
422
+ [None, "soft_window", "side", "soft", "", "", 0, True, 1.0, 4, 1024, 1024],
423
+ [None, "golden_hour", "front", "medium", "", "", 0, True, 1.0, 4, 1024, 1024],
424
+ [None, "dramatic", "side", "strong", "", "", 0, True, 1.0, 4, 1024, 1024],
425
+ [None, "neon", "front", "medium", "", "", 0, True, 1.0, 4, 1024, 1024],
426
+ [None, "candlelight", "front", "soft", "", "", 0, True, 1.0, 4, 1024, 1024],
427
+ ],
428
+ inputs=[image, light_type, light_direction, light_intensity, custom_prompt, user_prompt,
429
+ seed, randomize_seed, true_guidance_scale, num_inference_steps, height, width],
430
+ outputs=outputs,
431
+ fn=infer_relight,
432
+ cache_examples="lazy",
433
+ elem_id="examples"
434
+ )
435
+
436
+ # Image upload triggers dimension update and control reset
437
+ image.upload(
438
+ fn=update_dimensions_on_upload,
439
+ inputs=[image],
440
+ outputs=[width, height]
441
+ ).then(
442
+ fn=reset_all,
443
+ inputs=None,
444
+ outputs=[light_type, light_direction, light_intensity, custom_prompt, user_prompt, is_reset],
445
+ queue=False
446
+ ).then(
447
+ fn=end_reset,
448
+ inputs=None,
449
+ outputs=[is_reset],
450
+ queue=False
451
+ )
452
+
453
+
454
+ # Live updates
455
+ def maybe_infer(is_reset, progress=gr.Progress(track_tqdm=True), *args):
456
+ if is_reset:
457
+ return gr.update(), gr.update(), gr.update(), gr.update()
458
+ else:
459
+ result_img, result_seed, result_prompt = infer_relight(*args)
460
+ # Show video button if we have both input and output
461
+ show_button = args[0] is not None and result_img is not None
462
+ return result_img, result_seed, result_prompt, gr.update(visible=show_button)
463
+
464
+ control_inputs = [
465
+ image, light_type, light_direction, light_intensity, custom_prompt, user_prompt,
466
+ seed, randomize_seed, true_guidance_scale, num_inference_steps, height, width, prev_output
467
+ ]
468
+ control_inputs_with_flag = [is_reset] + control_inputs
469
+
470
+ for control in [light_type, light_direction, light_intensity]:
471
+ control.change(fn=maybe_infer, inputs=control_inputs_with_flag, outputs=outputs + [create_video_button])
472
+
473
+ custom_prompt.change(fn=maybe_infer, inputs=control_inputs_with_flag, outputs=outputs + [create_video_button])
474
+ user_prompt.change(fn=maybe_infer, inputs=control_inputs_with_flag, outputs=outputs + [create_video_button])
475
+
476
+ run_event.then(lambda img, *_: img, inputs=[result], outputs=[prev_output])
477
+
478
+ demo.launch()