prithivMLmods commited on
Commit
23c94ef
·
verified ·
1 Parent(s): 4306537

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -110
app.py CHANGED
@@ -19,83 +19,54 @@ DTYPE = "auto"
19
 
20
  CATEGORIES = ["Query", "Caption", "Point", "Detect"]
21
  PLACEHOLDERS = {
22
- "Query": "What is in this image?",
23
- "Caption": "Select a caption length from the suggestions below.",
24
- "Point": "Select an object from suggestions or enter a custom one.",
25
- "Detect": "Select an object from suggestions or enter a custom one.",
26
  }
27
 
 
 
28
  qwen_model = Qwen3VLForConditionalGeneration.from_pretrained(
29
- "Qwen/Qwen3-VL-32B-Instruct",
30
- torch_dtype=DTYPE,
31
  device_map=DEVICE,
32
  ).eval()
33
  qwen_processor = Qwen3VLProcessor.from_pretrained(
34
- "Qwen/Qwen3-VL-32B-Instruct",
35
  )
36
- print("Model loaded successfully.")
37
 
38
 
39
  # --- Utility Functions ---
40
  def safe_parse_json(text: str):
41
- """Safely parse JSON or Python literal from a string, cleaning it first."""
42
- # Find the JSON object within the text
43
- match = re.search(r'\{.*\}', text, re.DOTALL)
44
- if not match:
45
- return {}
46
- text = match.group(0)
47
  try:
48
  return json.loads(text)
49
  except json.JSONDecodeError:
50
- try:
51
- # Fallback for Python dictionary literals
52
- return ast.literal_eval(text)
53
- except (ValueError, SyntaxError):
54
- return {}
55
-
56
-
57
- def annotate_image(image: Image.Image, result: dict, category: str):
58
- """Draws annotations on the image based on the model's output."""
59
- if not isinstance(image, Image.Image) or not isinstance(result, dict):
60
- return image
61
-
62
- image_np = np.array(image.convert("RGB"))
63
-
64
- # Handle Point annotations
65
- if category == "Point" and "points" in result and result["points"]:
66
- points_xy = np.array(result["points"])
67
- if points_xy.size == 0:
68
- return image
69
-
70
- # Denormalize points from [0, 1] range to image dimensions
71
- points_xy *= np.array([image.width, image.height])
72
-
73
- key_points = sv.KeyPoints(xy=points_xy.reshape(1, -1, 2))
74
- annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
75
- annotated_image = annotator.annotate(scene=image_np.copy(), key_points=key_points)
76
- return Image.fromarray(annotated_image)
77
-
78
- # Handle Detection annotations
79
- if category == "Detect" and "objects" in result and result["objects"]:
80
- boxes_xyxy = np.array(result["objects"])
81
- if boxes_xyxy.size == 0:
82
- return image
83
-
84
- # Denormalize boxes from [0, 1] range to image dimensions
85
- boxes_xyxy *= np.array([image.width, image.height, image.width, image.height])
86
-
87
- detections = sv.Detections(xyxy=boxes_xyxy)
88
- annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=4)
89
- annotated_image = annotator.annotate(scene=image_np.copy(), detections=detections)
90
- return Image.fromarray(annotated_image)
91
-
92
- return image
93
-
94
 
95
  # --- Inference Functions ---
96
  def run_qwen_inference(image: Image.Image, prompt: str):
97
- """Core function to run inference with the Qwen3-VL model."""
98
- messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
 
 
 
 
 
 
 
 
99
  inputs = qwen_processor.apply_chat_template(
100
  messages,
101
  tokenize=True,
@@ -105,9 +76,15 @@ def run_qwen_inference(image: Image.Image, prompt: str):
105
  ).to(DEVICE)
106
 
107
  with torch.inference_mode():
108
- generated_ids = qwen_model.generate(**inputs, max_new_tokens=512)
 
 
 
109
 
110
- generated_ids_trimmed = generated_ids[:, inputs.input_ids.shape[1]:]
 
 
 
111
  output_text = qwen_processor.batch_decode(
112
  generated_ids_trimmed,
113
  skip_special_tokens=True,
@@ -118,88 +95,174 @@ def run_qwen_inference(image: Image.Image, prompt: str):
118
 
119
  @GPU
120
  def get_suggested_objects(image: Image.Image):
121
- """Get suggested objects in the image using Qwen3-VL to populate radio buttons."""
122
  if image is None:
123
- return gr.Radio(choices=[], visible=False)
124
-
125
  try:
126
- prompt = "List the 3 most prominent objects in this image as a Python list of strings. Example: ['car', 'tree', 'person']"
127
- result_text = run_qwen_inference(image, prompt)
128
-
 
 
 
 
 
129
  match = re.search(r'\[.*?\]', result_text)
130
  if match:
131
- suggestions = ast.literal_eval(match.group())
132
- if isinstance(suggestions, list) and suggestions:
133
- return gr.Radio(choices=suggestions, visible=True, interactive=True)
 
 
134
  except Exception as e:
135
  print(f"Error getting suggestions with Qwen: {e}")
136
-
137
- return gr.Radio(choices=[], visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
 
140
  @GPU
141
  def process_qwen(image: Image.Image, category: str, prompt: str):
142
- """Process inputs based on the selected category, returning text and data for annotation."""
143
  if category == "Query":
144
  return run_qwen_inference(image, prompt), {}
145
-
146
  elif category == "Caption":
147
  full_prompt = f"Provide a {prompt} length caption for the image."
148
  return run_qwen_inference(image, full_prompt), {}
149
-
150
  elif category == "Point":
151
  full_prompt = (
152
- f"Provide 2D point coordinates for '{prompt}'. Respond ONLY with a JSON object like "
153
- f"`{{\"points\": [[x1, y1], [x2, y2], ...]}}`. The coordinates must be normalized between 0.0 and 1.0."
154
  )
155
  output_text = run_qwen_inference(image, full_prompt)
156
  parsed_json = safe_parse_json(output_text)
157
- # Ensure the parsed data has the correct structure
158
- if "points" not in parsed_json or not isinstance(parsed_json["points"], list):
159
- return output_text, {}
160
- return output_text, parsed_json
161
-
 
 
 
162
  elif category == "Detect":
163
  full_prompt = (
164
- f"Provide bounding box coordinates for '{prompt}'. Respond ONLY with a JSON object like "
165
- f"`{{\"objects\": [[x_min, y_min, x_max, y_max], ...]}}`. The coordinates must be normalized between 0.0 and 1.0."
166
  )
167
  output_text = run_qwen_inference(image, full_prompt)
168
  parsed_json = safe_parse_json(output_text)
169
- if "objects" not in parsed_json or not isinstance(parsed_json["objects"], list):
170
- return output_text, {}
171
- return output_text, parsed_json
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  return "Invalid category", {}
174
 
175
 
176
  # --- Gradio Interface Logic ---
177
  def on_category_and_image_change(image, category):
178
- """Handle UI changes when the image or category is updated."""
179
  text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True)
180
 
181
  if category == "Caption":
182
- return gr.Radio(choices=["short", "normal", "long"], value="normal", visible=True), text_box
183
-
184
  if image is None or category not in ["Point", "Detect"]:
185
  return gr.Radio(choices=[], visible=False), text_box
186
 
187
- return get_suggested_objects(image), text_box
 
 
 
 
 
 
 
 
 
 
 
188
 
189
 
190
  def process_inputs(image, category, prompt):
191
- """Main function to handle the user's submission."""
192
  if image is None:
193
  raise gr.Error("Please upload an image.")
194
  if not prompt and category not in ["Caption"]:
195
- raise gr.Error("Please provide a prompt or select a suggestion.")
196
- if category == "Caption" and not prompt:
197
- prompt = "normal" # Default caption length
 
 
198
 
199
- image.thumbnail((1024, 1024)) # Resize for faster inference
 
200
 
 
201
  qwen_text, qwen_data = process_qwen(image, category, prompt)
202
- qwen_annotated_image = annotate_image(image, qwen_data, category)
203
 
204
  return qwen_annotated_image, qwen_text
205
 
@@ -207,38 +270,54 @@ def process_inputs(image, category, prompt):
207
  # --- Gradio UI Layout ---
208
  with gr.Blocks(theme=Ocean()) as demo:
209
  gr.Markdown("# 👓 Object Understanding with Qwen3-VL")
210
- gr.Markdown("### Explore object detection, keypoint detection, and captioning using natural language prompts.")
 
 
 
 
 
211
 
212
  with gr.Row():
213
  with gr.Column(scale=1):
214
  image_input = gr.Image(type="pil", label="Input Image")
215
  category_select = gr.Radio(
216
- choices=CATEGORIES, value=CATEGORIES[0], label="Select Task", interactive=True
 
 
 
217
  )
218
  suggestions_radio = gr.Radio(
219
- choices=[], label="Suggestions", visible=False, interactive=True
 
 
 
220
  )
221
  prompt_input = gr.Textbox(
222
- placeholder=PLACEHOLDERS[CATEGORIES[0]], label="Prompt", lines=2
 
 
223
  )
224
- submit_btn = gr.Button("Generate", variant="primary")
225
 
226
  with gr.Column(scale=2):
227
  gr.Markdown("### Qwen/Qwen3-VL-4B-Instruct Output")
228
  qwen_img_output = gr.Image(label="Annotated Image")
229
- qwen_text_output = gr.Textbox(label="Text Output", lines=8, interactive=False, show_copy_button=True)
 
 
230
 
231
  gr.Examples(
232
  examples=[
233
- ["examples/cars.jpg", "Query", "How many cars are in the image?"],
234
- ["examples/dog_beach.jpg", "Detect", "dog"],
235
- ["examples/person_skiing.jpg", "Point", "the person's head"],
236
- ["examples/dog_beach.jpg", "Caption", "short"],
237
  ],
238
  inputs=[image_input, category_select, prompt_input],
239
  )
240
 
241
  # --- Event Listeners ---
 
242
  category_select.change(
243
  fn=on_category_and_image_change,
244
  inputs=[image_input, category_select],
@@ -249,7 +328,15 @@ with gr.Blocks(theme=Ocean()) as demo:
249
  inputs=[image_input, category_select],
250
  outputs=[suggestions_radio, prompt_input],
251
  )
252
- suggestions_radio.change(fn=lambda x: x, inputs=suggestions_radio, outputs=prompt_input)
 
 
 
 
 
 
 
 
253
  submit_btn.click(
254
  fn=process_inputs,
255
  inputs=[image_input, category_select, prompt_input],
@@ -257,4 +344,4 @@ with gr.Blocks(theme=Ocean()) as demo:
257
  )
258
 
259
  if __name__ == "__main__":
260
- demo.launch()
 
19
 
20
  CATEGORIES = ["Query", "Caption", "Point", "Detect"]
21
  PLACEHOLDERS = {
22
+ "Query": "What's in this image?",
23
+ "Caption": "Select caption length: short, normal, or long",
24
+ "Point": "Select an object from suggestions or enter manually",
25
+ "Detect": "Select an object from suggestions or enter manually",
26
  }
27
 
28
+ # --- Model Loading ---
29
+ # Load Qwen3-VL
30
  qwen_model = Qwen3VLForConditionalGeneration.from_pretrained(
31
+ "Qwen/Qwen3-VL-4B-Instruct",
32
+ dtype=DTYPE,
33
  device_map=DEVICE,
34
  ).eval()
35
  qwen_processor = Qwen3VLProcessor.from_pretrained(
36
+ "Qwen/Qwen3-VL-4B-Instruct",
37
  )
 
38
 
39
 
40
  # --- Utility Functions ---
41
  def safe_parse_json(text: str):
42
+ """Safely parse a string that may be JSON or a Python literal."""
43
+ text = text.strip()
44
+ # Remove markdown code blocks
45
+ text = re.sub(r"^```(json)?", "", text)
46
+ text = re.sub(r"```$", "", text)
47
+ text = text.strip()
48
  try:
49
  return json.loads(text)
50
  except json.JSONDecodeError:
51
+ pass
52
+ try:
53
+ # Fallback to literal_eval for Python-like dictionary/list strings
54
+ return ast.literal_eval(text)
55
+ except Exception:
56
+ return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # --- Inference Functions ---
59
  def run_qwen_inference(image: Image.Image, prompt: str):
60
+ """Core function to run inference with the Qwen model."""
61
+ messages = [
62
+ {
63
+ "role": "user",
64
+ "content": [
65
+ {"type": "image", "image": image},
66
+ {"type": "text", "text": prompt},
67
+ ],
68
+ }
69
+ ]
70
  inputs = qwen_processor.apply_chat_template(
71
  messages,
72
  tokenize=True,
 
76
  ).to(DEVICE)
77
 
78
  with torch.inference_mode():
79
+ generated_ids = qwen_model.generate(
80
+ **inputs,
81
+ max_new_tokens=512,
82
+ )
83
 
84
+ generated_ids_trimmed = [
85
+ out_ids[len(in_ids) :]
86
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
87
+ ]
88
  output_text = qwen_processor.batch_decode(
89
  generated_ids_trimmed,
90
  skip_special_tokens=True,
 
95
 
96
  @GPU
97
  def get_suggested_objects(image: Image.Image):
98
+ """Get suggested objects in the image using Qwen."""
99
  if image is None:
100
+ return []
 
101
  try:
102
+ # Resize image for faster suggestion generation
103
+ suggest_image = image.copy()
104
+ suggest_image.thumbnail((512, 512))
105
+
106
+ prompt = "List the main objects in the image in a Python list format. For example: ['cat', 'dog', 'table']"
107
+ result_text = run_qwen_inference(suggest_image, prompt)
108
+
109
+ # Clean up the output to find the list
110
  match = re.search(r'\[.*?\]', result_text)
111
  if match:
112
+ suggested_objects = ast.literal_eval(match.group())
113
+ if isinstance(suggested_objects, list):
114
+ # Return up to 3 suggestions
115
+ return suggested_objects[:3]
116
+ return []
117
  except Exception as e:
118
  print(f"Error getting suggestions with Qwen: {e}")
119
+ return []
120
+
121
+
122
+ def annotate_image(image: Image.Image, result: dict):
123
+ """Annotates the image with points or bounding boxes based on model output."""
124
+ if not isinstance(image, Image.Image) or not isinstance(result, dict):
125
+ return image
126
+
127
+ original_width, original_height = image.size
128
+ scene_np = np.array(image.copy())
129
+
130
+ # Handle Point annotations
131
+ if "points" in result and result["points"]:
132
+ points_list = []
133
+ for point in result.get("points", []):
134
+ x = int(point["x"] * original_width)
135
+ y = int(point["y"] * original_height)
136
+ points_list.append([x, y])
137
+
138
+ if not points_list:
139
+ return image
140
+
141
+ points_array = np.array(points_list).reshape(-1, 2)
142
+ key_points = sv.KeyPoints(xy=points_array)
143
+ vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
144
+ annotated_image_np = vertex_annotator.annotate(
145
+ scene=scene_np, key_points=key_points
146
+ )
147
+ return Image.fromarray(annotated_image_np)
148
+
149
+ # Handle Detection annotations
150
+ if "objects" in result and result["objects"]:
151
+ boxes = []
152
+ for obj in result["objects"]:
153
+ x_min = obj["x_min"] * original_width
154
+ y_min = obj["y_min"] * original_height
155
+ x_max = obj["x_max"] * original_width
156
+ y_max = obj["y_max"] * original_height
157
+ boxes.append([x_min, y_min, x_max, y_max])
158
+
159
+ if not boxes:
160
+ return image
161
+
162
+ detections = sv.Detections(xyxy=np.array(boxes))
163
+ box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX, thickness=4)
164
+ label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
165
+
166
+ annotated_image_np = box_annotator.annotate(
167
+ scene=scene_np, detections=detections
168
+ )
169
+ return Image.fromarray(annotated_image_np)
170
+
171
+ return image
172
 
173
 
174
  @GPU
175
  def process_qwen(image: Image.Image, category: str, prompt: str):
176
+ """Processes the input based on the selected category using the Qwen model."""
177
  if category == "Query":
178
  return run_qwen_inference(image, prompt), {}
179
+
180
  elif category == "Caption":
181
  full_prompt = f"Provide a {prompt} length caption for the image."
182
  return run_qwen_inference(image, full_prompt), {}
183
+
184
  elif category == "Point":
185
  full_prompt = (
186
+ f"Provide 2d point coordinates for {prompt}. Report in JSON format like "
187
+ `[{"point_2d": [x, y]}]` " where coordinates are from 0 to 1000."
188
  )
189
  output_text = run_qwen_inference(image, full_prompt)
190
  parsed_json = safe_parse_json(output_text)
191
+ points_result = {"points": []}
192
+ if isinstance(parsed_json, list):
193
+ for item in parsed_json:
194
+ if "point_2d" in item and len(item["point_2d"]) == 2:
195
+ x, y = item["point_2d"]
196
+ points_result["points"].append({"x": x / 1000.0, "y": y / 1000.0})
197
+ return json.dumps(points_result, indent=2), points_result
198
+
199
  elif category == "Detect":
200
  full_prompt = (
201
+ f"Provide bounding box coordinates for {prompt}. Report in JSON format like "
202
+ `[{"bbox_2d": [xmin, ymin, xmax, ymax]}]` " where coordinates are from 0 to 1000."
203
  )
204
  output_text = run_qwen_inference(image, full_prompt)
205
  parsed_json = safe_parse_json(output_text)
206
+ objects_result = {"objects": []}
207
+ if isinstance(parsed_json, list):
208
+ for item in parsed_json:
209
+ if "bbox_2d" in item and len(item["bbox_2d"]) == 4:
210
+ xmin, ymin, xmax, ymax = item["bbox_2d"]
211
+ objects_result["objects"].append(
212
+ {
213
+ "x_min": xmin / 1000.0,
214
+ "y_min": ymin / 1000.0,
215
+ "x_max": xmax / 1000.0,
216
+ "y_max": ymax / 1000.0,
217
+ }
218
+ )
219
+ return json.dumps(objects_result, indent=2), objects_result
220
 
221
  return "Invalid category", {}
222
 
223
 
224
  # --- Gradio Interface Logic ---
225
  def on_category_and_image_change(image, category):
226
+ """Generate suggestions when category or image changes."""
227
  text_box = gr.Textbox(value="", placeholder=PLACEHOLDERS.get(category, ""), interactive=True)
228
 
229
  if category == "Caption":
230
+ return gr.Radio(choices=["short", "normal", "long"], label="Caption Length", value="normal", visible=True), text_box
231
+
232
  if image is None or category not in ["Point", "Detect"]:
233
  return gr.Radio(choices=[], visible=False), text_box
234
 
235
+ suggestions = get_suggested_objects(image)
236
+ if suggestions:
237
+ return gr.Radio(choices=suggestions, label="Suggestions", visible=True, interactive=True), text_box
238
+ else:
239
+ return gr.Radio(choices=[], visible=False), text_box
240
+
241
+
242
+ def update_prompt_from_radio(selected_object):
243
+ """Update prompt textbox when a radio option is selected."""
244
+ if selected_object:
245
+ return gr.Textbox(value=selected_object)
246
+ return gr.Textbox(value="")
247
 
248
 
249
  def process_inputs(image, category, prompt):
250
+ """Main function to handle the user's request."""
251
  if image is None:
252
  raise gr.Error("Please upload an image.")
253
  if not prompt and category not in ["Caption"]:
254
+ # Caption can have an empty prompt if a length is selected
255
+ if category == "Caption" and not prompt:
256
+ prompt = "normal" # default
257
+ else:
258
+ raise gr.Error("Please provide a prompt or select a suggestion.")
259
 
260
+ # Resize the image to make inference quicker
261
+ image.thumbnail((1024, 1024))
262
 
263
+ # Process with Qwen
264
  qwen_text, qwen_data = process_qwen(image, category, prompt)
265
+ qwen_annotated_image = annotate_image(image, qwen_data)
266
 
267
  return qwen_annotated_image, qwen_text
268
 
 
270
  # --- Gradio UI Layout ---
271
  with gr.Blocks(theme=Ocean()) as demo:
272
  gr.Markdown("# 👓 Object Understanding with Qwen3-VL")
273
+ gr.Markdown(
274
+ "### Explore object detection, visual grounding, and keypoint detection through natural language prompts."
275
+ )
276
+ gr.Markdown("""
277
+ *Powered by [Qwen/Qwen3-VL-4B-Instruct](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct). Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
278
+ """)
279
 
280
  with gr.Row():
281
  with gr.Column(scale=1):
282
  image_input = gr.Image(type="pil", label="Input Image")
283
  category_select = gr.Radio(
284
+ choices=CATEGORIES,
285
+ value=CATEGORIES[0],
286
+ label="Select Task Category",
287
+ interactive=True,
288
  )
289
  suggestions_radio = gr.Radio(
290
+ choices=[],
291
+ label="Suggestions",
292
+ visible=False,
293
+ interactive=True,
294
  )
295
  prompt_input = gr.Textbox(
296
+ placeholder=PLACEHOLDERS[CATEGORIES[0]],
297
+ label="Prompt",
298
+ lines=2,
299
  )
300
+ submit_btn = gr.Button("Process Image", variant="primary")
301
 
302
  with gr.Column(scale=2):
303
  gr.Markdown("### Qwen/Qwen3-VL-4B-Instruct Output")
304
  qwen_img_output = gr.Image(label="Annotated Image")
305
+ qwen_text_output = gr.Textbox(
306
+ label="Text Output", lines=10, interactive=False
307
+ )
308
 
309
  gr.Examples(
310
  examples=[
311
+ ["examples/example_1.jpg", "Query", "How many cars are in the image?"],
312
+ ["examples/example_1.jpg", "Detect", "car"],
313
+ ["examples/example_2.JPG", "Point", "the person's face"],
314
+ ["examples/example_2.JPG", "Caption", "short"],
315
  ],
316
  inputs=[image_input, category_select, prompt_input],
317
  )
318
 
319
  # --- Event Listeners ---
320
+ # When image or category changes, update suggestions
321
  category_select.change(
322
  fn=on_category_and_image_change,
323
  inputs=[image_input, category_select],
 
328
  inputs=[image_input, category_select],
329
  outputs=[suggestions_radio, prompt_input],
330
  )
331
+
332
+ # When a suggestion is clicked, update the prompt box
333
+ suggestions_radio.change(
334
+ fn=update_prompt_from_radio,
335
+ inputs=[suggestions_radio],
336
+ outputs=[prompt_input],
337
+ )
338
+
339
+ # Main submission action
340
  submit_btn.click(
341
  fn=process_inputs,
342
  inputs=[image_input, category_select, prompt_input],
 
344
  )
345
 
346
  if __name__ == "__main__":
347
+ demo.launch(debug=True)