prithivMLmods commited on
Commit
1b33b1c
·
verified ·
1 Parent(s): dd45b2a

update app

Browse files
Files changed (1) hide show
  1. app.py +149 -94
app.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
- from PIL import Image
 
6
  from typing import Iterable
7
  from gradio.themes import Soft
8
  from gradio.themes.utils import colors, fonts, sizes
@@ -100,96 +101,133 @@ try:
100
  print("Model loaded successfully.")
101
  except Exception as e:
102
  print(f"Error loading model: {e}")
 
103
  model = None
104
  processor = None
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  @spaces.GPU(duration=60)
107
- def process_image(input_image, task_type, text_prompt, threshold=0.5):
108
  if input_image is None:
109
  raise gr.Error("Please upload an image.")
110
 
111
  if model is None or processor is None:
112
  raise gr.Error("Model not loaded correctly.")
113
 
114
- # Convert image to RGB
115
  image_pil = input_image.convert("RGB")
116
- annotations = []
117
-
118
- with torch.no_grad():
119
- if task_type == "Instance Segmentation":
 
120
  if not text_prompt:
121
- raise gr.Error("Please enter a text prompt for Instance Segmentation.")
122
-
123
- # 1. Instance Segmentation Flow (Text Prompt)
124
  inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
125
- outputs = model(**inputs)
126
-
127
- # Post-process instance masks
128
- results = processor.post_process_instance_segmentation(
129
- outputs,
130
- threshold=threshold,
131
- mask_threshold=0.5,
132
- target_sizes=inputs.get("original_sizes").tolist()
133
- )[0]
134
 
135
- masks_np = results['masks'].cpu().numpy() # [N, H, W]
136
- scores_np = results['scores'].cpu().numpy()
 
 
 
 
 
 
 
137
 
138
- for i, mask in enumerate(masks_np):
139
- score_val = scores_np[i]
140
- label = f"{text_prompt} ({score_val:.2f})"
141
- annotations.append((mask, label))
 
 
 
142
 
143
- elif task_type == "Semantic Segmentation":
144
- # 2. Semantic Segmentation Flow (No Prompt)
145
- # Call processor without text
146
- inputs = processor(images=image_pil, return_tensors="pt").to(device)
147
- outputs = model(**inputs)
148
 
149
- # Extract semantic segmentation map
150
- # Shape: [batch, channels, height, width]
151
- semantic_seg = outputs.semantic_seg
152
 
153
- # Process for visualization:
154
- # Assuming semantic_seg is a dense map (e.g., saliency or class probabilities).
155
- # Since the snippet implies a single channel [batch, 1, H, W], we threshold it.
 
 
 
 
 
 
 
 
156
 
157
- # Remove batch dim -> [1, H, W] or [C, H, W]
158
- seg_map = semantic_seg.squeeze(0)
159
 
160
- # If 1 channel, create binary mask based on threshold/sigmoid
161
- if seg_map.shape[0] == 1:
162
- # Apply sigmoid if logits, or just threshold if probs
163
- # Assuming logits for general safety in torch models
164
- mask_tensor = torch.sigmoid(seg_map[0]) > threshold
165
- mask_np = mask_tensor.cpu().numpy()
166
-
167
- # Resize mask to original image size if needed
168
- # (Note: outputs.semantic_seg is usually feature map size, might need upscaling)
169
- # For simplicity in this snippet, we assume processor/output aligns or AnnotatedImage handles resizing (it usually requires matching sizes).
170
- # If size mismatch occurs, we convert mask to PIL, resize, then back to numpy.
171
-
172
- if mask_np.shape != (image_pil.height, image_pil.width):
173
- mask_img = Image.fromarray(mask_np.astype(np.uint8) * 255)
174
- mask_img = mask_img.resize(image_pil.size, Image.NEAREST)
175
- mask_np = np.array(mask_img) > 128
176
-
177
- annotations.append((mask_np, "Semantic Region"))
178
- else:
179
- # If multiple channels (classes), take argmax
180
- # This logic depends on specific SAM3 output structure
181
- mask_idx = torch.argmax(seg_map, dim=0).cpu().numpy()
182
- # Just visualize non-background (assuming 0 is background)
183
- mask_np = mask_idx > 0
184
-
185
- if mask_np.shape != (image_pil.height, image_pil.width):
186
- mask_img = Image.fromarray(mask_np.astype(np.uint8) * 255)
187
- mask_img = mask_img.resize(image_pil.size, Image.NEAREST)
188
- mask_np = np.array(mask_img) > 128
189
 
190
- annotations.append((mask_np, "Segmented Objects"))
 
 
 
 
 
191
 
192
- # Return tuple format for AnnotatedImage: (original_image, list_of_annotations)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  return (image_pil, annotations)
194
 
195
  # --- UI Logic ---
@@ -207,12 +245,6 @@ css="""
207
  }
208
  """
209
 
210
- def update_visibility(task):
211
- if task == "Instance Segmentation":
212
- return gr.update(visible=True)
213
- else:
214
- return gr.update(visible=False)
215
-
216
  with gr.Blocks(css=css, theme=plum_theme) as demo:
217
  with gr.Column(elem_id="col-container"):
218
  # Header with Logo
@@ -221,59 +253,82 @@ with gr.Blocks(css=css, theme=plum_theme) as demo:
221
  elem_id="main-title"
222
  )
223
 
224
- gr.Markdown("Segment objects using **SAM3** (Segment Anything Model 3). Choose **Instance** for specific text prompts or **Semantic** for automatic segmentation.")
225
 
226
  with gr.Row():
227
  # Left Column: Inputs
228
  with gr.Column(scale=1):
229
  input_image = gr.Image(label="Input Image", type="pil", height=350)
230
 
231
- task_type = gr.Radio(
232
- choices=["Instance Segmentation", "Semantic Segmentation"],
233
- value="Instance Segmentation",
234
  label="Task Type",
 
 
 
 
 
 
 
235
  interactive=True
236
  )
237
 
238
- text_prompt = gr.Textbox(
 
239
  label="Text Prompt",
240
- placeholder="e.g., cat, ear, car wheel...",
241
- info="Required for Instance Segmentation",
242
  visible=True
243
  )
244
 
 
 
 
 
 
 
 
245
  threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
246
 
247
- run_button = gr.Button("Run Segmentation", variant="primary")
248
 
249
  # Right Column: Output
250
  with gr.Column(scale=1.5):
251
  output_image = gr.AnnotatedImage(label="Segmented Output", height=500)
252
 
253
- # Event: Hide text prompt when Semantic Segmentation is selected
 
 
 
 
 
 
 
 
 
 
 
254
  task_type.change(
255
- fn=update_visibility,
256
  inputs=[task_type],
257
- outputs=[text_prompt]
258
  )
259
 
260
  # Examples
261
  gr.Examples(
262
  examples=[
263
- ["examples/cat.jpg", "Instance Segmentation", "cat", 0.5],
264
- ["examples/room.jpg", "Semantic Segmentation", "", 0.5],
265
- ["examples/car.jpg", "Instance Segmentation", "tire", 0.4],
266
  ],
267
- inputs=[input_image, task_type, text_prompt, threshold],
268
  outputs=[output_image],
269
- fn=process_image,
270
  cache_examples=False,
271
- label="Examples"
272
  )
273
 
274
  run_button.click(
275
- fn=process_image,
276
- inputs=[input_image, task_type, text_prompt, threshold],
277
  outputs=[output_image]
278
  )
279
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
+ import random
6
+ from PIL import Image, ImageDraw
7
  from typing import Iterable
8
  from gradio.themes import Soft
9
  from gradio.themes.utils import colors, fonts, sizes
 
101
  print("Model loaded successfully.")
102
  except Exception as e:
103
  print(f"Error loading model: {e}")
104
+ print("Ensure you have the correct libraries installed (transformers>=4.40.0) and access to the model.")
105
  model = None
106
  processor = None
107
 
108
+ # --- Helper Functions ---
109
+
110
+ def parse_boxes(box_str):
111
+ """
112
+ Parses a string of coordinates into a list of lists.
113
+ Format expected: "x1,y1,x2,y2" or "x1,y1,x2,y2; x3,y3,x4,y4"
114
+ """
115
+ try:
116
+ boxes = []
117
+ # Split by semicolon for multiple boxes
118
+ segments = box_str.split(';')
119
+ for seg in segments:
120
+ if not seg.strip():
121
+ continue
122
+ coords = [float(c.strip()) for c in seg.split(',')]
123
+ if len(coords) != 4:
124
+ raise ValueError(f"Expected 4 coordinates per box, got {len(coords)}")
125
+ boxes.append(coords)
126
+ return boxes
127
+ except Exception as e:
128
+ raise ValueError(f"Invalid box format: {e}")
129
+
130
  @spaces.GPU(duration=60)
131
+ def process_sam3(input_image, task_type, text_prompt, box_input, threshold=0.5):
132
  if input_image is None:
133
  raise gr.Error("Please upload an image.")
134
 
135
  if model is None or processor is None:
136
  raise gr.Error("Model not loaded correctly.")
137
 
 
138
  image_pil = input_image.convert("RGB")
139
+ inputs = {}
140
+
141
+ # Logic branching based on Task Type
142
+ try:
143
+ if task_type == "Text Prompt":
144
  if not text_prompt:
145
+ raise gr.Error("Please enter a text prompt.")
 
 
146
  inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
147
+ display_label_prefix = text_prompt
 
 
 
 
 
 
 
 
148
 
149
+ elif task_type == "Single Bounding Box":
150
+ if not box_input:
151
+ raise gr.Error("Please enter box coordinates.")
152
+ boxes = parse_boxes(box_input)
153
+ if len(boxes) != 1:
154
+ raise gr.Error("Please provide exactly one box for this mode.")
155
+
156
+ input_boxes = [boxes] # [batch_size, num_boxes, 4]
157
+ input_boxes_labels = [[1]] # 1 = positive
158
 
159
+ inputs = processor(
160
+ images=image_pil,
161
+ input_boxes=input_boxes,
162
+ input_boxes_labels=input_boxes_labels,
163
+ return_tensors="pt"
164
+ ).to(device)
165
+ display_label_prefix = "Box"
166
 
167
+ elif task_type == "Multiple Boxes (Positive)":
168
+ if not box_input:
169
+ raise gr.Error("Please enter box coordinates.")
170
+ boxes = parse_boxes(box_input) # Returns list of [x1,y1,x2,y2]
 
171
 
172
+ input_boxes = [boxes] # [batch, num_boxes, 4]
173
+ # All labels 1 (positive)
174
+ input_boxes_labels = [[1] * len(boxes)]
175
 
176
+ inputs = processor(
177
+ images=image_pil,
178
+ input_boxes=input_boxes,
179
+ input_boxes_labels=input_boxes_labels,
180
+ return_tensors="pt"
181
+ ).to(device)
182
+ display_label_prefix = "Multi-Box"
183
+
184
+ elif task_type == "Text + Negative Box":
185
+ if not text_prompt or not box_input:
186
+ raise gr.Error("Please provide both Text Prompt and Box Coordinates.")
187
 
188
+ boxes = parse_boxes(box_input)
 
189
 
190
+ input_boxes = [boxes]
191
+ # Labels 0 (negative/exclude)
192
+ input_boxes_labels = [[0] * len(boxes)]
193
+
194
+ inputs = processor(
195
+ images=image_pil,
196
+ text=text_prompt,
197
+ input_boxes=input_boxes,
198
+ input_boxes_labels=input_boxes_labels,
199
+ return_tensors="pt"
200
+ ).to(device)
201
+ display_label_prefix = f"{text_prompt} (Excl. Box)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
+ except ValueError as e:
204
+ raise gr.Error(str(e))
205
+
206
+ # Inference
207
+ with torch.no_grad():
208
+ outputs = model(**inputs)
209
 
210
+ # Post-processing
211
+ results = processor.post_process_instance_segmentation(
212
+ outputs,
213
+ threshold=threshold,
214
+ mask_threshold=0.5,
215
+ target_sizes=inputs.get("original_sizes").tolist()
216
+ )[0]
217
+
218
+ masks = results['masks']
219
+ scores = results['scores']
220
+
221
+ # Prepare AnnotatedImage Output
222
+ annotations = []
223
+ masks_np = masks.cpu().numpy()
224
+ scores_np = scores.cpu().numpy()
225
+
226
+ for i, mask in enumerate(masks_np):
227
+ score_val = scores_np[i]
228
+ label = f"{display_label_prefix} ({score_val:.2f})"
229
+ annotations.append((mask, label))
230
+
231
  return (image_pil, annotations)
232
 
233
  # --- UI Logic ---
 
245
  }
246
  """
247
 
 
 
 
 
 
 
248
  with gr.Blocks(css=css, theme=plum_theme) as demo:
249
  with gr.Column(elem_id="col-container"):
250
  # Header with Logo
 
253
  elem_id="main-title"
254
  )
255
 
256
+ gr.Markdown("Perform advanced segmentation using **SAM3** with Text, Boxes, or Combined prompts.")
257
 
258
  with gr.Row():
259
  # Left Column: Inputs
260
  with gr.Column(scale=1):
261
  input_image = gr.Image(label="Input Image", type="pil", height=350)
262
 
263
+ task_type = gr.Dropdown(
 
 
264
  label="Task Type",
265
+ choices=[
266
+ "Text Prompt",
267
+ "Single Bounding Box",
268
+ "Multiple Boxes (Positive)",
269
+ "Text + Negative Box"
270
+ ],
271
+ value="Text Prompt",
272
  interactive=True
273
  )
274
 
275
+ # Conditional Inputs
276
+ text_prompt_input = gr.Textbox(
277
  label="Text Prompt",
278
+ placeholder="e.g., cat, ear, car wheel",
 
279
  visible=True
280
  )
281
 
282
+ box_input = gr.Textbox(
283
+ label="Box Coordinates (x1, y1, x2, y2)",
284
+ placeholder="e.g., 100, 150, 500, 450",
285
+ info="For multiple boxes, separate with semicolon ';'. E.g., 10,10,50,50; 60,60,100,100",
286
+ visible=False
287
+ )
288
+
289
  threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
290
 
291
+ run_button = gr.Button("Segment Image", variant="primary")
292
 
293
  # Right Column: Output
294
  with gr.Column(scale=1.5):
295
  output_image = gr.AnnotatedImage(label="Segmented Output", height=500)
296
 
297
+ # Logic to toggle visibility of inputs based on dropdown
298
+ def update_inputs(task):
299
+ if task == "Text Prompt":
300
+ return gr.update(visible=True), gr.update(visible=False)
301
+ elif task == "Single Bounding Box":
302
+ return gr.update(visible=False), gr.update(visible=True, label="Single Box (x1, y1, x2, y2)")
303
+ elif task == "Multiple Boxes (Positive)":
304
+ return gr.update(visible=False), gr.update(visible=True, label="Multiple Boxes (x1,y1,x2,y2; x1,y1,x2,y2)")
305
+ elif task == "Text + Negative Box":
306
+ return gr.update(visible=True), gr.update(visible=True, label="Negative Box to Exclude (x1, y1, x2, y2)")
307
+ return gr.update(visible=True), gr.update(visible=True)
308
+
309
  task_type.change(
310
+ fn=update_inputs,
311
  inputs=[task_type],
312
+ outputs=[text_prompt_input, box_input]
313
  )
314
 
315
  # Examples
316
  gr.Examples(
317
  examples=[
318
+ ["examples/cat.jpg", "Text Prompt", "cat", "", 0.5],
319
+ ["examples/car.jpg", "Single Bounding Box", "", "100, 200, 400, 500", 0.5],
320
+ ["examples/fruit.jpg", "Text + Negative Box", "apple", "50, 50, 100, 100", 0.4],
321
  ],
322
+ inputs=[input_image, task_type, text_prompt_input, box_input, threshold],
323
  outputs=[output_image],
324
+ fn=process_sam3,
325
  cache_examples=False,
326
+ label="Examples (Ensure files exist and coordinates match images)"
327
  )
328
 
329
  run_button.click(
330
+ fn=process_sam3,
331
+ inputs=[input_image, task_type, text_prompt_input, box_input, threshold],
332
  outputs=[output_image]
333
  )
334