Spaces:
Runtime error
Runtime error
| import spaces | |
| from transformers import Owlv2Processor, Owlv2ForObjectDetection, AutoProcessor, AutoModelForZeroShotObjectDetection | |
| import torch | |
| import gradio as gr | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device) | |
| owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") | |
| dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") | |
| dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device) | |
| def infer(img, text_queries, score_threshold, model): | |
| if model == "dino": | |
| queries="" | |
| for query in text_queries: | |
| queries += f"{query}. " | |
| width, height = img.shape[:2] | |
| target_sizes=[(width, height)] | |
| inputs = dino_processor(text=queries, images=img, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = dino_model(**inputs) | |
| outputs.logits = outputs.logits.cpu() | |
| outputs.pred_boxes = outputs.pred_boxes.cpu() | |
| results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids, | |
| box_threshold=score_threshold, | |
| target_sizes=target_sizes) | |
| elif model == "owl": | |
| size = max(img.shape[:2]) | |
| target_sizes = torch.Tensor([[size, size]]) | |
| inputs = owl_processor(text=text_queries, images=img, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = owl_model(**inputs) | |
| outputs.logits = outputs.logits.cpu() | |
| outputs.pred_boxes = outputs.pred_boxes.cpu() | |
| results = owl_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes) | |
| boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] | |
| result_labels = [] | |
| for box, score, label in zip(boxes, scores, labels): | |
| box = [int(i) for i in box.tolist()] | |
| if score < score_threshold: | |
| continue | |
| if model == "owl": | |
| label = text_queries[label.cpu().item()] | |
| result_labels.append((box, label)) | |
| return result_labels | |
| def query_image(img, text_queries, owl_threshold, dino_threshold): | |
| text_queries = text_queries | |
| text_queries = text_queries.split(",") | |
| owl_output = infer(img, text_queries, owl_threshold, "owl") | |
| dino_output = infer(img, text_queries, owl_threshold, "dino") | |
| return (img, owl_output), (img, dino_output) | |
| owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold") | |
| dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold") | |
| owl_output = gr.AnnotatedImage(label="OWL Output") | |
| dino_output = gr.AnnotatedImage(label="Grounding DINO Output") | |
| demo = gr.Interface( | |
| query_image, | |
| inputs=[gr.Image(label="Input Image"), gr.Textbox("Candidate Labels"), owl_threshold, dino_threshold], | |
| outputs=[owl_output, dino_output], | |
| title="Zero-Shot Object Detection with OWLv2", | |
| examples=[["./bee.jpg", "bee, flower", 0.16, 0.12]] | |
| ) | |
| demo.launch(debug=True) |