prithivMLmods commited on
Commit
1fd4203
·
verified ·
1 Parent(s): 74593d4

update app

Browse files
Files changed (1) hide show
  1. app.py +122 -56
app.py CHANGED
@@ -2,8 +2,7 @@ import os
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
- import random
6
- from PIL import Image, ImageDraw
7
  from typing import Iterable
8
  from gradio.themes import Soft
9
  from gradio.themes.utils import colors, fonts, sizes
@@ -94,7 +93,6 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
94
  print(f"Using device: {device}")
95
 
96
  # --- Model Loading ---
97
- # Using the facebook/sam3 model as requested
98
  try:
99
  print("Loading SAM3 Model and Processor...")
100
  model = Sam3Model.from_pretrained("facebook/sam3").to(device)
@@ -102,112 +100,180 @@ try:
102
  print("Model loaded successfully.")
103
  except Exception as e:
104
  print(f"Error loading model: {e}")
105
- print("Ensure you have the correct libraries installed and access to the model.")
106
- # Fallback/Placeholder for demonstration if model doesn't exist in environment yet
107
  model = None
108
  processor = None
109
 
110
  @spaces.GPU(duration=60)
111
- def segment_image(input_image, text_prompt, threshold=0.5):
112
  if input_image is None:
113
  raise gr.Error("Please upload an image.")
114
- if not text_prompt:
115
- raise gr.Error("Please enter a text prompt (e.g., 'cat', 'face').")
116
 
117
  if model is None or processor is None:
118
  raise gr.Error("Model not loaded correctly.")
119
 
120
  # Convert image to RGB
121
  image_pil = input_image.convert("RGB")
 
122
 
123
- # Preprocess
124
- inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
125
-
126
- # Inference
127
  with torch.no_grad():
128
- outputs = model(**inputs)
129
-
130
- # Post-process results
131
- results = processor.post_process_instance_segmentation(
132
- outputs,
133
- threshold=threshold,
134
- mask_threshold=0.5,
135
- target_sizes=inputs.get("original_sizes").tolist()
136
- )[0]
137
-
138
- masks = results['masks'] # Boolean tensor [N, H, W]
139
- scores = results['scores']
140
-
141
- # Prepare for Gradio AnnotatedImage
142
- # Gradio expects (image, [(mask, label), ...])
143
-
144
- annotations = []
145
- masks_np = masks.cpu().numpy()
146
- scores_np = scores.cpu().numpy()
147
-
148
- for i, mask in enumerate(masks_np):
149
- # mask is a boolean array (True/False).
150
- # AnnotatedImage handles the coloring automatically.
151
- # We just pass the mask and a label.
152
- score_val = scores_np[i]
153
- label = f"{text_prompt} ({score_val:.2f})"
154
- annotations.append((mask, label))
155
-
156
- # Return tuple format for AnnotatedImage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  return (image_pil, annotations)
158
 
 
159
  css="""
160
  #col-container {
161
  margin: 0 auto;
162
- max-width: 980px;
 
 
 
 
 
 
 
163
  }
164
- #main-title h1 {font-size: 2.1em !important;}
165
  """
166
 
 
 
 
 
 
 
167
  with gr.Blocks(css=css, theme=plum_theme) as demo:
168
  with gr.Column(elem_id="col-container"):
 
169
  gr.Markdown(
170
- "# **SAM3 Image Segmentation**",
171
  elem_id="main-title"
172
  )
173
 
174
- gr.Markdown("Segment objects in images using **SAM3** (Segment Anything Model 3) with text prompts.")
175
 
176
  with gr.Row():
177
  # Left Column: Inputs
178
  with gr.Column(scale=1):
179
  input_image = gr.Image(label="Input Image", type="pil", height=350)
 
 
 
 
 
 
 
 
180
  text_prompt = gr.Textbox(
181
  label="Text Prompt",
182
  placeholder="e.g., cat, ear, car wheel...",
183
- info="What do you want to segment?"
 
184
  )
 
185
  threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
186
 
187
- run_button = gr.Button("Segment", variant="primary")
188
 
189
  # Right Column: Output
190
  with gr.Column(scale=1.5):
191
- # AnnotatedImage creates a nice overlay visualization
192
  output_image = gr.AnnotatedImage(label="Segmented Output", height=500)
193
 
 
 
 
 
 
 
 
194
  # Examples
195
  gr.Examples(
196
  examples=[
197
- ["examples/cat.jpg", "cat", 0.5],
198
- ["examples/car.jpg", "tire", 0.4],
199
- ["examples/fruit.jpg", "apple", 0.5],
200
  ],
201
- inputs=[input_image, text_prompt, threshold],
202
  outputs=[output_image],
203
- fn=segment_image,
204
  cache_examples=False,
205
- label="Examples (Ensure files exist in 'examples/' folder)"
206
  )
207
 
208
  run_button.click(
209
- fn=segment_image,
210
- inputs=[input_image, text_prompt, threshold],
211
  outputs=[output_image]
212
  )
213
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
+ from PIL import Image
 
6
  from typing import Iterable
7
  from gradio.themes import Soft
8
  from gradio.themes.utils import colors, fonts, sizes
 
93
  print(f"Using device: {device}")
94
 
95
  # --- Model Loading ---
 
96
  try:
97
  print("Loading SAM3 Model and Processor...")
98
  model = Sam3Model.from_pretrained("facebook/sam3").to(device)
 
100
  print("Model loaded successfully.")
101
  except Exception as e:
102
  print(f"Error loading model: {e}")
 
 
103
  model = None
104
  processor = None
105
 
106
  @spaces.GPU(duration=60)
107
+ def process_image(input_image, task_type, text_prompt, threshold=0.5):
108
  if input_image is None:
109
  raise gr.Error("Please upload an image.")
 
 
110
 
111
  if model is None or processor is None:
112
  raise gr.Error("Model not loaded correctly.")
113
 
114
  # Convert image to RGB
115
  image_pil = input_image.convert("RGB")
116
+ annotations = []
117
 
 
 
 
 
118
  with torch.no_grad():
119
+ if task_type == "Instance Segmentation":
120
+ if not text_prompt:
121
+ raise gr.Error("Please enter a text prompt for Instance Segmentation.")
122
+
123
+ # 1. Instance Segmentation Flow (Text Prompt)
124
+ inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
125
+ outputs = model(**inputs)
126
+
127
+ # Post-process instance masks
128
+ results = processor.post_process_instance_segmentation(
129
+ outputs,
130
+ threshold=threshold,
131
+ mask_threshold=0.5,
132
+ target_sizes=inputs.get("original_sizes").tolist()
133
+ )[0]
134
+
135
+ masks_np = results['masks'].cpu().numpy() # [N, H, W]
136
+ scores_np = results['scores'].cpu().numpy()
137
+
138
+ for i, mask in enumerate(masks_np):
139
+ score_val = scores_np[i]
140
+ label = f"{text_prompt} ({score_val:.2f})"
141
+ annotations.append((mask, label))
142
+
143
+ elif task_type == "Semantic Segmentation":
144
+ # 2. Semantic Segmentation Flow (No Prompt)
145
+ # Call processor without text
146
+ inputs = processor(images=image_pil, return_tensors="pt").to(device)
147
+ outputs = model(**inputs)
148
+
149
+ # Extract semantic segmentation map
150
+ # Shape: [batch, channels, height, width]
151
+ semantic_seg = outputs.semantic_seg
152
+
153
+ # Process for visualization:
154
+ # Assuming semantic_seg is a dense map (e.g., saliency or class probabilities).
155
+ # Since the snippet implies a single channel [batch, 1, H, W], we threshold it.
156
+
157
+ # Remove batch dim -> [1, H, W] or [C, H, W]
158
+ seg_map = semantic_seg.squeeze(0)
159
+
160
+ # If 1 channel, create binary mask based on threshold/sigmoid
161
+ if seg_map.shape[0] == 1:
162
+ # Apply sigmoid if logits, or just threshold if probs
163
+ # Assuming logits for general safety in torch models
164
+ mask_tensor = torch.sigmoid(seg_map[0]) > threshold
165
+ mask_np = mask_tensor.cpu().numpy()
166
+
167
+ # Resize mask to original image size if needed
168
+ # (Note: outputs.semantic_seg is usually feature map size, might need upscaling)
169
+ # For simplicity in this snippet, we assume processor/output aligns or AnnotatedImage handles resizing (it usually requires matching sizes).
170
+ # If size mismatch occurs, we convert mask to PIL, resize, then back to numpy.
171
+
172
+ if mask_np.shape != (image_pil.height, image_pil.width):
173
+ mask_img = Image.fromarray(mask_np.astype(np.uint8) * 255)
174
+ mask_img = mask_img.resize(image_pil.size, Image.NEAREST)
175
+ mask_np = np.array(mask_img) > 128
176
+
177
+ annotations.append((mask_np, "Semantic Region"))
178
+ else:
179
+ # If multiple channels (classes), take argmax
180
+ # This logic depends on specific SAM3 output structure
181
+ mask_idx = torch.argmax(seg_map, dim=0).cpu().numpy()
182
+ # Just visualize non-background (assuming 0 is background)
183
+ mask_np = mask_idx > 0
184
+
185
+ if mask_np.shape != (image_pil.height, image_pil.width):
186
+ mask_img = Image.fromarray(mask_np.astype(np.uint8) * 255)
187
+ mask_img = mask_img.resize(image_pil.size, Image.NEAREST)
188
+ mask_np = np.array(mask_img) > 128
189
+
190
+ annotations.append((mask_np, "Segmented Objects"))
191
+
192
+ # Return tuple format for AnnotatedImage: (original_image, list_of_annotations)
193
  return (image_pil, annotations)
194
 
195
+ # --- UI Logic ---
196
  css="""
197
  #col-container {
198
  margin: 0 auto;
199
+ max-width: 1100px;
200
+ }
201
+ #main-title h1 {
202
+ font-size: 2.1em !important;
203
+ display: flex;
204
+ align-items: center;
205
+ justify-content: center;
206
+ gap: 10px;
207
  }
 
208
  """
209
 
210
+ def update_visibility(task):
211
+ if task == "Instance Segmentation":
212
+ return gr.update(visible=True)
213
+ else:
214
+ return gr.update(visible=False)
215
+
216
  with gr.Blocks(css=css, theme=plum_theme) as demo:
217
  with gr.Column(elem_id="col-container"):
218
+ # Header with Logo
219
  gr.Markdown(
220
+ "# **SAM3 Image Segmentation** <img src='https://huggingface.co/spaces/prithivMLmods/Qwen-Image-Edit-2509-LoRAs-Fast-Fusion/resolve/main/Lora%20Huggy.png' alt='Logo' width='35' height='35' style='display: inline-block; vertical-align: text-bottom; margin-left: 5px;'>",
221
  elem_id="main-title"
222
  )
223
 
224
+ gr.Markdown("Segment objects using **SAM3** (Segment Anything Model 3). Choose **Instance** for specific text prompts or **Semantic** for automatic segmentation.")
225
 
226
  with gr.Row():
227
  # Left Column: Inputs
228
  with gr.Column(scale=1):
229
  input_image = gr.Image(label="Input Image", type="pil", height=350)
230
+
231
+ task_type = gr.Radio(
232
+ choices=["Instance Segmentation", "Semantic Segmentation"],
233
+ value="Instance Segmentation",
234
+ label="Task Type",
235
+ interactive=True
236
+ )
237
+
238
  text_prompt = gr.Textbox(
239
  label="Text Prompt",
240
  placeholder="e.g., cat, ear, car wheel...",
241
+ info="Required for Instance Segmentation",
242
+ visible=True
243
  )
244
+
245
  threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
246
 
247
+ run_button = gr.Button("Run Segmentation", variant="primary")
248
 
249
  # Right Column: Output
250
  with gr.Column(scale=1.5):
 
251
  output_image = gr.AnnotatedImage(label="Segmented Output", height=500)
252
 
253
+ # Event: Hide text prompt when Semantic Segmentation is selected
254
+ task_type.change(
255
+ fn=update_visibility,
256
+ inputs=[task_type],
257
+ outputs=[text_prompt]
258
+ )
259
+
260
  # Examples
261
  gr.Examples(
262
  examples=[
263
+ ["examples/cat.jpg", "Instance Segmentation", "cat", 0.5],
264
+ ["examples/room.jpg", "Semantic Segmentation", "", 0.5],
265
+ ["examples/car.jpg", "Instance Segmentation", "tire", 0.4],
266
  ],
267
+ inputs=[input_image, task_type, text_prompt, threshold],
268
  outputs=[output_image],
269
+ fn=process_image,
270
  cache_examples=False,
271
+ label="Examples"
272
  )
273
 
274
  run_button.click(
275
+ fn=process_image,
276
+ inputs=[input_image, task_type, text_prompt, threshold],
277
  outputs=[output_image]
278
  )
279