developer0hye commited on
Commit
d58546f
·
verified ·
1 Parent(s): ce43380

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -15,7 +15,7 @@ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
15
  import supervision as sv
16
 
17
  # Model ID for Hugging Face
18
- model_id = "IDEA-Research/grounding-dino-base"
19
 
20
  # Load model and processor using Transformers
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -32,8 +32,12 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
32
 
33
  init_image = input_image.convert("RGB")
34
 
 
 
 
 
35
  # Process input using transformers
36
- inputs = processor(images=init_image, text=grounding_caption, return_tensors="pt").to(device)
37
 
38
  # Run inference
39
  with torch.no_grad():
@@ -42,10 +46,8 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
42
  # Post-process results
43
  results = processor.post_process_grounded_object_detection(
44
  outputs,
45
- inputs.input_ids,
46
  threshold=box_threshold,
47
- text_threshold=text_threshold,
48
- target_sizes=[init_image.size[::-1]]
49
  )
50
 
51
  result = results[0]
@@ -140,8 +142,8 @@ if __name__ == "__main__":
140
  }
141
  """
142
  with gr.Blocks(css=css) as demo:
143
- gr.Markdown("<h1><center>Grounding DINO Base<h1><center>")
144
- gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/IDEA-Research/GroundingDINO'>Grounding DINO</a><h3><center>")
145
 
146
  with gr.Row():
147
  with gr.Column():
@@ -159,7 +161,8 @@ if __name__ == "__main__":
159
  )
160
  text_threshold = gr.Slider(
161
  minimum=0.0, maximum=1.0, value=0.25, step=0.001,
162
- label="Text Threshold"
 
163
  )
164
 
165
  with gr.Column():
 
15
  import supervision as sv
16
 
17
  # Model ID for Hugging Face
18
+ model_id = "rziga/mm_grounding_dino_base_all"
19
 
20
  # Load model and processor using Transformers
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
32
 
33
  init_image = input_image.convert("RGB")
34
 
35
+ # Process caption into list of list format for mm grounding dino
36
+ # Split by period and strip whitespace
37
+ text_labels = [[label.strip() for label in grounding_caption.split('.') if label.strip()]]
38
+
39
  # Process input using transformers
40
+ inputs = processor(images=init_image, text=text_labels, return_tensors="pt").to(device)
41
 
42
  # Run inference
43
  with torch.no_grad():
 
46
  # Post-process results
47
  results = processor.post_process_grounded_object_detection(
48
  outputs,
 
49
  threshold=box_threshold,
50
+ target_sizes=[(init_image.size[1], init_image.size[0])]
 
51
  )
52
 
53
  result = results[0]
 
142
  }
143
  """
144
  with gr.Blocks(css=css) as demo:
145
+ gr.Markdown("<h1><center>MM Grounding DINO Base<h1><center>")
146
+ gr.Markdown("<h3><center>Open-World Detection with <a href='https://huggingface.co/openmmlab-community/mm_grounding_dino_base_all'>MM Grounding DINO</a><h3><center>")
147
 
148
  with gr.Row():
149
  with gr.Column():
 
161
  )
162
  text_threshold = gr.Slider(
163
  minimum=0.0, maximum=1.0, value=0.25, step=0.001,
164
+ label="Text Threshold (not used in MM Grounding DINO)",
165
+ visible=False
166
  )
167
 
168
  with gr.Column():