Mountchicken commited on
Commit
a44d7a3
Β·
verified Β·
1 Parent(s): 5aa063f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +934 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,934 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import argparse
5
+ import json
6
+ import os
7
+ import re
8
+ from typing import Any, Dict, List
9
+
10
+ import gradio as gr
11
+ import numpy as np
12
+ from gradio_image_prompter import ImagePrompter
13
+ from PIL import Image
14
+
15
+ from rex_omni import RexOmniVisualize, RexOmniWrapper, TaskType
16
+ from rex_omni.tasks import KEYPOINT_CONFIGS, TASK_CONFIGS, get_task_config
17
+
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser(description="Rex Omni Gradio Demo")
21
+ parser.add_argument(
22
+ "--model_path",
23
+ default="IDEA-Research/Rex-Omni",
24
+ help="Model path or HuggingFace repo ID",
25
+ )
26
+ parser.add_argument(
27
+ "--backend",
28
+ type=str,
29
+ default="transformers",
30
+ choices=["transformers", "vllm"],
31
+ help="Backend to use for inference",
32
+ )
33
+ parser.add_argument("--temperature", type=float, default=0.0)
34
+ parser.add_argument("--top_p", type=float, default=0.05)
35
+ parser.add_argument("--top_k", type=int, default=1)
36
+ parser.add_argument("--max_tokens", type=int, default=2048)
37
+ parser.add_argument("--repetition_penalty", type=float, default=1.05)
38
+ parser.add_argument("--min_pixels", type=int, default=16 * 28 * 28)
39
+ parser.add_argument("--max_pixels", type=int, default=2560 * 28 * 28)
40
+ parser.add_argument("--server_name", type=str, default="0.0.0.0")
41
+ parser.add_argument("--server_port", type=int, default=7860)
42
+ args = parser.parse_args()
43
+ return args
44
+
45
+
46
+ # Task configurations with detailed descriptions
47
+ DEMO_TASK_CONFIGS = {
48
+ "Detection": {
49
+ "task_type": TaskType.DETECTION,
50
+ "description": "Detect objects and return bounding boxes",
51
+ "example_categories": "person",
52
+ "supports_visual_prompt": False,
53
+ "supports_ocr_config": False,
54
+ },
55
+ "Pointing": {
56
+ "task_type": TaskType.POINTING,
57
+ "description": "Point to objects and return point coordinates",
58
+ "example_categories": "person",
59
+ "supports_visual_prompt": False,
60
+ "supports_ocr_config": False,
61
+ },
62
+ "Visual Prompting": {
63
+ "task_type": TaskType.VISUAL_PROMPTING,
64
+ "description": "Ground visual examples to find similar objects",
65
+ "example_categories": "",
66
+ "supports_visual_prompt": True,
67
+ "supports_ocr_config": False,
68
+ },
69
+ "Keypoint": {
70
+ "task_type": TaskType.KEYPOINT,
71
+ "description": "Detect keypoints with skeleton visualization",
72
+ "example_categories": "person, hand, animal",
73
+ "supports_visual_prompt": False,
74
+ "supports_ocr_config": False,
75
+ },
76
+ "OCR": {
77
+ "task_type": None, # Will be determined by OCR config
78
+ "description": "Optical Character Recognition with customizable output format",
79
+ "example_categories": "text, word",
80
+ "supports_visual_prompt": False,
81
+ "supports_ocr_config": True,
82
+ },
83
+ }
84
+
85
+ # OCR configuration options
86
+ OCR_OUTPUT_FORMATS = {
87
+ "Box": {
88
+ "task_type": TaskType.OCR_BOX,
89
+ "description": "Detect text with bounding boxes",
90
+ },
91
+ "Polygon": {
92
+ "task_type": TaskType.OCR_POLYGON,
93
+ "description": "Detect text with polygon boundaries",
94
+ },
95
+ }
96
+
97
+ OCR_GRANULARITY_LEVELS = {
98
+ "Word Level": {"categories": "word", "description": "Detect individual words"},
99
+ "Text Line Level": {"categories": "text line", "description": "Detect text lines"},
100
+ }
101
+
102
+ # Example configurations
103
+ EXAMPLE_CONFIGS = [
104
+ {
105
+ "name": "Detection: Cafe Scene",
106
+ "image_path": "tutorials/detection_example/test_images/cafe.jpg",
107
+ "task": "Detection",
108
+ "categories": "man, woman, yellow flower, sofa, robot-shape light, blanket, microwave, laptop, cup, white chair, lamp",
109
+ "keypoint_type": "person",
110
+ "ocr_output_format": "Box",
111
+ "ocr_granularity": "Word Level",
112
+ "visual_prompt_boxes": None,
113
+ "description": "Detection",
114
+ },
115
+ {
116
+ "name": "Referring: Boys Playing",
117
+ "image_path": "tutorials/detection_example/test_images/boys.jpg",
118
+ "task": "Detection",
119
+ "categories": "boys holding microphone, boy playing piano, the four guitars on the wall",
120
+ "keypoint_type": "person",
121
+ "ocr_output_format": "Box",
122
+ "ocr_granularity": "Word Level",
123
+ "visual_prompt_boxes": None,
124
+ "description": "Referring",
125
+ },
126
+ {
127
+ "name": "GUI Grounding: Boys Playing",
128
+ "image_path": "tutorials/detection_example/test_images/gui.png",
129
+ "task": "Detection",
130
+ "categories": "more information of song 'Photograph'",
131
+ "keypoint_type": "person",
132
+ "ocr_output_format": "Box",
133
+ "ocr_granularity": "Word Level",
134
+ "visual_prompt_boxes": None,
135
+ "description": "GUI Grounding",
136
+ },
137
+ {
138
+ "name": "Object Pointing: Point to boxes",
139
+ "image_path": "tutorials/pointing_example/test_images/boxes.jpg",
140
+ "task": "Pointing",
141
+ "categories": "open boxes, closed boxes",
142
+ "keypoint_type": "person",
143
+ "ocr_output_format": "Box",
144
+ "ocr_granularity": "Word Level",
145
+ "visual_prompt_boxes": None,
146
+ "description": "Point to boxes in the image",
147
+ },
148
+ {
149
+ "name": "Affordance Pointing",
150
+ "image_path": "tutorials/pointing_example/test_images/cup.png",
151
+ "task": "Pointing",
152
+ "categories": "where I can hold the green cup",
153
+ "keypoint_type": "person",
154
+ "ocr_output_format": "Box",
155
+ "ocr_granularity": "Word Level",
156
+ "visual_prompt_boxes": None,
157
+ "description": "Affordance Pointing",
158
+ },
159
+ {
160
+ "name": "Keypoint: Person",
161
+ "image_path": "tutorials/keypointing_example/test_images/person.png",
162
+ "task": "Keypoint",
163
+ "categories": "person",
164
+ "keypoint_type": "person",
165
+ "ocr_output_format": "Box",
166
+ "ocr_granularity": "Word Level",
167
+ "visual_prompt_boxes": None,
168
+ "description": "Detect human keypoints and pose estimation",
169
+ },
170
+ {
171
+ "name": "Keypoint: Animal",
172
+ "image_path": "tutorials/keypointing_example/test_images/animal.png",
173
+ "task": "Keypoint",
174
+ "categories": "animal",
175
+ "keypoint_type": "animal",
176
+ "ocr_output_format": "Box",
177
+ "ocr_granularity": "Word Level",
178
+ "visual_prompt_boxes": None,
179
+ "description": "Detect animal keypoints and pose structure",
180
+ },
181
+ {
182
+ "name": "OCR: Box and Word",
183
+ "image_path": "tutorials/ocr_example/test_images/ocr.png",
184
+ "task": "OCR",
185
+ "categories": "text",
186
+ "keypoint_type": "person",
187
+ "ocr_output_format": "Box",
188
+ "ocr_granularity": "Word Level",
189
+ "visual_prompt_boxes": None,
190
+ "description": "OCR: Box and Word",
191
+ },
192
+ {
193
+ "name": "OCR: Box and Text Line",
194
+ "image_path": "tutorials/ocr_example/test_images/ocr.png",
195
+ "task": "OCR",
196
+ "categories": "text",
197
+ "keypoint_type": "person",
198
+ "ocr_output_format": "Box",
199
+ "ocr_granularity": "Text Line Level",
200
+ "visual_prompt_boxes": None,
201
+ "description": "OCR: Box and Text Line",
202
+ },
203
+ {
204
+ "name": "OCR: Polygon and Text Line",
205
+ "image_path": "tutorials/ocr_example/test_images/ocr.png",
206
+ "task": "OCR",
207
+ "categories": "text",
208
+ "keypoint_type": "person",
209
+ "ocr_output_format": "Polygon",
210
+ "ocr_granularity": "Text Line Level",
211
+ "visual_prompt_boxes": None,
212
+ "description": "OCR: Polygon and Text Line",
213
+ },
214
+ {
215
+ "name": "Visual Prompting: Pigeons",
216
+ "image_path": "tutorials/visual_prompting_example/test_images/pigeons.jpeg",
217
+ "task": "Visual Prompting",
218
+ "categories": "pigeon",
219
+ "keypoint_type": "person",
220
+ "ocr_output_format": "Box",
221
+ "ocr_granularity": "Word Level",
222
+ "visual_prompt_boxes": [[644, 1210, 842, 1361], [1180, 1066, 1227, 1160]],
223
+ "description": "Find similar pigeons using visual prompting examples",
224
+ },
225
+ ]
226
+
227
+
228
+ def parse_visual_prompt(points: List) -> List[List[float]]:
229
+ """Parse visual prompt points to bounding boxes"""
230
+ boxes = []
231
+ for point in points:
232
+ if point[2] == 2 and point[-1] == 3: # Rectangle
233
+ x1, y1, _, x2, y2, _ = point
234
+ boxes.append([x1, y1, x2, y2])
235
+ elif point[2] == 1 and point[-1] == 4: # Positive point
236
+ x, y, _, _, _, _ = point
237
+ half_width = 10
238
+ x1 = max(0, x - half_width)
239
+ y1 = max(0, y - half_width)
240
+ x2 = x + half_width
241
+ y2 = y + half_width
242
+ boxes.append([x1, y1, x2, y2])
243
+ return boxes
244
+
245
+
246
+ def convert_boxes_to_visual_prompt_format(
247
+ boxes: List[List[float]], image_width: int, image_height: int
248
+ ) -> str:
249
+ """Convert bounding boxes to visual prompt format for the model"""
250
+ if not boxes:
251
+ return ""
252
+
253
+ # Convert to normalized bins (0-999)
254
+ visual_prompts = []
255
+ for i, box in enumerate(boxes):
256
+ x0, y0, x1, y1 = box
257
+
258
+ # Normalize and convert to bins
259
+ x0_norm = max(0.0, min(1.0, x0 / image_width))
260
+ y0_norm = max(0.0, min(1.0, y0 / image_height))
261
+ x1_norm = max(0.0, min(1.0, x1 / image_width))
262
+ y1_norm = max(0.0, min(1.0, y1 / image_height))
263
+
264
+ x0_bin = int(x0_norm * 999)
265
+ y0_bin = int(y0_norm * 999)
266
+ x1_bin = int(x1_norm * 999)
267
+ y1_bin = int(y1_norm * 999)
268
+
269
+ visual_prompt = f"<{x0_bin}><{y0_bin}><{x1_bin}><{y1_bin}>"
270
+ visual_prompts.append(visual_prompt)
271
+
272
+ return ", ".join(visual_prompts)
273
+
274
+
275
+ def get_task_prompt(
276
+ task_name: str,
277
+ categories: str,
278
+ keypoint_type: str = "",
279
+ visual_prompt_boxes: List = None,
280
+ image_width: int = 0,
281
+ image_height: int = 0,
282
+ ocr_output_format: str = "Box",
283
+ ocr_granularity: str = "Word Level",
284
+ ) -> str:
285
+ """Generate the actual prompt that will be sent to the model"""
286
+ if task_name not in DEMO_TASK_CONFIGS:
287
+ return "Invalid task selected."
288
+
289
+ demo_config = DEMO_TASK_CONFIGS[task_name]
290
+
291
+ if task_name == "Visual Prompting":
292
+ task_config = get_task_config(TaskType.VISUAL_PROMPTING)
293
+ if visual_prompt_boxes and len(visual_prompt_boxes) > 0:
294
+ visual_prompt_str = convert_boxes_to_visual_prompt_format(
295
+ visual_prompt_boxes, image_width, image_height
296
+ )
297
+ return task_config.prompt_template.replace(
298
+ "{visual_prompt}", visual_prompt_str
299
+ )
300
+ else:
301
+ return "Please draw bounding boxes on the image to provide visual examples."
302
+
303
+ elif task_name == "Keypoint":
304
+ task_config = get_task_config(TaskType.KEYPOINT)
305
+ if keypoint_type and keypoint_type in KEYPOINT_CONFIGS:
306
+ keypoints_list = KEYPOINT_CONFIGS[keypoint_type]
307
+ keypoints_str = ", ".join(keypoints_list)
308
+ prompt = task_config.prompt_template.replace("{categories}", keypoint_type)
309
+ prompt = prompt.replace("{keypoints}", keypoints_str)
310
+ return prompt
311
+ else:
312
+ return "Please select a keypoint type first."
313
+
314
+ elif task_name == "OCR":
315
+ # Get OCR task type based on output format
316
+ ocr_task_type = OCR_OUTPUT_FORMATS[ocr_output_format]["task_type"]
317
+ task_config = get_task_config(ocr_task_type)
318
+
319
+ # Get categories based on granularity level
320
+ ocr_categories = OCR_GRANULARITY_LEVELS[ocr_granularity]["categories"]
321
+
322
+ # Replace categories in prompt template
323
+ return task_config.prompt_template.replace("{categories}", ocr_categories)
324
+
325
+ else:
326
+ # For other tasks, use the task config from tasks.py
327
+ task_type = demo_config["task_type"]
328
+ task_config = get_task_config(task_type)
329
+
330
+ # Replace {categories} placeholder
331
+ if categories.strip():
332
+ return task_config.prompt_template.replace(
333
+ "{categories}", categories.strip()
334
+ )
335
+ else:
336
+ return task_config.prompt_template.replace("{categories}", "objects")
337
+
338
+
339
+ @spaces.GPU
340
+ def run_inference(
341
+ image,
342
+ task_selection,
343
+ categories,
344
+ keypoint_type,
345
+ visual_prompt_data,
346
+ ocr_output_format,
347
+ ocr_granularity,
348
+ font_size,
349
+ draw_width,
350
+ show_labels,
351
+ custom_color,
352
+ ):
353
+ """Run inference using Rex Omni"""
354
+ if image is None:
355
+ return None, "Please upload an image first."
356
+
357
+ try:
358
+ # Convert numpy array to PIL Image if needed
359
+ if isinstance(image, np.ndarray):
360
+ image = Image.fromarray(image)
361
+
362
+ image_width, image_height = image.size
363
+
364
+ # Parse visual prompts if needed
365
+ visual_prompt_boxes = []
366
+ if task_selection == "Visual Prompting":
367
+ # Check if we have predefined visual prompt boxes from examples
368
+ if hasattr(image, "_example_visual_prompts"):
369
+ visual_prompt_boxes = image._example_visual_prompts
370
+ elif visual_prompt_data is not None and "points" in visual_prompt_data:
371
+ visual_prompt_boxes = parse_visual_prompt(visual_prompt_data["points"])
372
+
373
+ # Determine task type and categories based on task selection
374
+ if task_selection == "OCR":
375
+ # For OCR, use the selected output format to determine task type
376
+ task_type = OCR_OUTPUT_FORMATS[ocr_output_format]["task_type"]
377
+ task_key = task_type.value
378
+ # Use granularity level to determine categories
379
+ categories_list = [OCR_GRANULARITY_LEVELS[ocr_granularity]["categories"]]
380
+ elif task_selection == "Visual Prompting":
381
+ # For visual prompting, we don't need explicit categories
382
+ task_key = "visual_prompting"
383
+ categories_list = ["object"]
384
+
385
+ # Check if visual prompt boxes are provided
386
+ if not visual_prompt_boxes:
387
+ return (
388
+ None,
389
+ "Please draw bounding boxes on the image to provide visual examples for Visual Prompting task.",
390
+ )
391
+ elif task_selection == "Keypoint":
392
+ task_key = "keypoint"
393
+ categories_list = [keypoint_type] if keypoint_type else ["person"]
394
+ else:
395
+ # For other tasks, get task type from demo config
396
+ demo_config = DEMO_TASK_CONFIGS[task_selection]
397
+ task_type = demo_config["task_type"]
398
+ task_key = task_type.value
399
+
400
+ # Split categories by comma and clean up
401
+ categories_list = [
402
+ cat.strip() for cat in categories.split(",") if cat.strip()
403
+ ]
404
+ if not categories_list:
405
+ categories_list = ["object"]
406
+
407
+ # Run inference
408
+ if task_selection == "Visual Prompting":
409
+ results = rex_model.inference(
410
+ images=image,
411
+ task=task_key,
412
+ categories=categories_list,
413
+ visual_prompt_boxes=visual_prompt_boxes,
414
+ )
415
+ elif task_selection == "Keypoint":
416
+ results = rex_model.inference(
417
+ images=image,
418
+ task=task_key,
419
+ categories=categories_list,
420
+ keypoint_type=keypoint_type if keypoint_type else "person",
421
+ )
422
+ else:
423
+ results = rex_model.inference(
424
+ images=image, task=task_key, categories=categories_list
425
+ )
426
+
427
+ result = results[0]
428
+
429
+ # Check if inference was successful
430
+ if not result.get("success", False):
431
+ error_msg = result.get("error", "Unknown error occurred during inference")
432
+ return None, f"Inference failed: {error_msg}"
433
+
434
+ # Get predictions and raw output
435
+ predictions = result["extracted_predictions"]
436
+ raw_output = result["raw_output"]
437
+
438
+ # Create visualization
439
+ try:
440
+ vis_image = RexOmniVisualize(
441
+ image=image,
442
+ predictions=predictions,
443
+ font_size=font_size,
444
+ draw_width=draw_width,
445
+ show_labels=show_labels,
446
+ )
447
+ return vis_image, raw_output
448
+ except Exception as e:
449
+ return image, f"Visualization failed: {str(e)}\n\nRaw output:\n{raw_output}"
450
+
451
+ except Exception as e:
452
+ return None, f"Error during inference: {str(e)}"
453
+
454
+
455
+ def update_interface(task_selection):
456
+ """Update interface based on task selection"""
457
+ config = DEMO_TASK_CONFIGS.get(task_selection, {})
458
+
459
+ if task_selection == "Visual Prompting":
460
+ return (
461
+ gr.update(visible=False), # categories
462
+ gr.update(visible=False), # keypoint_type
463
+ gr.update(visible=True), # visual_prompt_tab
464
+ gr.update(visible=False), # ocr_config_group
465
+ gr.update(value=config.get("description", "")), # task_description
466
+ )
467
+ elif task_selection == "Keypoint":
468
+ return (
469
+ gr.update(visible=False), # categories
470
+ gr.update(visible=True), # keypoint_type
471
+ gr.update(visible=False), # visual_prompt_tab
472
+ gr.update(visible=False), # ocr_config_group
473
+ gr.update(value=config.get("description", "")), # task_description
474
+ )
475
+ elif task_selection == "OCR":
476
+ return (
477
+ gr.update(visible=False), # categories
478
+ gr.update(visible=False), # keypoint_type
479
+ gr.update(visible=False), # visual_prompt_tab
480
+ gr.update(visible=True), # ocr_config_group
481
+ gr.update(value=config.get("description", "")), # task_description
482
+ )
483
+ else:
484
+ return (
485
+ gr.update(
486
+ visible=True, placeholder=config.get("example_categories", "")
487
+ ), # categories
488
+ gr.update(visible=False), # keypoint_type
489
+ gr.update(visible=False), # visual_prompt_tab
490
+ gr.update(visible=False), # ocr_config_group
491
+ gr.update(value=config.get("description", "")), # task_description
492
+ )
493
+
494
+
495
+ def load_example_image(image_path, visual_prompt_boxes=None):
496
+ """Load example image from tutorials directory"""
497
+ if image_path is None:
498
+ return None
499
+
500
+ try:
501
+ import os
502
+
503
+ from PIL import Image
504
+
505
+ # Construct full path
506
+ full_path = os.path.join(os.path.dirname(__file__), "..", image_path)
507
+ if os.path.exists(full_path):
508
+ image = Image.open(full_path).convert("RGB")
509
+
510
+ # Attach visual prompt boxes if provided (for Visual Prompting examples)
511
+ if visual_prompt_boxes:
512
+ image._example_visual_prompts = visual_prompt_boxes
513
+
514
+ return image
515
+ else:
516
+ print(f"Warning: Example image not found at {full_path}")
517
+ return None
518
+ except Exception as e:
519
+ print(f"Error loading example image: {e}")
520
+ return None
521
+
522
+
523
+ def prepare_gallery_data():
524
+ """Prepare gallery data for examples"""
525
+ gallery_images = []
526
+ gallery_captions = []
527
+
528
+ for config in EXAMPLE_CONFIGS:
529
+ # Load example image
530
+ image = load_example_image(config["image_path"], config["visual_prompt_boxes"])
531
+ if image:
532
+ gallery_images.append(image)
533
+ gallery_captions.append(f"{config['name']}\n{config['description']}")
534
+
535
+ return gallery_images, gallery_captions
536
+
537
+
538
+ def update_example_selection(selected_index):
539
+ """Update all interface elements based on gallery selection"""
540
+ if selected_index is None or selected_index >= len(EXAMPLE_CONFIGS):
541
+ return [gr.update() for _ in range(7)] # Return no updates if invalid selection
542
+
543
+ config = EXAMPLE_CONFIGS[selected_index]
544
+
545
+ # Load example image if available
546
+ example_image = None
547
+ if config["image_path"]:
548
+ example_image = load_example_image(
549
+ config["image_path"], config["visual_prompt_boxes"]
550
+ )
551
+
552
+ return (
553
+ example_image, # input_image
554
+ config["task"], # task_selection
555
+ config["categories"], # categories
556
+ config["keypoint_type"], # keypoint_type
557
+ config["ocr_output_format"], # ocr_output_format
558
+ config["ocr_granularity"], # ocr_granularity
559
+ gr.update(
560
+ value=DEMO_TASK_CONFIGS[config["task"]]["description"]
561
+ ), # task_description
562
+ )
563
+
564
+
565
+ def update_prompt_preview(
566
+ task_selection,
567
+ categories,
568
+ keypoint_type,
569
+ visual_prompt_data,
570
+ ocr_output_format,
571
+ ocr_granularity,
572
+ ):
573
+ """Update the prompt preview"""
574
+ if visual_prompt_data is None:
575
+ visual_prompt_data = {}
576
+
577
+ # Parse visual prompts
578
+ visual_prompt_boxes = []
579
+ if "points" in visual_prompt_data:
580
+ visual_prompt_boxes = parse_visual_prompt(visual_prompt_data["points"])
581
+
582
+ # Generate prompt preview
583
+ prompt = get_task_prompt(
584
+ task_selection,
585
+ categories,
586
+ keypoint_type,
587
+ visual_prompt_boxes,
588
+ 800, # dummy image dimensions for preview
589
+ 600,
590
+ ocr_output_format=ocr_output_format,
591
+ ocr_granularity=ocr_granularity,
592
+ )
593
+
594
+ return prompt
595
+
596
+
597
+ def create_demo():
598
+ """Create the Gradio demo interface"""
599
+
600
+ with gr.Blocks(
601
+ theme=gr.themes.Soft(primary_hue="blue"),
602
+ title="Rex Omni Demo",
603
+ css="""
604
+ .gradio-container {
605
+ max-width: 1400px !important;
606
+ }
607
+ .prompt-preview {
608
+ background-color: #f8f9fa;
609
+ border: 1px solid #dee2e6;
610
+ border-radius: 0.375rem;
611
+ padding: 0.75rem;
612
+ font-family: 'Courier New', monospace;
613
+ font-size: 0.875rem;
614
+ }
615
+ .preserve-aspect-ratio img {
616
+ object-fit: contain !important;
617
+ max-height: 400px !important;
618
+ width: auto !important;
619
+ }
620
+ .preserve-aspect-ratio canvas {
621
+ object-fit: contain !important;
622
+ max-height: 400px !important;
623
+ width: auto !important;
624
+ }
625
+ """,
626
+ ) as demo:
627
+
628
+ gr.Markdown("# Rex Omni: Detect Anything Demo")
629
+ gr.Markdown("Upload an image and select a task to see Rex Omni in action!")
630
+
631
+ with gr.Row():
632
+ # Left Column - Input Controls
633
+ with gr.Column(scale=1):
634
+ gr.Markdown("## πŸ“ Task Configuration")
635
+
636
+ # Task Selection
637
+ task_selection = gr.Dropdown(
638
+ label="Select Task",
639
+ choices=list(DEMO_TASK_CONFIGS.keys()),
640
+ value="Detection",
641
+ info="Choose the vision task to perform",
642
+ )
643
+
644
+ # Task Description
645
+ task_description = gr.Textbox(
646
+ label="Task Description",
647
+ value=DEMO_TASK_CONFIGS["Detection"]["description"],
648
+ interactive=False,
649
+ lines=2,
650
+ )
651
+
652
+ # Text Prompt Section
653
+ with gr.Group():
654
+ gr.Markdown("### πŸ’¬ Text Prompt Configuration")
655
+
656
+ categories = gr.Textbox(
657
+ label="Categories",
658
+ value="person, car, dog",
659
+ placeholder="person, car, dog",
660
+ info="Enter object categories separated by commas",
661
+ visible=True,
662
+ )
663
+
664
+ keypoint_type = gr.Dropdown(
665
+ label="Keypoint Type",
666
+ choices=list(KEYPOINT_CONFIGS.keys()),
667
+ value="person",
668
+ visible=False,
669
+ info="Select the type of keypoints to detect",
670
+ )
671
+
672
+ # OCR Configuration Section
673
+ ocr_config_group = gr.Group(visible=False)
674
+ with ocr_config_group:
675
+ gr.Markdown("### πŸ“„ OCR Configuration")
676
+
677
+ ocr_output_format = gr.Radio(
678
+ label="Output Format",
679
+ choices=list(OCR_OUTPUT_FORMATS.keys()),
680
+ value="Box",
681
+ info="Choose between bounding box or polygon output format",
682
+ )
683
+
684
+ ocr_granularity = gr.Radio(
685
+ label="Granularity Level",
686
+ choices=list(OCR_GRANULARITY_LEVELS.keys()),
687
+ value="Word Level",
688
+ info="Choose between word-level or text-line-level detection",
689
+ )
690
+
691
+ # Visual Prompt Section
692
+ visual_prompt_tab = gr.Group(visible=False)
693
+ with visual_prompt_tab:
694
+ gr.Markdown("### 🎯 Visual Prompt Configuration")
695
+ gr.Markdown(
696
+ "Draw bounding boxes on the image to provide visual examples"
697
+ )
698
+
699
+ # Prompt Preview
700
+ gr.Markdown("### πŸ” Generated Prompt Preview")
701
+ prompt_preview = gr.Textbox(
702
+ label="Actual Prompt",
703
+ value="Detect person, car, dog.",
704
+ interactive=False,
705
+ lines=3,
706
+ elem_classes=["prompt-preview"],
707
+ )
708
+
709
+ # Visualization Controls
710
+ with gr.Accordion("🎨 Visualization Settings", open=False):
711
+ font_size = gr.Slider(
712
+ label="Font Size", value=20, minimum=10, maximum=50, step=1
713
+ )
714
+ draw_width = gr.Slider(
715
+ label="Line Width", value=5, minimum=1, maximum=20, step=1
716
+ )
717
+ show_labels = gr.Checkbox(label="Show Labels", value=True)
718
+ custom_color = gr.Textbox(
719
+ label="Custom Colors (Hex)",
720
+ placeholder="#FF0000,#00FF00,#0000FF",
721
+ info="Comma-separated hex colors for different categories",
722
+ )
723
+
724
+ # Right Column - Image and Results
725
+ with gr.Column(scale=2):
726
+ with gr.Row():
727
+ # Input Image
728
+ with gr.Column():
729
+ input_image = gr.Image(
730
+ label="πŸ“· Input Image", type="numpy", height=400
731
+ )
732
+
733
+ # Visual Prompt Interface (only visible for Visual Prompting task)
734
+ visual_prompter = ImagePrompter(
735
+ label="🎯 Visual Prompt Interface",
736
+ width=420,
737
+ height=315, # 4:3 aspect ratio (420 * 3/4 = 315)
738
+ visible=False,
739
+ elem_classes=["preserve-aspect-ratio"],
740
+ )
741
+
742
+ # Output Visualization
743
+ with gr.Column():
744
+ output_image = gr.Image(
745
+ label="🎨 Visualization Result", height=400
746
+ )
747
+
748
+ # Run Button
749
+ run_button = gr.Button("πŸš€ Run Inference", variant="primary", size="lg")
750
+
751
+ # Model Output
752
+ model_output = gr.Textbox(
753
+ label="πŸ€– Model Raw Output",
754
+ lines=15,
755
+ max_lines=20,
756
+ show_copy_button=True,
757
+ )
758
+
759
+ # Example Gallery Section
760
+ with gr.Row():
761
+ gr.Markdown("## πŸ–ΌοΈ Example Gallery")
762
+
763
+ with gr.Row():
764
+ gallery_images, gallery_captions = prepare_gallery_data()
765
+ example_gallery = gr.Gallery(
766
+ value=list(zip(gallery_images, gallery_captions)),
767
+ label="Click on an example to load it",
768
+ show_label=True,
769
+ elem_id="example_gallery",
770
+ columns=4,
771
+ rows=2,
772
+ height="auto",
773
+ allow_preview=True,
774
+ )
775
+
776
+ # Event Handlers
777
+
778
+ # Update interface when gallery example is selected
779
+ def handle_gallery_select(evt: gr.SelectData):
780
+ return update_example_selection(evt.index)
781
+
782
+ example_gallery.select(
783
+ fn=handle_gallery_select,
784
+ outputs=[
785
+ input_image,
786
+ task_selection,
787
+ categories,
788
+ keypoint_type,
789
+ ocr_output_format,
790
+ ocr_granularity,
791
+ task_description,
792
+ ],
793
+ )
794
+
795
+ # Update interface when task changes
796
+ task_selection.change(
797
+ fn=update_interface,
798
+ inputs=[task_selection],
799
+ outputs=[
800
+ categories,
801
+ keypoint_type,
802
+ visual_prompt_tab,
803
+ ocr_config_group,
804
+ task_description,
805
+ ],
806
+ )
807
+
808
+ # Update prompt preview when any input changes
809
+ for component in [
810
+ task_selection,
811
+ categories,
812
+ keypoint_type,
813
+ ocr_output_format,
814
+ ocr_granularity,
815
+ ]:
816
+ component.change(
817
+ fn=update_prompt_preview,
818
+ inputs=[
819
+ task_selection,
820
+ categories,
821
+ keypoint_type,
822
+ visual_prompter,
823
+ ocr_output_format,
824
+ ocr_granularity,
825
+ ],
826
+ outputs=[prompt_preview],
827
+ )
828
+
829
+ # Show/hide visual prompter based on task
830
+ def toggle_visual_prompter(task_selection):
831
+ if task_selection == "Visual Prompting":
832
+ return gr.update(visible=False), gr.update(visible=True)
833
+ else:
834
+ return gr.update(visible=True), gr.update(visible=False)
835
+
836
+ task_selection.change(
837
+ fn=toggle_visual_prompter,
838
+ inputs=[task_selection],
839
+ outputs=[input_image, visual_prompter],
840
+ )
841
+
842
+ # Run inference with dynamic image selection
843
+ def run_inference_wrapper(
844
+ input_image,
845
+ visual_prompter_data,
846
+ task_selection,
847
+ categories,
848
+ keypoint_type,
849
+ ocr_output_format,
850
+ ocr_granularity,
851
+ font_size,
852
+ draw_width,
853
+ show_labels,
854
+ custom_color,
855
+ ):
856
+ # For Visual Prompting task, use the visual prompter image
857
+ if task_selection == "Visual Prompting":
858
+ if visual_prompter_data is not None and "image" in visual_prompter_data:
859
+ image_to_use = visual_prompter_data["image"]
860
+ else:
861
+ return (
862
+ None,
863
+ "Please upload an image in the Visual Prompt Interface for Visual Prompting task.",
864
+ )
865
+ else:
866
+ image_to_use = input_image
867
+
868
+ return run_inference(
869
+ image_to_use,
870
+ task_selection,
871
+ categories,
872
+ keypoint_type,
873
+ visual_prompter_data,
874
+ ocr_output_format,
875
+ ocr_granularity,
876
+ font_size,
877
+ draw_width,
878
+ show_labels,
879
+ custom_color,
880
+ )
881
+
882
+ run_button.click(
883
+ fn=run_inference_wrapper,
884
+ inputs=[
885
+ input_image,
886
+ visual_prompter,
887
+ task_selection,
888
+ categories,
889
+ keypoint_type,
890
+ ocr_output_format,
891
+ ocr_granularity,
892
+ font_size,
893
+ draw_width,
894
+ show_labels,
895
+ custom_color,
896
+ ],
897
+ outputs=[output_image, model_output],
898
+ )
899
+
900
+ return demo
901
+
902
+
903
+ if __name__ == "__main__":
904
+ args = parse_args()
905
+
906
+ print("πŸš€ Initializing Rex Omni model...")
907
+ print(f"Model: {args.model_path}")
908
+ print(f"Backend: {args.backend}")
909
+
910
+ # Initialize Rex Omni model
911
+ rex_model = RexOmniWrapper(
912
+ model_path=args.model_path,
913
+ backend=args.backend,
914
+ max_tokens=args.max_tokens,
915
+ temperature=args.temperature,
916
+ top_p=args.top_p,
917
+ top_k=args.top_k,
918
+ repetition_penalty=args.repetition_penalty,
919
+ min_pixels=args.min_pixels,
920
+ max_pixels=args.max_pixels,
921
+ )
922
+
923
+ print("βœ… Model initialized successfully!")
924
+
925
+ # Create and launch demo
926
+ demo = create_demo()
927
+
928
+ print(f"🌐 Launching demo at http://{args.server_name}:{args.server_port}")
929
+ demo.launch(
930
+ server_name=args.server_name,
931
+ server_port=args.server_port,
932
+ share=True,
933
+ debug=True,
934
+ )
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib==3.10.6
2
+ numpy==1.26.4
3
+ Pillow==11.3.0
4
+ qwen_vl_utils==0.0.14
5
+ torch==2.6.0+cu124
6
+ transformers==4.51.3
7
+ vllm==0.8.2
8
+ accelerate==1.10.1
9
+ flash-attn==2.7.4.post1
10
+ gradio==4.44.1
11
+ gradio_image_prompter==0.1.0
12
+ pydantic==2.10.6