prithivMLmods commited on
Commit
b01e247
·
verified ·
1 Parent(s): 1b33b1c

update app

Browse files
Files changed (1) hide show
  1. app.py +36 -157
app.py CHANGED
@@ -94,6 +94,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
94
  print(f"Using device: {device}")
95
 
96
  # --- Model Loading ---
 
97
  try:
98
  print("Loading SAM3 Model and Processor...")
99
  model = Sam3Model.from_pretrained("facebook/sam3").to(device)
@@ -101,113 +102,32 @@ try:
101
  print("Model loaded successfully.")
102
  except Exception as e:
103
  print(f"Error loading model: {e}")
104
- print("Ensure you have the correct libraries installed (transformers>=4.40.0) and access to the model.")
 
105
  model = None
106
  processor = None
107
 
108
- # --- Helper Functions ---
109
-
110
- def parse_boxes(box_str):
111
- """
112
- Parses a string of coordinates into a list of lists.
113
- Format expected: "x1,y1,x2,y2" or "x1,y1,x2,y2; x3,y3,x4,y4"
114
- """
115
- try:
116
- boxes = []
117
- # Split by semicolon for multiple boxes
118
- segments = box_str.split(';')
119
- for seg in segments:
120
- if not seg.strip():
121
- continue
122
- coords = [float(c.strip()) for c in seg.split(',')]
123
- if len(coords) != 4:
124
- raise ValueError(f"Expected 4 coordinates per box, got {len(coords)}")
125
- boxes.append(coords)
126
- return boxes
127
- except Exception as e:
128
- raise ValueError(f"Invalid box format: {e}")
129
-
130
  @spaces.GPU(duration=60)
131
- def process_sam3(input_image, task_type, text_prompt, box_input, threshold=0.5):
132
  if input_image is None:
133
  raise gr.Error("Please upload an image.")
 
 
134
 
135
  if model is None or processor is None:
136
  raise gr.Error("Model not loaded correctly.")
137
 
 
138
  image_pil = input_image.convert("RGB")
139
- inputs = {}
140
-
141
- # Logic branching based on Task Type
142
- try:
143
- if task_type == "Text Prompt":
144
- if not text_prompt:
145
- raise gr.Error("Please enter a text prompt.")
146
- inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
147
- display_label_prefix = text_prompt
148
-
149
- elif task_type == "Single Bounding Box":
150
- if not box_input:
151
- raise gr.Error("Please enter box coordinates.")
152
- boxes = parse_boxes(box_input)
153
- if len(boxes) != 1:
154
- raise gr.Error("Please provide exactly one box for this mode.")
155
-
156
- input_boxes = [boxes] # [batch_size, num_boxes, 4]
157
- input_boxes_labels = [[1]] # 1 = positive
158
-
159
- inputs = processor(
160
- images=image_pil,
161
- input_boxes=input_boxes,
162
- input_boxes_labels=input_boxes_labels,
163
- return_tensors="pt"
164
- ).to(device)
165
- display_label_prefix = "Box"
166
-
167
- elif task_type == "Multiple Boxes (Positive)":
168
- if not box_input:
169
- raise gr.Error("Please enter box coordinates.")
170
- boxes = parse_boxes(box_input) # Returns list of [x1,y1,x2,y2]
171
-
172
- input_boxes = [boxes] # [batch, num_boxes, 4]
173
- # All labels 1 (positive)
174
- input_boxes_labels = [[1] * len(boxes)]
175
-
176
- inputs = processor(
177
- images=image_pil,
178
- input_boxes=input_boxes,
179
- input_boxes_labels=input_boxes_labels,
180
- return_tensors="pt"
181
- ).to(device)
182
- display_label_prefix = "Multi-Box"
183
 
184
- elif task_type == "Text + Negative Box":
185
- if not text_prompt or not box_input:
186
- raise gr.Error("Please provide both Text Prompt and Box Coordinates.")
187
-
188
- boxes = parse_boxes(box_input)
189
-
190
- input_boxes = [boxes]
191
- # Labels 0 (negative/exclude)
192
- input_boxes_labels = [[0] * len(boxes)]
193
-
194
- inputs = processor(
195
- images=image_pil,
196
- text=text_prompt,
197
- input_boxes=input_boxes,
198
- input_boxes_labels=input_boxes_labels,
199
- return_tensors="pt"
200
- ).to(device)
201
- display_label_prefix = f"{text_prompt} (Excl. Box)"
202
-
203
- except ValueError as e:
204
- raise gr.Error(str(e))
205
 
206
  # Inference
207
  with torch.no_grad():
208
  outputs = model(**inputs)
209
 
210
- # Post-processing
211
  results = processor.post_process_instance_segmentation(
212
  outputs,
213
  threshold=threshold,
@@ -215,120 +135,79 @@ def process_sam3(input_image, task_type, text_prompt, box_input, threshold=0.5):
215
  target_sizes=inputs.get("original_sizes").tolist()
216
  )[0]
217
 
218
- masks = results['masks']
219
  scores = results['scores']
220
 
221
- # Prepare AnnotatedImage Output
 
 
222
  annotations = []
223
  masks_np = masks.cpu().numpy()
224
  scores_np = scores.cpu().numpy()
225
 
226
  for i, mask in enumerate(masks_np):
 
 
 
227
  score_val = scores_np[i]
228
- label = f"{display_label_prefix} ({score_val:.2f})"
229
  annotations.append((mask, label))
230
 
 
231
  return (image_pil, annotations)
232
 
233
- # --- UI Logic ---
234
  css="""
235
  #col-container {
236
  margin: 0 auto;
237
- max-width: 1100px;
238
- }
239
- #main-title h1 {
240
- font-size: 2.1em !important;
241
- display: flex;
242
- align-items: center;
243
- justify-content: center;
244
- gap: 10px;
245
  }
 
246
  """
247
 
248
  with gr.Blocks(css=css, theme=plum_theme) as demo:
249
  with gr.Column(elem_id="col-container"):
250
- # Header with Logo
251
  gr.Markdown(
252
- "# **SAM3 Image Segmentation** <img src='https://huggingface.co/spaces/prithivMLmods/Qwen-Image-Edit-2509-LoRAs-Fast-Fusion/resolve/main/Lora%20Huggy.png' alt='Logo' width='35' height='35' style='display: inline-block; vertical-align: text-bottom; margin-left: 5px;'>",
253
  elem_id="main-title"
254
  )
255
 
256
- gr.Markdown("Perform advanced segmentation using **SAM3** with Text, Boxes, or Combined prompts.")
257
 
258
  with gr.Row():
259
  # Left Column: Inputs
260
  with gr.Column(scale=1):
261
  input_image = gr.Image(label="Input Image", type="pil", height=350)
262
-
263
- task_type = gr.Dropdown(
264
- label="Task Type",
265
- choices=[
266
- "Text Prompt",
267
- "Single Bounding Box",
268
- "Multiple Boxes (Positive)",
269
- "Text + Negative Box"
270
- ],
271
- value="Text Prompt",
272
- interactive=True
273
- )
274
-
275
- # Conditional Inputs
276
- text_prompt_input = gr.Textbox(
277
  label="Text Prompt",
278
- placeholder="e.g., cat, ear, car wheel",
279
- visible=True
280
- )
281
-
282
- box_input = gr.Textbox(
283
- label="Box Coordinates (x1, y1, x2, y2)",
284
- placeholder="e.g., 100, 150, 500, 450",
285
- info="For multiple boxes, separate with semicolon ';'. E.g., 10,10,50,50; 60,60,100,100",
286
- visible=False
287
  )
288
-
289
  threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
290
 
291
- run_button = gr.Button("Segment Image", variant="primary")
292
 
293
  # Right Column: Output
294
  with gr.Column(scale=1.5):
 
295
  output_image = gr.AnnotatedImage(label="Segmented Output", height=500)
296
 
297
- # Logic to toggle visibility of inputs based on dropdown
298
- def update_inputs(task):
299
- if task == "Text Prompt":
300
- return gr.update(visible=True), gr.update(visible=False)
301
- elif task == "Single Bounding Box":
302
- return gr.update(visible=False), gr.update(visible=True, label="Single Box (x1, y1, x2, y2)")
303
- elif task == "Multiple Boxes (Positive)":
304
- return gr.update(visible=False), gr.update(visible=True, label="Multiple Boxes (x1,y1,x2,y2; x1,y1,x2,y2)")
305
- elif task == "Text + Negative Box":
306
- return gr.update(visible=True), gr.update(visible=True, label="Negative Box to Exclude (x1, y1, x2, y2)")
307
- return gr.update(visible=True), gr.update(visible=True)
308
-
309
- task_type.change(
310
- fn=update_inputs,
311
- inputs=[task_type],
312
- outputs=[text_prompt_input, box_input]
313
- )
314
-
315
  # Examples
316
  gr.Examples(
317
  examples=[
318
- ["examples/cat.jpg", "Text Prompt", "cat", "", 0.5],
319
- ["examples/car.jpg", "Single Bounding Box", "", "100, 200, 400, 500", 0.5],
320
- ["examples/fruit.jpg", "Text + Negative Box", "apple", "50, 50, 100, 100", 0.4],
321
  ],
322
- inputs=[input_image, task_type, text_prompt_input, box_input, threshold],
323
  outputs=[output_image],
324
- fn=process_sam3,
325
  cache_examples=False,
326
- label="Examples (Ensure files exist and coordinates match images)"
327
  )
328
 
329
  run_button.click(
330
- fn=process_sam3,
331
- inputs=[input_image, task_type, text_prompt_input, box_input, threshold],
332
  outputs=[output_image]
333
  )
334
 
 
94
  print(f"Using device: {device}")
95
 
96
  # --- Model Loading ---
97
+ # Using the facebook/sam3 model as requested
98
  try:
99
  print("Loading SAM3 Model and Processor...")
100
  model = Sam3Model.from_pretrained("facebook/sam3").to(device)
 
102
  print("Model loaded successfully.")
103
  except Exception as e:
104
  print(f"Error loading model: {e}")
105
+ print("Ensure you have the correct libraries installed and access to the model.")
106
+ # Fallback/Placeholder for demonstration if model doesn't exist in environment yet
107
  model = None
108
  processor = None
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  @spaces.GPU(duration=60)
111
+ def segment_image(input_image, text_prompt, threshold=0.5):
112
  if input_image is None:
113
  raise gr.Error("Please upload an image.")
114
+ if not text_prompt:
115
+ raise gr.Error("Please enter a text prompt (e.g., 'cat', 'face').")
116
 
117
  if model is None or processor is None:
118
  raise gr.Error("Model not loaded correctly.")
119
 
120
+ # Convert image to RGB
121
  image_pil = input_image.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ # Preprocess
124
+ inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  # Inference
127
  with torch.no_grad():
128
  outputs = model(**inputs)
129
 
130
+ # Post-process results
131
  results = processor.post_process_instance_segmentation(
132
  outputs,
133
  threshold=threshold,
 
135
  target_sizes=inputs.get("original_sizes").tolist()
136
  )[0]
137
 
138
+ masks = results['masks'] # Boolean tensor [N, H, W]
139
  scores = results['scores']
140
 
141
+ # Prepare for Gradio AnnotatedImage
142
+ # Gradio expects (image, [(mask, label), ...])
143
+
144
  annotations = []
145
  masks_np = masks.cpu().numpy()
146
  scores_np = scores.cpu().numpy()
147
 
148
  for i, mask in enumerate(masks_np):
149
+ # mask is a boolean array (True/False).
150
+ # AnnotatedImage handles the coloring automatically.
151
+ # We just pass the mask and a label.
152
  score_val = scores_np[i]
153
+ label = f"{text_prompt} ({score_val:.2f})"
154
  annotations.append((mask, label))
155
 
156
+ # Return tuple format for AnnotatedImage
157
  return (image_pil, annotations)
158
 
 
159
  css="""
160
  #col-container {
161
  margin: 0 auto;
162
+ max-width: 980px;
 
 
 
 
 
 
 
163
  }
164
+ #main-title h1 {font-size: 2.1em !important;}
165
  """
166
 
167
  with gr.Blocks(css=css, theme=plum_theme) as demo:
168
  with gr.Column(elem_id="col-container"):
 
169
  gr.Markdown(
170
+ "# **SAM3 Image Segmentation**",
171
  elem_id="main-title"
172
  )
173
 
174
+ gr.Markdown("Segment objects in images using **SAM3** (Segment Anything Model 3) with text prompts.")
175
 
176
  with gr.Row():
177
  # Left Column: Inputs
178
  with gr.Column(scale=1):
179
  input_image = gr.Image(label="Input Image", type="pil", height=350)
180
+ text_prompt = gr.Textbox(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  label="Text Prompt",
182
+ placeholder="e.g., cat, ear, car wheel...",
183
+ info="What do you want to segment?"
 
 
 
 
 
 
 
184
  )
 
185
  threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
186
 
187
+ run_button = gr.Button("Segment", variant="primary")
188
 
189
  # Right Column: Output
190
  with gr.Column(scale=1.5):
191
+ # AnnotatedImage creates a nice overlay visualization
192
  output_image = gr.AnnotatedImage(label="Segmented Output", height=500)
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  # Examples
195
  gr.Examples(
196
  examples=[
197
+ ["examples/cat.jpg", "cat", 0.5],
198
+ ["examples/car.jpg", "tire", 0.4],
199
+ ["examples/fruit.jpg", "apple", 0.5],
200
  ],
201
+ inputs=[input_image, text_prompt, threshold],
202
  outputs=[output_image],
203
+ fn=segment_image,
204
  cache_examples=False,
205
+ label="Examples"
206
  )
207
 
208
  run_button.click(
209
+ fn=segment_image,
210
+ inputs=[input_image, text_prompt, threshold],
211
  outputs=[output_image]
212
  )
213