moondream3 / app.py
merve's picture
merve HF Staff
change default value for vid example
db4e94e verified
import json
import time
import cv2
import tempfile
import os
import gradio as gr
import numpy as np
from gradio.themes.ocean import Ocean
from PIL import Image
import torch
from transformers import AutoModelForCausalLM
import supervision as sv
import spaces
model_id = "moondream/moondream3-preview"
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map={"": "cuda"},
)
model.compile()
def create_annotated_image(image, detection_result, object_name="Object"):
if not isinstance(detection_result, dict) or "objects" not in detection_result:
return image
original_width, original_height = image.size
annotated_image = np.array(image.convert("RGB"))
bboxes = []
labels = []
for i, obj in enumerate(detection_result["objects"]):
x_min = int(obj["x_min"] * original_width)
y_min = int(obj["y_min"] * original_height)
x_max = int(obj["x_max"] * original_width)
y_max = int(obj["y_max"] * original_height)
x_min = max(0, min(x_min, original_width))
y_min = max(0, min(y_min, original_height))
x_max = max(0, min(x_max, original_width))
y_max = max(0, min(y_max, original_height))
if x_max > x_min and y_max > y_min:
bboxes.append([x_min, y_min, x_max, y_max])
labels.append(f"{object_name} {i+1}")
print(f"Box {i+1}: ({x_min}, {y_min}, {x_max}, {y_max})")
detections = sv.Detections(
xyxy=np.array(bboxes, dtype=np.float32),
class_id=np.arange(len(bboxes))
)
bounding_box_annotator = sv.BoxAnnotator(
thickness=3,
color_lookup=sv.ColorLookup.INDEX
)
label_annotator = sv.LabelAnnotator(
text_thickness=2,
text_scale=0.6,
color_lookup=sv.ColorLookup.INDEX
)
annotated_image = bounding_box_annotator.annotate(
scene=annotated_image, detections=detections
)
annotated_image = label_annotator.annotate(
scene=annotated_image, detections=detections, labels=labels
)
return Image.fromarray(annotated_image)
@spaces.GPU()
def process_video_with_tracking(video_path, prompt, detection_interval=3):
cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
byte_tracker = sv.ByteTrack()
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, "tracked_video.mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_count = 0
detection_count = 0
last_detections = None
try:
while True:
ret, frame = cap.read()
if not ret:
break
run_detection = (frame_count % detection_interval == 0)
if run_detection:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
result = model.detect(pil_image, prompt)
detection_count += 1
if "objects" in result and result["objects"]:
bboxes = []
confidences = []
for obj in result["objects"]:
x_min = max(0.0, min(1.0, obj["x_min"])) * width
y_min = max(0.0, min(1.0, obj["y_min"])) * height
x_max = max(0.0, min(1.0, obj["x_max"])) * width
y_max = max(0.0, min(1.0, obj["y_max"])) * height
if x_max > x_min and y_max > y_min:
bboxes.append([x_min, y_min, x_max, y_max])
confidences.append(0.8)
if bboxes:
detections = sv.Detections(
xyxy=np.array(bboxes, dtype=np.float32),
confidence=np.array(confidences, dtype=np.float32),
class_id=np.zeros(len(bboxes), dtype=int)
)
detections = byte_tracker.update_with_detections(detections)
last_detections = detections
else:
empty_detections = sv.Detections.empty()
detections = byte_tracker.update_with_detections(empty_detections)
last_detections = detections
else:
empty_detections = sv.Detections.empty()
detections = byte_tracker.update_with_detections(empty_detections)
last_detections = detections
else:
empty_detections = sv.Detections.empty()
detections = byte_tracker.update_with_detections(empty_detections)
if detections is not None and len(detections) > 0:
box_annotator = sv.BoxAnnotator(
thickness=3,
color_lookup=sv.ColorLookup.TRACK
)
label_annotator = sv.LabelAnnotator(
text_scale=0.6,
text_thickness=2,
color_lookup=sv.ColorLookup.TRACK
)
labels = []
for tracker_id in detections.tracker_id:
if tracker_id is not None:
labels.append(f"{prompt} ID: {tracker_id}")
else:
labels.append(f"{prompt} Unknown")
frame = box_annotator.annotate(scene=frame, detections=detections)
frame = label_annotator.annotate(scene=frame, detections=detections, labels=labels)
out.write(frame)
frame_count += 1
if frame_count % 30 == 0:
progress = (frame_count / total_frames) * 100
print(f"Processing: {progress:.1f}% ({frame_count}/{total_frames}) - Detections: {detection_count}")
finally:
cap.release()
out.release()
summary = f"""Video processing complete:
- Total frames processed: {frame_count}
- Detection runs: {detection_count} (every {detection_interval} frames)
- Objects tracked: {prompt}
- Processing speed: ~{detection_count/frame_count*100:.1f}% detection rate for optimization"""
return output_path, summary
def create_point_annotated_image(image, point_result):
"""Create annotated image with points for detected objects."""
if not isinstance(point_result, dict) or "points" not in point_result:
return image
original_width, original_height = image.size
annotated_image = np.array(image.convert("RGB"))
points = []
for point in point_result["points"]:
x = int(point["x"] * original_width)
y = int(point["y"] * original_height)
points.append([x, y])
if points:
points_array = np.array(points).reshape(1, -1, 2)
key_points = sv.KeyPoints(xy=points_array)
vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED)
annotated_image = vertex_annotator.annotate(
scene=annotated_image, key_points=key_points
)
return Image.fromarray(annotated_image)
@spaces.GPU()
def detect_objects(image, prompt, task_type, max_objects):
STANDARD_SIZE = (1024, 1024)
image.thumbnail(STANDARD_SIZE)
t0 = time.perf_counter()
if task_type == "Object Detection":
settings = {"max_objects": max_objects} if max_objects > 0 else {}
result = model.detect(image, prompt, settings=settings)
annotated_image = create_annotated_image(image, result, prompt)
elif task_type == "Point Detection":
result = model.point(image, prompt)
annotated_image = create_point_annotated_image(image, result)
elif task_type == "Caption":
result = model.caption(image, length="normal")
annotated_image = image
else:
result = model.query(image=image, question=prompt, reasoning=True)
annotated_image = image
elapsed_ms = (time.perf_counter() - t0) * 1_000
if isinstance(result, dict):
if "objects" in result:
output_text = f"Found {len(result['objects'])} objects:\n"
for i, obj in enumerate(result['objects'], 1):
output_text += f"\n{i}. Bounding box: "
output_text += f"({obj['x_min']:.3f}, {obj['y_min']:.3f}, {obj['x_max']:.3f}, {obj['y_max']:.3f})"
elif "points" in result:
output_text = f"Found {len(result['points'])} points:\n"
for i, point in enumerate(result['points'], 1):
output_text += f"\n{i}. Point: ({point['x']:.3f}, {point['y']:.3f})"
elif "caption" in result:
output_text = result['caption']
elif "answer" in result:
if "reasoning" in result:
output_text = f"Reasoning: {result['reasoning']}\n\nAnswer: {result['answer']}"
else:
output_text = result['answer']
else:
output_text = json.dumps(result, indent=2)
else:
output_text = str(result)
timing_text = f"Inference time: {elapsed_ms:.0f} ms"
return annotated_image, output_text, timing_text
def process_video(video_file, prompt, detection_interval):
if video_file is None:
return None, "Please upload a video file"
output_path, summary = process_video_with_tracking(
video_file, prompt, detection_interval
)
return output_path, summary
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Moondream3 🌝")
gr.Markdown("""
*Try [Moondream3 Preview](https://huggingface.co/moondream/moondream3-preview) for following tasks:*
- **Object Detection**
- **Point Detection**
- **Captioning**
- **Visual Question Answering**
- **Video Object Tracking**
""")
with gr.Tabs() as tabs:
with gr.Tab("Image Processing"):
with gr.Row():
with gr.Column(scale=2):
image_input = gr.Image(label="Upload an image", type="pil", height=400)
task_type = gr.Radio(
choices=["Object Detection", "Point Detection", "Caption", "Visual Question Answering"],
label="Task Type",
value="Object Detection"
)
prompt_input = gr.Textbox(
label="Prompt (object to detect/question to ask)",
placeholder="e.g., 'car', 'person', 'What's in this image?'",
value="objects"
)
max_objects = gr.Number(
label="Max Objects (for Object Detection only)",
value=10,
minimum=1,
maximum=50,
step=1,
visible=True
)
generate_btn = gr.Button(value="Generate", variant="primary")
with gr.Column(scale=2):
output_image = gr.Image(
type="pil",
label="Result",
height=400
)
output_textbox = gr.Textbox(
label="Model Response",
lines=10,
show_copy_button=True
)
output_time = gr.Markdown()
gr.Markdown("### Examples")
example_prompts = [
[
"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG",
"Object Detection",
"candy",
5
],
[
"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG",
"Point Detection",
"candy",
5
],
[
"https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg",
"Caption",
"",
5
],
[
"https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg",
"Visual Question Answering",
"how well does moondream 3 perform in chartvqa?",
5
],
]
gr.Examples(
examples=example_prompts,
inputs=[image_input, task_type, prompt_input, max_objects],
label="Click an example to populate inputs"
)
with gr.Tab("Video Object Tracking"):
with gr.Row():
with gr.Column(scale=2):
video_input = gr.Video(
label="Upload a video file",
height=400
)
video_prompt = gr.Textbox(
label="Object to track",
placeholder="e.g., 'person', 'car', 'ball'",
value="person"
)
detection_interval = gr.Slider(
minimum=10,
maximum=30,
value=15,
step=5,
label="Detection Interval (frames)",
info="Run detection every N frames (less is slower but more accurate, ZeroGPU might time out with long videos)"
)
process_video_btn = gr.Button(value="Process Video", variant="primary")
with gr.Column(scale=2):
output_video = gr.Video(
label="Tracked Video Result",
height=400
)
video_summary = gr.Textbox(
label="Processing Summary",
lines=8,
show_copy_button=True
)
gr.Markdown("### Examples")
example_prompts = [
[
"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4",
"snowboarder",
15
],
]
gr.Examples(
examples=example_prompts,
inputs=[video_input, video_prompt, detection_interval],
label="Click an example to populate inputs"
)
def update_max_objects_visibility(task):
return gr.Number(visible=(task == "Object Detection"))
task_type.change(
fn=update_max_objects_visibility,
inputs=[task_type],
outputs=[max_objects]
)
generate_btn.click(
fn=detect_objects,
inputs=[image_input, prompt_input, task_type, max_objects],
outputs=[output_image, output_textbox, output_time]
)
process_video_btn.click(
fn=process_video,
inputs=[video_input, video_prompt, detection_interval],
outputs=[output_video, video_summary]
)
if __name__ == "__main__":
demo.launch()