Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,7 +15,7 @@ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
|
| 15 |
import supervision as sv
|
| 16 |
|
| 17 |
# Model ID for Hugging Face
|
| 18 |
-
model_id = "
|
| 19 |
|
| 20 |
# Load model and processor using Transformers
|
| 21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -32,8 +32,12 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
|
|
| 32 |
|
| 33 |
init_image = input_image.convert("RGB")
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
# Process input using transformers
|
| 36 |
-
inputs = processor(images=init_image, text=
|
| 37 |
|
| 38 |
# Run inference
|
| 39 |
with torch.no_grad():
|
|
@@ -42,10 +46,8 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
|
|
| 42 |
# Post-process results
|
| 43 |
results = processor.post_process_grounded_object_detection(
|
| 44 |
outputs,
|
| 45 |
-
inputs.input_ids,
|
| 46 |
threshold=box_threshold,
|
| 47 |
-
|
| 48 |
-
target_sizes=[init_image.size[::-1]]
|
| 49 |
)
|
| 50 |
|
| 51 |
result = results[0]
|
|
@@ -140,8 +142,8 @@ if __name__ == "__main__":
|
|
| 140 |
}
|
| 141 |
"""
|
| 142 |
with gr.Blocks(css=css) as demo:
|
| 143 |
-
gr.Markdown("<h1><center>Grounding DINO Base<h1><center>")
|
| 144 |
-
gr.Markdown("<h3><center>Open-World Detection with <a href='https://
|
| 145 |
|
| 146 |
with gr.Row():
|
| 147 |
with gr.Column():
|
|
@@ -159,7 +161,8 @@ if __name__ == "__main__":
|
|
| 159 |
)
|
| 160 |
text_threshold = gr.Slider(
|
| 161 |
minimum=0.0, maximum=1.0, value=0.25, step=0.001,
|
| 162 |
-
label="Text Threshold"
|
|
|
|
| 163 |
)
|
| 164 |
|
| 165 |
with gr.Column():
|
|
|
|
| 15 |
import supervision as sv
|
| 16 |
|
| 17 |
# Model ID for Hugging Face
|
| 18 |
+
model_id = "rziga/mm_grounding_dino_base_all"
|
| 19 |
|
| 20 |
# Load model and processor using Transformers
|
| 21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 32 |
|
| 33 |
init_image = input_image.convert("RGB")
|
| 34 |
|
| 35 |
+
# Process caption into list of list format for mm grounding dino
|
| 36 |
+
# Split by period and strip whitespace
|
| 37 |
+
text_labels = [[label.strip() for label in grounding_caption.split('.') if label.strip()]]
|
| 38 |
+
|
| 39 |
# Process input using transformers
|
| 40 |
+
inputs = processor(images=init_image, text=text_labels, return_tensors="pt").to(device)
|
| 41 |
|
| 42 |
# Run inference
|
| 43 |
with torch.no_grad():
|
|
|
|
| 46 |
# Post-process results
|
| 47 |
results = processor.post_process_grounded_object_detection(
|
| 48 |
outputs,
|
|
|
|
| 49 |
threshold=box_threshold,
|
| 50 |
+
target_sizes=[(init_image.size[1], init_image.size[0])]
|
|
|
|
| 51 |
)
|
| 52 |
|
| 53 |
result = results[0]
|
|
|
|
| 142 |
}
|
| 143 |
"""
|
| 144 |
with gr.Blocks(css=css) as demo:
|
| 145 |
+
gr.Markdown("<h1><center>MM Grounding DINO Base<h1><center>")
|
| 146 |
+
gr.Markdown("<h3><center>Open-World Detection with <a href='https://huggingface.co/openmmlab-community/mm_grounding_dino_base_all'>MM Grounding DINO</a><h3><center>")
|
| 147 |
|
| 148 |
with gr.Row():
|
| 149 |
with gr.Column():
|
|
|
|
| 161 |
)
|
| 162 |
text_threshold = gr.Slider(
|
| 163 |
minimum=0.0, maximum=1.0, value=0.25, step=0.001,
|
| 164 |
+
label="Text Threshold (not used in MM Grounding DINO)",
|
| 165 |
+
visible=False
|
| 166 |
)
|
| 167 |
|
| 168 |
with gr.Column():
|