SAM3-Demo / app.py
prithivMLmods's picture
update app
1b33b1c verified
raw
history blame
12.5 kB
import os
import gradio as gr
import numpy as np
import torch
import random
from PIL import Image, ImageDraw
from typing import Iterable
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
from transformers import Sam3Processor, Sam3Model
# --- Handle optional 'spaces' import for local compatibility ---
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(duration=60):
def decorator(func):
return func
return decorator
# --- Custom Theme Setup (Plum) ---
colors.plum = colors.Color(
name="plum",
c50="#FDF4FD",
c100="#F7E6F7",
c200="#ECD0EC",
c300="#DDA0DD", # Plum
c400="#C98BC9",
c500="#B060B0",
c600="#964B96",
c700="#7A3A7A",
c800="#602C60",
c900="#451E45",
c950="#2B122B",
)
class PlumTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.plum,
secondary_hue: colors.Color | str = colors.plum,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
self.set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_100, *primary_50)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *primary_500, *primary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *primary_600, *primary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *primary_800)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
button_secondary_text_color="black",
button_secondary_text_color_hover="white",
button_secondary_background_fill="linear-gradient(90deg, *primary_200, *primary_200)",
button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
slider_color="*primary_500",
slider_color_dark="*primary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
plum_theme = PlumTheme()
# --- Hardware Setup ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# --- Model Loading ---
try:
print("Loading SAM3 Model and Processor...")
model = Sam3Model.from_pretrained("facebook/sam3").to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
print("Ensure you have the correct libraries installed (transformers>=4.40.0) and access to the model.")
model = None
processor = None
# --- Helper Functions ---
def parse_boxes(box_str):
"""
Parses a string of coordinates into a list of lists.
Format expected: "x1,y1,x2,y2" or "x1,y1,x2,y2; x3,y3,x4,y4"
"""
try:
boxes = []
# Split by semicolon for multiple boxes
segments = box_str.split(';')
for seg in segments:
if not seg.strip():
continue
coords = [float(c.strip()) for c in seg.split(',')]
if len(coords) != 4:
raise ValueError(f"Expected 4 coordinates per box, got {len(coords)}")
boxes.append(coords)
return boxes
except Exception as e:
raise ValueError(f"Invalid box format: {e}")
@spaces.GPU(duration=60)
def process_sam3(input_image, task_type, text_prompt, box_input, threshold=0.5):
if input_image is None:
raise gr.Error("Please upload an image.")
if model is None or processor is None:
raise gr.Error("Model not loaded correctly.")
image_pil = input_image.convert("RGB")
inputs = {}
# Logic branching based on Task Type
try:
if task_type == "Text Prompt":
if not text_prompt:
raise gr.Error("Please enter a text prompt.")
inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
display_label_prefix = text_prompt
elif task_type == "Single Bounding Box":
if not box_input:
raise gr.Error("Please enter box coordinates.")
boxes = parse_boxes(box_input)
if len(boxes) != 1:
raise gr.Error("Please provide exactly one box for this mode.")
input_boxes = [boxes] # [batch_size, num_boxes, 4]
input_boxes_labels = [[1]] # 1 = positive
inputs = processor(
images=image_pil,
input_boxes=input_boxes,
input_boxes_labels=input_boxes_labels,
return_tensors="pt"
).to(device)
display_label_prefix = "Box"
elif task_type == "Multiple Boxes (Positive)":
if not box_input:
raise gr.Error("Please enter box coordinates.")
boxes = parse_boxes(box_input) # Returns list of [x1,y1,x2,y2]
input_boxes = [boxes] # [batch, num_boxes, 4]
# All labels 1 (positive)
input_boxes_labels = [[1] * len(boxes)]
inputs = processor(
images=image_pil,
input_boxes=input_boxes,
input_boxes_labels=input_boxes_labels,
return_tensors="pt"
).to(device)
display_label_prefix = "Multi-Box"
elif task_type == "Text + Negative Box":
if not text_prompt or not box_input:
raise gr.Error("Please provide both Text Prompt and Box Coordinates.")
boxes = parse_boxes(box_input)
input_boxes = [boxes]
# Labels 0 (negative/exclude)
input_boxes_labels = [[0] * len(boxes)]
inputs = processor(
images=image_pil,
text=text_prompt,
input_boxes=input_boxes,
input_boxes_labels=input_boxes_labels,
return_tensors="pt"
).to(device)
display_label_prefix = f"{text_prompt} (Excl. Box)"
except ValueError as e:
raise gr.Error(str(e))
# Inference
with torch.no_grad():
outputs = model(**inputs)
# Post-processing
results = processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=0.5,
target_sizes=inputs.get("original_sizes").tolist()
)[0]
masks = results['masks']
scores = results['scores']
# Prepare AnnotatedImage Output
annotations = []
masks_np = masks.cpu().numpy()
scores_np = scores.cpu().numpy()
for i, mask in enumerate(masks_np):
score_val = scores_np[i]
label = f"{display_label_prefix} ({score_val:.2f})"
annotations.append((mask, label))
return (image_pil, annotations)
# --- UI Logic ---
css="""
#col-container {
margin: 0 auto;
max-width: 1100px;
}
#main-title h1 {
font-size: 2.1em !important;
display: flex;
align-items: center;
justify-content: center;
gap: 10px;
}
"""
with gr.Blocks(css=css, theme=plum_theme) as demo:
with gr.Column(elem_id="col-container"):
# Header with Logo
gr.Markdown(
"# **SAM3 Image Segmentation** <img src='https://huggingface.co/spaces/prithivMLmods/Qwen-Image-Edit-2509-LoRAs-Fast-Fusion/resolve/main/Lora%20Huggy.png' alt='Logo' width='35' height='35' style='display: inline-block; vertical-align: text-bottom; margin-left: 5px;'>",
elem_id="main-title"
)
gr.Markdown("Perform advanced segmentation using **SAM3** with Text, Boxes, or Combined prompts.")
with gr.Row():
# Left Column: Inputs
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image", type="pil", height=350)
task_type = gr.Dropdown(
label="Task Type",
choices=[
"Text Prompt",
"Single Bounding Box",
"Multiple Boxes (Positive)",
"Text + Negative Box"
],
value="Text Prompt",
interactive=True
)
# Conditional Inputs
text_prompt_input = gr.Textbox(
label="Text Prompt",
placeholder="e.g., cat, ear, car wheel",
visible=True
)
box_input = gr.Textbox(
label="Box Coordinates (x1, y1, x2, y2)",
placeholder="e.g., 100, 150, 500, 450",
info="For multiple boxes, separate with semicolon ';'. E.g., 10,10,50,50; 60,60,100,100",
visible=False
)
threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
run_button = gr.Button("Segment Image", variant="primary")
# Right Column: Output
with gr.Column(scale=1.5):
output_image = gr.AnnotatedImage(label="Segmented Output", height=500)
# Logic to toggle visibility of inputs based on dropdown
def update_inputs(task):
if task == "Text Prompt":
return gr.update(visible=True), gr.update(visible=False)
elif task == "Single Bounding Box":
return gr.update(visible=False), gr.update(visible=True, label="Single Box (x1, y1, x2, y2)")
elif task == "Multiple Boxes (Positive)":
return gr.update(visible=False), gr.update(visible=True, label="Multiple Boxes (x1,y1,x2,y2; x1,y1,x2,y2)")
elif task == "Text + Negative Box":
return gr.update(visible=True), gr.update(visible=True, label="Negative Box to Exclude (x1, y1, x2, y2)")
return gr.update(visible=True), gr.update(visible=True)
task_type.change(
fn=update_inputs,
inputs=[task_type],
outputs=[text_prompt_input, box_input]
)
# Examples
gr.Examples(
examples=[
["examples/cat.jpg", "Text Prompt", "cat", "", 0.5],
["examples/car.jpg", "Single Bounding Box", "", "100, 200, 400, 500", 0.5],
["examples/fruit.jpg", "Text + Negative Box", "apple", "50, 50, 100, 100", 0.4],
],
inputs=[input_image, task_type, text_prompt_input, box_input, threshold],
outputs=[output_image],
fn=process_sam3,
cache_examples=False,
label="Examples (Ensure files exist and coordinates match images)"
)
run_button.click(
fn=process_sam3,
inputs=[input_image, task_type, text_prompt_input, box_input, threshold],
outputs=[output_image]
)
if __name__ == "__main__":
demo.launch(ssr_mode=False, show_error=True)