Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import time | |
| import cv2 | |
| import tempfile | |
| import os | |
| import gradio as gr | |
| import numpy as np | |
| from gradio.themes.ocean import Ocean | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoModelForCausalLM | |
| import supervision as sv | |
| import spaces | |
| model_id = "moondream/moondream3-preview" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map={"": "cuda"}, | |
| ) | |
| model.compile() | |
| def create_annotated_image(image, detection_result, object_name="Object"): | |
| if not isinstance(detection_result, dict) or "objects" not in detection_result: | |
| return image | |
| original_width, original_height = image.size | |
| annotated_image = np.array(image.convert("RGB")) | |
| bboxes = [] | |
| labels = [] | |
| for i, obj in enumerate(detection_result["objects"]): | |
| x_min = int(obj["x_min"] * original_width) | |
| y_min = int(obj["y_min"] * original_height) | |
| x_max = int(obj["x_max"] * original_width) | |
| y_max = int(obj["y_max"] * original_height) | |
| x_min = max(0, min(x_min, original_width)) | |
| y_min = max(0, min(y_min, original_height)) | |
| x_max = max(0, min(x_max, original_width)) | |
| y_max = max(0, min(y_max, original_height)) | |
| if x_max > x_min and y_max > y_min: | |
| bboxes.append([x_min, y_min, x_max, y_max]) | |
| labels.append(f"{object_name} {i+1}") | |
| print(f"Box {i+1}: ({x_min}, {y_min}, {x_max}, {y_max})") | |
| detections = sv.Detections( | |
| xyxy=np.array(bboxes, dtype=np.float32), | |
| class_id=np.arange(len(bboxes)) | |
| ) | |
| bounding_box_annotator = sv.BoxAnnotator( | |
| thickness=3, | |
| color_lookup=sv.ColorLookup.INDEX | |
| ) | |
| label_annotator = sv.LabelAnnotator( | |
| text_thickness=2, | |
| text_scale=0.6, | |
| color_lookup=sv.ColorLookup.INDEX | |
| ) | |
| annotated_image = bounding_box_annotator.annotate( | |
| scene=annotated_image, detections=detections | |
| ) | |
| annotated_image = label_annotator.annotate( | |
| scene=annotated_image, detections=detections, labels=labels | |
| ) | |
| return Image.fromarray(annotated_image) | |
| def process_video_with_tracking(video_path, prompt, detection_interval=3): | |
| cap = cv2.VideoCapture(video_path) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| byte_tracker = sv.ByteTrack() | |
| temp_dir = tempfile.mkdtemp() | |
| output_path = os.path.join(temp_dir, "tracked_video.mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| frame_count = 0 | |
| detection_count = 0 | |
| last_detections = None | |
| try: | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| run_detection = (frame_count % detection_interval == 0) | |
| if run_detection: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(frame_rgb) | |
| result = model.detect(pil_image, prompt) | |
| detection_count += 1 | |
| if "objects" in result and result["objects"]: | |
| bboxes = [] | |
| confidences = [] | |
| for obj in result["objects"]: | |
| x_min = max(0.0, min(1.0, obj["x_min"])) * width | |
| y_min = max(0.0, min(1.0, obj["y_min"])) * height | |
| x_max = max(0.0, min(1.0, obj["x_max"])) * width | |
| y_max = max(0.0, min(1.0, obj["y_max"])) * height | |
| if x_max > x_min and y_max > y_min: | |
| bboxes.append([x_min, y_min, x_max, y_max]) | |
| confidences.append(0.8) | |
| if bboxes: | |
| detections = sv.Detections( | |
| xyxy=np.array(bboxes, dtype=np.float32), | |
| confidence=np.array(confidences, dtype=np.float32), | |
| class_id=np.zeros(len(bboxes), dtype=int) | |
| ) | |
| detections = byte_tracker.update_with_detections(detections) | |
| last_detections = detections | |
| else: | |
| empty_detections = sv.Detections.empty() | |
| detections = byte_tracker.update_with_detections(empty_detections) | |
| last_detections = detections | |
| else: | |
| empty_detections = sv.Detections.empty() | |
| detections = byte_tracker.update_with_detections(empty_detections) | |
| last_detections = detections | |
| else: | |
| empty_detections = sv.Detections.empty() | |
| detections = byte_tracker.update_with_detections(empty_detections) | |
| if detections is not None and len(detections) > 0: | |
| box_annotator = sv.BoxAnnotator( | |
| thickness=3, | |
| color_lookup=sv.ColorLookup.TRACK | |
| ) | |
| label_annotator = sv.LabelAnnotator( | |
| text_scale=0.6, | |
| text_thickness=2, | |
| color_lookup=sv.ColorLookup.TRACK | |
| ) | |
| labels = [] | |
| for tracker_id in detections.tracker_id: | |
| if tracker_id is not None: | |
| labels.append(f"{prompt} ID: {tracker_id}") | |
| else: | |
| labels.append(f"{prompt} Unknown") | |
| frame = box_annotator.annotate(scene=frame, detections=detections) | |
| frame = label_annotator.annotate(scene=frame, detections=detections, labels=labels) | |
| out.write(frame) | |
| frame_count += 1 | |
| if frame_count % 30 == 0: | |
| progress = (frame_count / total_frames) * 100 | |
| print(f"Processing: {progress:.1f}% ({frame_count}/{total_frames}) - Detections: {detection_count}") | |
| finally: | |
| cap.release() | |
| out.release() | |
| summary = f"""Video processing complete: | |
| - Total frames processed: {frame_count} | |
| - Detection runs: {detection_count} (every {detection_interval} frames) | |
| - Objects tracked: {prompt} | |
| - Processing speed: ~{detection_count/frame_count*100:.1f}% detection rate for optimization""" | |
| return output_path, summary | |
| def create_point_annotated_image(image, point_result): | |
| """Create annotated image with points for detected objects.""" | |
| if not isinstance(point_result, dict) or "points" not in point_result: | |
| return image | |
| original_width, original_height = image.size | |
| annotated_image = np.array(image.convert("RGB")) | |
| points = [] | |
| for point in point_result["points"]: | |
| x = int(point["x"] * original_width) | |
| y = int(point["y"] * original_height) | |
| points.append([x, y]) | |
| if points: | |
| points_array = np.array(points).reshape(1, -1, 2) | |
| key_points = sv.KeyPoints(xy=points_array) | |
| vertex_annotator = sv.VertexAnnotator(radius=8, color=sv.Color.RED) | |
| annotated_image = vertex_annotator.annotate( | |
| scene=annotated_image, key_points=key_points | |
| ) | |
| return Image.fromarray(annotated_image) | |
| def detect_objects(image, prompt, task_type, max_objects): | |
| STANDARD_SIZE = (1024, 1024) | |
| image.thumbnail(STANDARD_SIZE) | |
| t0 = time.perf_counter() | |
| if task_type == "Object Detection": | |
| settings = {"max_objects": max_objects} if max_objects > 0 else {} | |
| result = model.detect(image, prompt, settings=settings) | |
| annotated_image = create_annotated_image(image, result, prompt) | |
| elif task_type == "Point Detection": | |
| result = model.point(image, prompt) | |
| annotated_image = create_point_annotated_image(image, result) | |
| elif task_type == "Caption": | |
| result = model.caption(image, length="normal") | |
| annotated_image = image | |
| else: | |
| result = model.query(image=image, question=prompt, reasoning=True) | |
| annotated_image = image | |
| elapsed_ms = (time.perf_counter() - t0) * 1_000 | |
| if isinstance(result, dict): | |
| if "objects" in result: | |
| output_text = f"Found {len(result['objects'])} objects:\n" | |
| for i, obj in enumerate(result['objects'], 1): | |
| output_text += f"\n{i}. Bounding box: " | |
| output_text += f"({obj['x_min']:.3f}, {obj['y_min']:.3f}, {obj['x_max']:.3f}, {obj['y_max']:.3f})" | |
| elif "points" in result: | |
| output_text = f"Found {len(result['points'])} points:\n" | |
| for i, point in enumerate(result['points'], 1): | |
| output_text += f"\n{i}. Point: ({point['x']:.3f}, {point['y']:.3f})" | |
| elif "caption" in result: | |
| output_text = result['caption'] | |
| elif "answer" in result: | |
| if "reasoning" in result: | |
| output_text = f"Reasoning: {result['reasoning']}\n\nAnswer: {result['answer']}" | |
| else: | |
| output_text = result['answer'] | |
| else: | |
| output_text = json.dumps(result, indent=2) | |
| else: | |
| output_text = str(result) | |
| timing_text = f"Inference time: {elapsed_ms:.0f} ms" | |
| return annotated_image, output_text, timing_text | |
| def process_video(video_file, prompt, detection_interval): | |
| if video_file is None: | |
| return None, "Please upload a video file" | |
| output_path, summary = process_video_with_tracking( | |
| video_file, prompt, detection_interval | |
| ) | |
| return output_path, summary | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Moondream3 🌝") | |
| gr.Markdown(""" | |
| *Try [Moondream3 Preview](https://huggingface.co/moondream/moondream3-preview) for following tasks:* | |
| - **Object Detection** | |
| - **Point Detection** | |
| - **Captioning** | |
| - **Visual Question Answering** | |
| - **Video Object Tracking** | |
| """) | |
| with gr.Tabs() as tabs: | |
| with gr.Tab("Image Processing"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(label="Upload an image", type="pil", height=400) | |
| task_type = gr.Radio( | |
| choices=["Object Detection", "Point Detection", "Caption", "Visual Question Answering"], | |
| label="Task Type", | |
| value="Object Detection" | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="Prompt (object to detect/question to ask)", | |
| placeholder="e.g., 'car', 'person', 'What's in this image?'", | |
| value="objects" | |
| ) | |
| max_objects = gr.Number( | |
| label="Max Objects (for Object Detection only)", | |
| value=10, | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| visible=True | |
| ) | |
| generate_btn = gr.Button(value="Generate", variant="primary") | |
| with gr.Column(scale=2): | |
| output_image = gr.Image( | |
| type="pil", | |
| label="Result", | |
| height=400 | |
| ) | |
| output_textbox = gr.Textbox( | |
| label="Model Response", | |
| lines=10, | |
| show_copy_button=True | |
| ) | |
| output_time = gr.Markdown() | |
| gr.Markdown("### Examples") | |
| example_prompts = [ | |
| [ | |
| "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG", | |
| "Object Detection", | |
| "candy", | |
| 5 | |
| ], | |
| [ | |
| "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/candy.JPG", | |
| "Point Detection", | |
| "candy", | |
| 5 | |
| ], | |
| [ | |
| "https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg", | |
| "Caption", | |
| "", | |
| 5 | |
| ], | |
| [ | |
| "https://moondream.ai/images/blog/moondream-3-preview/benchmarks.jpg", | |
| "Visual Question Answering", | |
| "how well does moondream 3 perform in chartvqa?", | |
| 5 | |
| ], | |
| ] | |
| gr.Examples( | |
| examples=example_prompts, | |
| inputs=[image_input, task_type, prompt_input, max_objects], | |
| label="Click an example to populate inputs" | |
| ) | |
| with gr.Tab("Video Object Tracking"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| video_input = gr.Video( | |
| label="Upload a video file", | |
| height=400 | |
| ) | |
| video_prompt = gr.Textbox( | |
| label="Object to track", | |
| placeholder="e.g., 'person', 'car', 'ball'", | |
| value="person" | |
| ) | |
| detection_interval = gr.Slider( | |
| minimum=10, | |
| maximum=30, | |
| value=15, | |
| step=5, | |
| label="Detection Interval (frames)", | |
| info="Run detection every N frames (less is slower but more accurate, ZeroGPU might time out with long videos)" | |
| ) | |
| process_video_btn = gr.Button(value="Process Video", variant="primary") | |
| with gr.Column(scale=2): | |
| output_video = gr.Video( | |
| label="Tracked Video Result", | |
| height=400 | |
| ) | |
| video_summary = gr.Textbox( | |
| label="Processing Summary", | |
| lines=8, | |
| show_copy_button=True | |
| ) | |
| gr.Markdown("### Examples") | |
| example_prompts = [ | |
| [ | |
| "https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/IMG_8137.mp4", | |
| "snowboarder", | |
| 15 | |
| ], | |
| ] | |
| gr.Examples( | |
| examples=example_prompts, | |
| inputs=[video_input, video_prompt, detection_interval], | |
| label="Click an example to populate inputs" | |
| ) | |
| def update_max_objects_visibility(task): | |
| return gr.Number(visible=(task == "Object Detection")) | |
| task_type.change( | |
| fn=update_max_objects_visibility, | |
| inputs=[task_type], | |
| outputs=[max_objects] | |
| ) | |
| generate_btn.click( | |
| fn=detect_objects, | |
| inputs=[image_input, prompt_input, task_type, max_objects], | |
| outputs=[output_image, output_textbox, output_time] | |
| ) | |
| process_video_btn.click( | |
| fn=process_video, | |
| inputs=[video_input, video_prompt, detection_interval], | |
| outputs=[output_video, video_summary] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |