linoyts HF Staff commited on
Commit
169902b
·
verified ·
1 Parent(s): 0dd4b90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -4
app.py CHANGED
@@ -13,11 +13,126 @@ from qwenimage.pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeli
13
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
14
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
15
  import math
 
16
 
17
  from PIL import Image
18
 
19
  # Set environment variable for parallel loading
20
- os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  MAX_IMAGE_SIZE = 2048
@@ -56,19 +171,31 @@ pipe.fuse_lora()
56
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
57
 
58
 
59
-
60
  # dummy_mask = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/mask_cat.png?raw=true")
61
 
62
  # # --- Ahead-of-time compilation ---
63
  # optimize_pipeline_(pipe, image=Image.new("RGB", (1328, 1328)), prompt="prompt", mask_image=dummy_mask)
64
 
65
  @spaces.GPU(duration=120)
66
- def infer(edit_images, prompt, negative_prompt="", seed=42, randomize_seed=False, strength=1.0, num_inference_steps=35, true_cfg_scale=4.0, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
67
  image = edit_images["background"]
68
  mask = edit_images["layers"][0]
69
 
70
  if randomize_seed:
71
  seed = random.randint(0, MAX_SEED)
 
 
 
 
72
 
73
  # Generate image using Qwen pipeline
74
  result_image = pipe(
@@ -164,6 +291,7 @@ with gr.Blocks(css=css) as demo:
164
 
165
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
166
 
 
167
  with gr.Row():
168
  strength = gr.Slider(
169
  label="Strength",
@@ -190,11 +318,15 @@ with gr.Blocks(css=css) as demo:
190
  step=1,
191
  value=8,
192
  )
 
 
 
 
193
 
194
  gr.on(
195
  triggers=[run_button.click, prompt.submit],
196
  fn = infer,
197
- inputs = [edit_image, prompt, negative_prompt, seed, randomize_seed, strength, num_inference_steps, true_cfg_scale],
198
  outputs = [result, seed]
199
  )
200
 
 
13
  from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
14
  from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
15
  import math
16
+ from huggingface_hub import InferenceClient
17
 
18
  from PIL import Image
19
 
20
  # Set environment variable for parallel loading
21
+ # os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES"
22
+
23
+ # --- Prompt Enhancement using Hugging Face InferenceClient ---
24
+ def polish_prompt_hf(original_prompt, system_prompt):
25
+ """
26
+ Rewrites the prompt using a Hugging Face InferenceClient.
27
+ """
28
+ # Ensure HF_TOKEN is set
29
+ api_key = os.environ.get("HF_TOKEN")
30
+ if not api_key:
31
+ print("Warning: HF_TOKEN not set. Falling back to original prompt.")
32
+ return original_prompt
33
+
34
+ try:
35
+ # Initialize the client
36
+ client = InferenceClient(
37
+ provider="cerebras",
38
+ api_key=api_key,
39
+ )
40
+
41
+ # Format the messages for the chat completions API
42
+ messages = [
43
+ {"role": "system", "content": system_prompt},
44
+ {"role": "user", "content": original_prompt}
45
+ ]
46
+
47
+ # Call the API
48
+ completion = client.chat.completions.create(
49
+ model="Qwen/Qwen3-235B-A22B-Instruct-2507",
50
+ messages=messages,
51
+ )
52
+
53
+ # Parse the response
54
+ result = completion.choices[0].message.content
55
+
56
+ # Try to extract JSON if present
57
+ if '{"Rewritten"' in result:
58
+ try:
59
+ # Clean up the response
60
+ result = result.replace('```json', '').replace('```', '')
61
+ result_json = json.loads(result)
62
+ polished_prompt = result_json.get('Rewritten', result)
63
+ except:
64
+ polished_prompt = result
65
+ else:
66
+ polished_prompt = result
67
+
68
+ polished_prompt = polished_prompt.strip().replace("\n", " ")
69
+ return polished_prompt
70
+
71
+ except Exception as e:
72
+ print(f"Error during API call to Hugging Face: {e}")
73
+ # Fallback to original prompt if enhancement fails
74
+ return original_prompt
75
+
76
+
77
+ def polish_prompt(prompt, img):
78
+ """
79
+ Main function to polish prompts for image editing using HF inference.
80
+ """
81
+ SYSTEM_PROMPT = '''
82
+ # Edit Instruction Rewriter
83
+ You are a professional edit instruction rewriter. Your task is to generate a precise, concise, and visually achievable professional-level edit instruction based on the user-provided instruction and the image to be edited.
84
+ Please strictly follow the rewriting rules below:
85
+ ## 1. General Principles
86
+ - Keep the rewritten prompt **concise**. Avoid overly long sentences and reduce unnecessary descriptive language.
87
+ - If the instruction is contradictory, vague, or unachievable, prioritize reasonable inference and correction, and supplement details when necessary.
88
+ - Keep the core intention of the original instruction unchanged, only enhancing its clarity, rationality, and visual feasibility.
89
+ - All added objects or modifications must align with the logic and style of the edited input image's overall scene.
90
+ ## 2. Task Type Handling Rules
91
+ ### 1. Add, Delete, Replace Tasks
92
+ - If the instruction is clear (already includes task type, target entity, position, quantity, attributes), preserve the original intent and only refine the grammar.
93
+ - If the description is vague, supplement with minimal but sufficient details (category, color, size, orientation, position, etc.). For example:
94
+ > Original: "Add an animal"
95
+ > Rewritten: "Add a light-gray cat in the bottom-right corner, sitting and facing the camera"
96
+ - Remove meaningless instructions: e.g., "Add 0 objects" should be ignored or flagged as invalid.
97
+ - For replacement tasks, specify "Replace Y with X" and briefly describe the key visual features of X.
98
+ ### 2. Text Editing Tasks
99
+ - All text content must be enclosed in English double quotes " ". Do not translate or alter the original language of the text, and do not change the capitalization.
100
+ - **For text replacement tasks, always use the fixed template:**
101
+ - Replace "xx" to "yy".
102
+ - Replace the xx bounding box to "yy".
103
+ - If the user does not specify text content, infer and add concise text based on the instruction and the input image's context. For example:
104
+ > Original: "Add a line of text" (poster)
105
+ > Rewritten: "Add text "LIMITED EDITION" at the top center with slight shadow"
106
+ - Specify text position, color, and layout in a concise way.
107
+ ### 3. Human Editing Tasks
108
+ - Maintain the person's core visual consistency (ethnicity, gender, age, hairstyle, expression, outfit, etc.).
109
+ - If modifying appearance (e.g., clothes, hairstyle), ensure the new element is consistent with the original style.
110
+ - **For expression changes, they must be natural and subtle, never exaggerated.**
111
+ - If deletion is not specifically emphasized, the most important subject in the original image (e.g., a person, an animal) should be preserved.
112
+ - For background change tasks, emphasize maintaining subject consistency at first.
113
+ - Example:
114
+ > Original: "Change the person's hat"
115
+ > Rewritten: "Replace the man's hat with a dark brown beret; keep smile, short hair, and gray jacket unchanged"
116
+ ### 4. Style Transformation or Enhancement Tasks
117
+ - If a style is specified, describe it concisely with key visual traits. For example:
118
+ > Original: "Disco style"
119
+ > Rewritten: "1970s disco: flashing lights, disco ball, mirrored walls, colorful tones"
120
+ - If the instruction says "use reference style" or "keep current style," analyze the input image, extract main features (color, composition, texture, lighting, art style), and integrate them concisely.
121
+ - **For coloring tasks, including restoring old photos, always use the fixed template:** "Restore old photograph, remove scratches, reduce noise, enhance details, high resolution, realistic, natural skin tones, clear facial features, no distortion, vintage photo restoration"
122
+ - If there are other changes, place the style description at the end.
123
+ ## 3. Rationality and Logic Checks
124
+ - Resolve contradictory instructions: e.g., "Remove all trees but keep all trees" should be logically corrected.
125
+ - Add missing key information: if position is unspecified, choose a reasonable area based on composition (near subject, empty space, center/edges).
126
+ # Output Format
127
+ Return only the rewritten instruction text directly, without JSON formatting or any other wrapper.
128
+ '''
129
+
130
+ # Note: We're not actually using the image in the HF version,
131
+ # but keeping the interface consistent
132
+ full_prompt = f"{SYSTEM_PROMPT}\n\nUser Input: {prompt}\n\nRewritten Prompt:"
133
+
134
+ return polish_prompt_hf(full_prompt, SYSTEM_PROMPT)
135
+
136
 
137
  MAX_SEED = np.iinfo(np.int32).max
138
  MAX_IMAGE_SIZE = 2048
 
171
  pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
172
 
173
 
 
174
  # dummy_mask = load_image("https://github.com/Trgtuan10/Image_storage/blob/main/mask_cat.png?raw=true")
175
 
176
  # # --- Ahead-of-time compilation ---
177
  # optimize_pipeline_(pipe, image=Image.new("RGB", (1328, 1328)), prompt="prompt", mask_image=dummy_mask)
178
 
179
  @spaces.GPU(duration=120)
180
+ def infer(edit_images,
181
+ prompt,
182
+ negative_prompt="",
183
+ seed=42,
184
+ randomize_seed=False,
185
+ strength=1.0,
186
+ num_inference_steps=8,
187
+ true_cfg_scale=1.0,
188
+ rewrite_prompt=True,
189
+ progress=gr.Progress(track_tqdm=True)):
190
  image = edit_images["background"]
191
  mask = edit_images["layers"][0]
192
 
193
  if randomize_seed:
194
  seed = random.randint(0, MAX_SEED)
195
+
196
+ if rewrite_prompt:
197
+ prompt = polish_prompt(prompt, image)
198
+ print(f"Rewritten Prompt: {prompt}")
199
 
200
  # Generate image using Qwen pipeline
201
  result_image = pipe(
 
291
 
292
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
293
 
294
+
295
  with gr.Row():
296
  strength = gr.Slider(
297
  label="Strength",
 
318
  step=1,
319
  value=8,
320
  )
321
+ rewrite_prompt = gr.Checkbox(
322
+ label="Enhance prompt (using HF Inference)",
323
+ value=True
324
+ )
325
 
326
  gr.on(
327
  triggers=[run_button.click, prompt.submit],
328
  fn = infer,
329
+ inputs = [edit_image, prompt, negative_prompt, seed, randomize_seed, strength, num_inference_steps, true_cfg_scale, rewrite_prompt],
330
  outputs = [result, seed]
331
  )
332