|
|
import gradio as gr |
|
|
import spaces |
|
|
import argparse |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
import warnings |
|
|
import torch |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection |
|
|
|
|
|
import supervision as sv |
|
|
|
|
|
|
|
|
MODEL_IDS = { |
|
|
"MM Grounding DINO Large": "rziga/mm_grounding_dino_large_all", |
|
|
"MM Grounding DINO Base": "rziga/mm_grounding_dino_base_all" |
|
|
} |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
loaded_model_name = None |
|
|
processor = None |
|
|
model = None |
|
|
|
|
|
@spaces.GPU |
|
|
def run_grounding(input_image, grounding_caption, model_choice, box_threshold, text_threshold): |
|
|
global loaded_model_name, processor, model |
|
|
|
|
|
|
|
|
if loaded_model_name != model_choice: |
|
|
model_id = MODEL_IDS[model_choice] |
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) |
|
|
loaded_model_name = model_choice |
|
|
|
|
|
if isinstance(input_image, np.ndarray): |
|
|
if input_image.ndim == 3: |
|
|
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) |
|
|
input_image = Image.fromarray(input_image) |
|
|
|
|
|
init_image = input_image.convert("RGB") |
|
|
|
|
|
|
|
|
|
|
|
text_labels = [[label.strip() for label in grounding_caption.split('.') if label.strip()]] |
|
|
|
|
|
|
|
|
inputs = processor(images=init_image, text=text_labels, return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
|
|
|
results = processor.post_process_grounded_object_detection( |
|
|
outputs, |
|
|
threshold=box_threshold, |
|
|
target_sizes=[(init_image.size[1], init_image.size[0])] |
|
|
) |
|
|
|
|
|
result = results[0] |
|
|
|
|
|
|
|
|
image_np = np.array(init_image) |
|
|
|
|
|
|
|
|
boxes = [] |
|
|
labels = [] |
|
|
confidences = [] |
|
|
class_ids = [] |
|
|
|
|
|
for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])): |
|
|
|
|
|
xyxy = box.tolist() |
|
|
boxes.append(xyxy) |
|
|
labels.append(label) |
|
|
confidences.append(float(score)) |
|
|
class_ids.append(i) |
|
|
|
|
|
|
|
|
if boxes: |
|
|
lines = [] |
|
|
for label, xyxy, conf in zip(labels, boxes, confidences): |
|
|
x1, y1, x2, y2 = [int(round(v)) for v in xyxy] |
|
|
|
|
|
lines.append(f"{label} {conf:.3f} {x1}, {y1}, {x2}, {y2}") |
|
|
detection_text = "\n".join(lines) |
|
|
else: |
|
|
detection_text = "No detections." |
|
|
|
|
|
|
|
|
if boxes: |
|
|
detections = sv.Detections( |
|
|
xyxy=np.array(boxes), |
|
|
confidence=np.array(confidences), |
|
|
class_id=np.array(class_ids, dtype=np.int32), |
|
|
) |
|
|
|
|
|
text_scale = sv.calculate_optimal_text_scale(resolution_wh=init_image.size) |
|
|
line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=init_image.size) |
|
|
|
|
|
|
|
|
box_annotator = sv.BoxAnnotator( |
|
|
thickness=2, |
|
|
color=sv.ColorPalette.DEFAULT, |
|
|
) |
|
|
|
|
|
label_annotator = sv.LabelAnnotator( |
|
|
color=sv.ColorPalette.DEFAULT, |
|
|
text_color=sv.Color.WHITE, |
|
|
text_scale=text_scale, |
|
|
text_thickness=line_thickness, |
|
|
text_padding=3 |
|
|
) |
|
|
|
|
|
|
|
|
formatted_labels = [ |
|
|
f"{label}: {conf:.2f}" |
|
|
for label, conf in zip(labels, confidences) |
|
|
] |
|
|
|
|
|
|
|
|
annotated_image = box_annotator.annotate(scene=image_np, detections=detections) |
|
|
annotated_image = label_annotator.annotate( |
|
|
scene=annotated_image, |
|
|
detections=detections, |
|
|
labels=formatted_labels |
|
|
) |
|
|
else: |
|
|
annotated_image = image_np |
|
|
|
|
|
|
|
|
image_with_box = Image.fromarray(annotated_image) |
|
|
|
|
|
|
|
|
return image_with_box, detection_text |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True) |
|
|
parser.add_argument("--debug", action="store_true", help="using debug mode") |
|
|
parser.add_argument("--share", action="store_true", help="share the app") |
|
|
args = parser.parse_args() |
|
|
|
|
|
css = """ |
|
|
#mkd { |
|
|
height: 500px; |
|
|
overflow: auto; |
|
|
border: 1px solid #ccc; |
|
|
} |
|
|
""" |
|
|
with gr.Blocks(css=css) as demo: |
|
|
gr.Markdown("<h1><center>MM Grounding DINO (Large & Base)<h1><center>") |
|
|
gr.Markdown("<h3><center>Open-World Detection with MM Grounding DINO Models<h3><center>") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image = gr.Image(label="Input Image", type="pil") |
|
|
model_choice = gr.Radio( |
|
|
choices=list(MODEL_IDS.keys()), |
|
|
value="MM Grounding DINO Large", |
|
|
label="Select Model", |
|
|
info="Choose between Large (better performance) or Base (faster) model" |
|
|
) |
|
|
grounding_caption = gr.Textbox( |
|
|
label="Detection Prompt (lowercase + each ends with a dot)", |
|
|
value="a person. a car." |
|
|
) |
|
|
run_button = gr.Button("Run") |
|
|
|
|
|
with gr.Accordion("Advanced options", open=False): |
|
|
box_threshold = gr.Slider( |
|
|
minimum=0.0, maximum=1.0, value=0.3, step=0.001, |
|
|
label="Box Threshold" |
|
|
) |
|
|
text_threshold = gr.Slider( |
|
|
minimum=0.0, maximum=1.0, value=0.25, step=0.001, |
|
|
label="Text Threshold (not used in MM Grounding DINO)", |
|
|
visible=False |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gallery = gr.Image( |
|
|
label="Detection Result", |
|
|
type="pil" |
|
|
) |
|
|
det_text = gr.Textbox( |
|
|
label="Detections (class confidence top_left_x, top_left_y, bot_x, bot_y)", |
|
|
lines=12, |
|
|
interactive=False, |
|
|
show_copy_button=True |
|
|
) |
|
|
|
|
|
run_button.click( |
|
|
fn=run_grounding, |
|
|
inputs=[input_image, grounding_caption, model_choice, box_threshold, text_threshold], |
|
|
outputs=[gallery, det_text] |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["000000039769.jpg", "a cat. a remote control.", "MM Grounding DINO Large", 0.3, 0.25], |
|
|
["KakaoTalk_20250430_163200504.jpg", "cup. screen. hand.", "MM Grounding DINO Base", 0.3, 0.25] |
|
|
], |
|
|
inputs=[input_image, grounding_caption, model_choice, box_threshold, text_threshold], |
|
|
outputs=[gallery, det_text], |
|
|
fn=run_grounding, |
|
|
cache_examples=True, |
|
|
) |
|
|
|
|
|
demo.launch(share=args.share, debug=args.debug, show_error=True) |
|
|
|