developer0hye commited on
Commit
e70d81e
·
verified ·
1 Parent(s): 795585e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -41
app.py CHANGED
@@ -5,7 +5,6 @@ import cv2
5
  from PIL import Image
6
  import numpy as np
7
 
8
-
9
  import warnings
10
  import torch
11
  warnings.filterwarnings("ignore")
@@ -30,16 +29,16 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
30
  if input_image.ndim == 3:
31
  input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
32
  input_image = Image.fromarray(input_image)
33
-
34
  init_image = input_image.convert("RGB")
35
-
36
  # Process input using transformers
37
  inputs = processor(images=init_image, text=grounding_caption, return_tensors="pt").to(device)
38
-
39
  # Run inference
40
  with torch.no_grad():
41
  outputs = model(**inputs)
42
-
43
  # Post-process results
44
  results = processor.post_process_grounded_object_detection(
45
  outputs,
@@ -48,43 +47,54 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
48
  text_threshold=text_threshold,
49
  target_sizes=[init_image.size[::-1]]
50
  )
51
-
52
  result = results[0]
53
-
54
  # Convert image for supervision visualization
55
  image_np = np.array(init_image)
56
-
57
  # Create detections for supervision
58
  boxes = []
59
  labels = []
60
  confidences = []
61
  class_ids = []
62
-
63
  for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])):
64
- # Convert box to xyxy format
65
  xyxy = box.tolist()
66
  boxes.append(xyxy)
67
  labels.append(label)
68
  confidences.append(float(score))
69
  class_ids.append(i) # Use index as class_id (integer)
70
-
71
- # Create Detections object for supervision
 
 
 
 
 
 
 
 
 
 
 
72
  if boxes:
73
  detections = sv.Detections(
74
  xyxy=np.array(boxes),
75
  confidence=np.array(confidences),
76
- class_id=np.array(class_ids, dtype=np.int32), # Ensure it's an integer array
77
  )
78
-
79
  text_scale = sv.calculate_optimal_text_scale(resolution_wh=init_image.size)
80
  line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=init_image.size)
81
-
82
  # Create annotators
83
  box_annotator = sv.BoxAnnotator(
84
  thickness=2,
85
  color=sv.ColorPalette.DEFAULT,
86
  )
87
-
88
  label_annotator = sv.LabelAnnotator(
89
  color=sv.ColorPalette.DEFAULT,
90
  text_color=sv.Color.WHITE,
@@ -92,40 +102,41 @@ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold)
92
  text_thickness=line_thickness,
93
  text_padding=3
94
  )
95
-
96
  # Create formatted labels for each detection
97
  formatted_labels = [
98
- f"{label}: {conf:.2f}"
99
  for label, conf in zip(labels, confidences)
100
  ]
101
-
102
  # Apply annotations to the image
103
  annotated_image = box_annotator.annotate(scene=image_np, detections=detections)
104
  annotated_image = label_annotator.annotate(
105
- scene=annotated_image,
106
- detections=detections,
107
  labels=formatted_labels
108
  )
109
  else:
110
  annotated_image = image_np
111
-
112
  # Convert back to PIL Image
113
  image_with_box = Image.fromarray(annotated_image)
114
-
115
- return image_with_box
 
116
 
117
  if __name__ == "__main__":
118
-
119
  parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
120
  parser.add_argument("--debug", action="store_true", help="using debug mode")
121
  parser.add_argument("--share", action="store_true", help="share the app")
122
  args = parser.parse_args()
123
-
124
  css = """
125
  #mkd {
126
- height: 500px;
127
- overflow: auto;
128
- border: 1px solid #ccc;
129
  }
130
  """
131
  with gr.Blocks(css=css) as demo:
@@ -135,16 +146,19 @@ if __name__ == "__main__":
135
  with gr.Row():
136
  with gr.Column():
137
  input_image = gr.Image(label="Input Image", type="pil")
138
- grounding_caption = gr.Textbox(label="Detection Prompt(VERY important: text queries need to be lowercased + end with a dot, example: a cat. a remote control.)", value="a person. a car.")
 
 
 
139
  run_button = gr.Button("Run")
140
-
141
  with gr.Accordion("Advanced options", open=False):
142
  box_threshold = gr.Slider(
143
- minimum=0.0, maximum=1.0, value=0.3, step=0.001,
144
  label="Box Threshold"
145
  )
146
  text_threshold = gr.Slider(
147
- minimum=0.0, maximum=1.0, value=0.25, step=0.001,
148
  label="Text Threshold"
149
  )
150
 
@@ -153,22 +167,28 @@ if __name__ == "__main__":
153
  label="Detection Result",
154
  type="pil"
155
  )
 
 
 
 
 
 
156
 
157
  run_button.click(
158
- fn=run_grounding,
159
- inputs=[input_image, grounding_caption, box_threshold, text_threshold],
160
- outputs=[gallery]
161
  )
162
-
163
  gr.Examples(
164
  examples=[
165
  ["000000039769.jpg", "a cat. a remote control.", 0.3, 0.25],
166
  ["KakaoTalk_20250430_163200504.jpg", "cup. screen. hand.", 0.3, 0.25]
167
- ],
168
  inputs=[input_image, grounding_caption, box_threshold, text_threshold],
169
- outputs=[gallery],
170
  fn=run_grounding,
171
  cache_examples=True,
172
  )
173
-
174
- demo.launch(share=args.share, debug=args.debug, show_error=True)
 
5
  from PIL import Image
6
  import numpy as np
7
 
 
8
  import warnings
9
  import torch
10
  warnings.filterwarnings("ignore")
 
29
  if input_image.ndim == 3:
30
  input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
31
  input_image = Image.fromarray(input_image)
32
+
33
  init_image = input_image.convert("RGB")
34
+
35
  # Process input using transformers
36
  inputs = processor(images=init_image, text=grounding_caption, return_tensors="pt").to(device)
37
+
38
  # Run inference
39
  with torch.no_grad():
40
  outputs = model(**inputs)
41
+
42
  # Post-process results
43
  results = processor.post_process_grounded_object_detection(
44
  outputs,
 
47
  text_threshold=text_threshold,
48
  target_sizes=[init_image.size[::-1]]
49
  )
50
+
51
  result = results[0]
52
+
53
  # Convert image for supervision visualization
54
  image_np = np.array(init_image)
55
+
56
  # Create detections for supervision
57
  boxes = []
58
  labels = []
59
  confidences = []
60
  class_ids = []
61
+
62
  for i, (box, score, label) in enumerate(zip(result["boxes"], result["scores"], result["labels"])):
63
+ # box is xyxy format [xmin, ymin, xmax, ymax]
64
  xyxy = box.tolist()
65
  boxes.append(xyxy)
66
  labels.append(label)
67
  confidences.append(float(score))
68
  class_ids.append(i) # Use index as class_id (integer)
69
+
70
+ # Build the text summary in the requested format
71
+ if boxes:
72
+ lines = []
73
+ for label, xyxy, conf in zip(labels, boxes, confidences):
74
+ x1, y1, x2, y2 = [int(round(v)) for v in xyxy]
75
+ # Format: class top_left_x, top_left_y, bot_x, bot_y
76
+ lines.append(f"{label} {x1}, {y1}, {x2}, {y2}")
77
+ detection_text = "\n".join(lines)
78
+ else:
79
+ detection_text = "No detections."
80
+
81
+ # Create Detections object for supervision & annotate
82
  if boxes:
83
  detections = sv.Detections(
84
  xyxy=np.array(boxes),
85
  confidence=np.array(confidences),
86
+ class_id=np.array(class_ids, dtype=np.int32),
87
  )
88
+
89
  text_scale = sv.calculate_optimal_text_scale(resolution_wh=init_image.size)
90
  line_thickness = sv.calculate_optimal_line_thickness(resolution_wh=init_image.size)
91
+
92
  # Create annotators
93
  box_annotator = sv.BoxAnnotator(
94
  thickness=2,
95
  color=sv.ColorPalette.DEFAULT,
96
  )
97
+
98
  label_annotator = sv.LabelAnnotator(
99
  color=sv.ColorPalette.DEFAULT,
100
  text_color=sv.Color.WHITE,
 
102
  text_thickness=line_thickness,
103
  text_padding=3
104
  )
105
+
106
  # Create formatted labels for each detection
107
  formatted_labels = [
108
+ f"{label}: {conf:.2f}"
109
  for label, conf in zip(labels, confidences)
110
  ]
111
+
112
  # Apply annotations to the image
113
  annotated_image = box_annotator.annotate(scene=image_np, detections=detections)
114
  annotated_image = label_annotator.annotate(
115
+ scene=annotated_image,
116
+ detections=detections,
117
  labels=formatted_labels
118
  )
119
  else:
120
  annotated_image = image_np
121
+
122
  # Convert back to PIL Image
123
  image_with_box = Image.fromarray(annotated_image)
124
+
125
+ # Return both the annotated image and the detection text
126
+ return image_with_box, detection_text
127
 
128
  if __name__ == "__main__":
129
+
130
  parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
131
  parser.add_argument("--debug", action="store_true", help="using debug mode")
132
  parser.add_argument("--share", action="store_true", help="share the app")
133
  args = parser.parse_args()
134
+
135
  css = """
136
  #mkd {
137
+ height: 500px;
138
+ overflow: auto;
139
+ border: 1px solid #ccc;
140
  }
141
  """
142
  with gr.Blocks(css=css) as demo:
 
146
  with gr.Row():
147
  with gr.Column():
148
  input_image = gr.Image(label="Input Image", type="pil")
149
+ grounding_caption = gr.Textbox(
150
+ label="Detection Prompt (lowercase + each ends with a dot)",
151
+ value="a person. a car."
152
+ )
153
  run_button = gr.Button("Run")
154
+
155
  with gr.Accordion("Advanced options", open=False):
156
  box_threshold = gr.Slider(
157
+ minimum=0.0, maximum=1.0, value=0.3, step=0.001,
158
  label="Box Threshold"
159
  )
160
  text_threshold = gr.Slider(
161
+ minimum=0.0, maximum=1.0, value=0.25, step=0.001,
162
  label="Text Threshold"
163
  )
164
 
 
167
  label="Detection Result",
168
  type="pil"
169
  )
170
+ det_text = gr.Textbox(
171
+ label="Detections (class top_left_x, top_left_y, bot_x, bot_y)",
172
+ lines=12,
173
+ interactive=False,
174
+ show_copy_button=True
175
+ )
176
 
177
  run_button.click(
178
+ fn=run_grounding,
179
+ inputs=[input_image, grounding_caption, box_threshold, text_threshold],
180
+ outputs=[gallery, det_text]
181
  )
182
+
183
  gr.Examples(
184
  examples=[
185
  ["000000039769.jpg", "a cat. a remote control.", 0.3, 0.25],
186
  ["KakaoTalk_20250430_163200504.jpg", "cup. screen. hand.", 0.3, 0.25]
187
+ ],
188
  inputs=[input_image, grounding_caption, box_threshold, text_threshold],
189
+ outputs=[gallery, det_text],
190
  fn=run_grounding,
191
  cache_examples=True,
192
  )
193
+
194
+ demo.launch(share=args.share, debug=args.debug, show_error=True)