Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse filesFix Visual Prompting bug
app.py
CHANGED
|
@@ -1,26 +1,32 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# -*- coding: utf-8 -*-
|
| 3 |
-
import spaces
|
| 4 |
import argparse
|
| 5 |
import json
|
| 6 |
import os
|
| 7 |
|
|
|
|
| 8 |
|
| 9 |
-
os.system(
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
import subprocess
|
| 12 |
-
subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
import sys
|
| 15 |
import threading
|
| 16 |
-
import re
|
| 17 |
from typing import Any, Dict, List
|
| 18 |
|
| 19 |
import gradio as gr
|
| 20 |
import numpy as np
|
| 21 |
-
from
|
| 22 |
from PIL import Image
|
| 23 |
-
|
| 24 |
from rex_omni import RexOmniVisualize, RexOmniWrapper, TaskType
|
| 25 |
from rex_omni.tasks import KEYPOINT_CONFIGS, TASK_CONFIGS, get_task_config
|
| 26 |
|
|
@@ -234,22 +240,33 @@ EXAMPLE_CONFIGS = [
|
|
| 234 |
]
|
| 235 |
|
| 236 |
|
| 237 |
-
def parse_visual_prompt(
|
| 238 |
-
"""Parse
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
|
| 255 |
def convert_boxes_to_visual_prompt_format(
|
|
@@ -344,6 +361,7 @@ def get_task_prompt(
|
|
| 344 |
else:
|
| 345 |
return task_config.prompt_template.replace("{categories}", "objects")
|
| 346 |
|
|
|
|
| 347 |
@spaces.GPU
|
| 348 |
def run_inference(
|
| 349 |
image,
|
|
@@ -362,7 +380,6 @@ def run_inference(
|
|
| 362 |
if image is None:
|
| 363 |
return None, "Please upload an image first."
|
| 364 |
|
| 365 |
-
|
| 366 |
# Convert numpy array to PIL Image if needed
|
| 367 |
if isinstance(image, np.ndarray):
|
| 368 |
image = Image.fromarray(image)
|
|
@@ -375,8 +392,8 @@ def run_inference(
|
|
| 375 |
# Check if we have predefined visual prompt boxes from examples
|
| 376 |
if hasattr(image, "_example_visual_prompts"):
|
| 377 |
visual_prompt_boxes = image._example_visual_prompts
|
| 378 |
-
elif visual_prompt_data is not None
|
| 379 |
-
visual_prompt_boxes = parse_visual_prompt(visual_prompt_data
|
| 380 |
|
| 381 |
# Determine task type and categories based on task selection
|
| 382 |
if task_selection == "OCR":
|
|
@@ -406,9 +423,7 @@ def run_inference(
|
|
| 406 |
task_key = task_type.value
|
| 407 |
|
| 408 |
# Split categories by comma and clean up
|
| 409 |
-
categories_list = [
|
| 410 |
-
cat.strip() for cat in categories.split(",") if cat.strip()
|
| 411 |
-
]
|
| 412 |
if not categories_list:
|
| 413 |
categories_list = ["object"]
|
| 414 |
|
|
@@ -456,6 +471,7 @@ def run_inference(
|
|
| 456 |
except Exception as e:
|
| 457 |
return image, f"Visualization failed: {str(e)}\n\nRaw output:\n{raw_output}"
|
| 458 |
|
|
|
|
| 459 |
def update_interface(task_selection):
|
| 460 |
"""Update interface based on task selection"""
|
| 461 |
config = DEMO_TASK_CONFIGS.get(task_selection, {})
|
|
@@ -580,8 +596,8 @@ def update_prompt_preview(
|
|
| 580 |
|
| 581 |
# Parse visual prompts
|
| 582 |
visual_prompt_boxes = []
|
| 583 |
-
if
|
| 584 |
-
visual_prompt_boxes = parse_visual_prompt(visual_prompt_data
|
| 585 |
|
| 586 |
# Generate prompt preview
|
| 587 |
prompt = get_task_prompt(
|
|
@@ -697,7 +713,7 @@ def create_demo():
|
|
| 697 |
with visual_prompt_tab:
|
| 698 |
gr.Markdown("### 🎯 Visual Prompt Configuration")
|
| 699 |
gr.Markdown(
|
| 700 |
-
"
|
| 701 |
)
|
| 702 |
|
| 703 |
# Prompt Preview
|
|
@@ -735,10 +751,9 @@ def create_demo():
|
|
| 735 |
)
|
| 736 |
|
| 737 |
# Visual Prompt Interface (only visible for Visual Prompting task)
|
| 738 |
-
visual_prompter =
|
| 739 |
label="🎯 Visual Prompt Interface",
|
| 740 |
-
|
| 741 |
-
height=315, # 4:3 aspect ratio (420 * 3/4 = 315)
|
| 742 |
visible=False,
|
| 743 |
elem_classes=["preserve-aspect-ratio"],
|
| 744 |
)
|
|
@@ -857,15 +872,30 @@ def create_demo():
|
|
| 857 |
show_labels,
|
| 858 |
custom_color,
|
| 859 |
):
|
| 860 |
-
# For Visual Prompting task,
|
| 861 |
if task_selection == "Visual Prompting":
|
| 862 |
-
if
|
| 863 |
-
|
| 864 |
-
|
|
|
|
|
|
|
| 865 |
return (
|
| 866 |
None,
|
| 867 |
-
"Please upload an image in the Visual Prompt Interface for Visual Prompting task.",
|
| 868 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 869 |
else:
|
| 870 |
image_to_use = input_image
|
| 871 |
|
|
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# -*- coding: utf-8 -*-
|
|
|
|
| 3 |
import argparse
|
| 4 |
import json
|
| 5 |
import os
|
| 6 |
|
| 7 |
+
import spaces
|
| 8 |
|
| 9 |
+
os.system(
|
| 10 |
+
"pip install torch==2.4.0 torchvision==0.18.0 --index-url https://download.pytorch.org/whl/cu124"
|
| 11 |
+
)
|
| 12 |
+
os.system("pip install gradio_bbox_annotator")
|
| 13 |
import subprocess
|
|
|
|
| 14 |
|
| 15 |
+
subprocess.run(
|
| 16 |
+
"pip install flash-attn==2.7.4.post1 --no-build-isolation",
|
| 17 |
+
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
| 18 |
+
shell=True,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
import re
|
| 22 |
import sys
|
| 23 |
import threading
|
|
|
|
| 24 |
from typing import Any, Dict, List
|
| 25 |
|
| 26 |
import gradio as gr
|
| 27 |
import numpy as np
|
| 28 |
+
from gradio_bbox_annotator import BBoxAnnotator
|
| 29 |
from PIL import Image
|
|
|
|
| 30 |
from rex_omni import RexOmniVisualize, RexOmniWrapper, TaskType
|
| 31 |
from rex_omni.tasks import KEYPOINT_CONFIGS, TASK_CONFIGS, get_task_config
|
| 32 |
|
|
|
|
| 240 |
]
|
| 241 |
|
| 242 |
|
| 243 |
+
def parse_visual_prompt(bbox_data) -> List[List[float]]:
|
| 244 |
+
"""Parse BBoxAnnotator output to bounding boxes"""
|
| 245 |
+
if bbox_data is None:
|
| 246 |
+
return []
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
# BBoxAnnotator returns format: (image, boxes_list)
|
| 250 |
+
# where boxes_list contains [x, y, width, height] for each box
|
| 251 |
+
if isinstance(bbox_data, tuple) and len(bbox_data) >= 2:
|
| 252 |
+
boxes_list = bbox_data[1]
|
| 253 |
+
else:
|
| 254 |
+
boxes_list = bbox_data
|
| 255 |
+
|
| 256 |
+
if not boxes_list:
|
| 257 |
+
return []
|
| 258 |
+
|
| 259 |
+
# Convert from [x, y, width, height] to [x1, y1, x2, y2] format
|
| 260 |
+
boxes = []
|
| 261 |
+
for box in boxes_list:
|
| 262 |
+
if len(box) >= 4:
|
| 263 |
+
x1, y1, x2, y2 = box[:4]
|
| 264 |
+
boxes.append([x1, y1, x2, y2])
|
| 265 |
+
|
| 266 |
+
return boxes
|
| 267 |
+
except Exception as e:
|
| 268 |
+
print(f"Error parsing visual prompt: {e}")
|
| 269 |
+
return []
|
| 270 |
|
| 271 |
|
| 272 |
def convert_boxes_to_visual_prompt_format(
|
|
|
|
| 361 |
else:
|
| 362 |
return task_config.prompt_template.replace("{categories}", "objects")
|
| 363 |
|
| 364 |
+
|
| 365 |
@spaces.GPU
|
| 366 |
def run_inference(
|
| 367 |
image,
|
|
|
|
| 380 |
if image is None:
|
| 381 |
return None, "Please upload an image first."
|
| 382 |
|
|
|
|
| 383 |
# Convert numpy array to PIL Image if needed
|
| 384 |
if isinstance(image, np.ndarray):
|
| 385 |
image = Image.fromarray(image)
|
|
|
|
| 392 |
# Check if we have predefined visual prompt boxes from examples
|
| 393 |
if hasattr(image, "_example_visual_prompts"):
|
| 394 |
visual_prompt_boxes = image._example_visual_prompts
|
| 395 |
+
elif visual_prompt_data is not None:
|
| 396 |
+
visual_prompt_boxes = parse_visual_prompt(visual_prompt_data)
|
| 397 |
|
| 398 |
# Determine task type and categories based on task selection
|
| 399 |
if task_selection == "OCR":
|
|
|
|
| 423 |
task_key = task_type.value
|
| 424 |
|
| 425 |
# Split categories by comma and clean up
|
| 426 |
+
categories_list = [cat.strip() for cat in categories.split(",") if cat.strip()]
|
|
|
|
|
|
|
| 427 |
if not categories_list:
|
| 428 |
categories_list = ["object"]
|
| 429 |
|
|
|
|
| 471 |
except Exception as e:
|
| 472 |
return image, f"Visualization failed: {str(e)}\n\nRaw output:\n{raw_output}"
|
| 473 |
|
| 474 |
+
|
| 475 |
def update_interface(task_selection):
|
| 476 |
"""Update interface based on task selection"""
|
| 477 |
config = DEMO_TASK_CONFIGS.get(task_selection, {})
|
|
|
|
| 596 |
|
| 597 |
# Parse visual prompts
|
| 598 |
visual_prompt_boxes = []
|
| 599 |
+
if visual_prompt_data is not None:
|
| 600 |
+
visual_prompt_boxes = parse_visual_prompt(visual_prompt_data)
|
| 601 |
|
| 602 |
# Generate prompt preview
|
| 603 |
prompt = get_task_prompt(
|
|
|
|
| 713 |
with visual_prompt_tab:
|
| 714 |
gr.Markdown("### 🎯 Visual Prompt Configuration")
|
| 715 |
gr.Markdown(
|
| 716 |
+
"Select the pen tool and draw one or multiple boxes on the image. "
|
| 717 |
)
|
| 718 |
|
| 719 |
# Prompt Preview
|
|
|
|
| 751 |
)
|
| 752 |
|
| 753 |
# Visual Prompt Interface (only visible for Visual Prompting task)
|
| 754 |
+
visual_prompter = BBoxAnnotator(
|
| 755 |
label="🎯 Visual Prompt Interface",
|
| 756 |
+
categories="D",
|
|
|
|
| 757 |
visible=False,
|
| 758 |
elem_classes=["preserve-aspect-ratio"],
|
| 759 |
)
|
|
|
|
| 872 |
show_labels,
|
| 873 |
custom_color,
|
| 874 |
):
|
| 875 |
+
# For Visual Prompting task, extract image from BBoxAnnotator data
|
| 876 |
if task_selection == "Visual Prompting":
|
| 877 |
+
if (
|
| 878 |
+
visual_prompter_data is None
|
| 879 |
+
or not isinstance(visual_prompter_data, tuple)
|
| 880 |
+
or len(visual_prompter_data) < 1
|
| 881 |
+
):
|
| 882 |
return (
|
| 883 |
None,
|
| 884 |
+
"Please upload an image and draw bounding boxes in the Visual Prompt Interface for Visual Prompting task.",
|
| 885 |
)
|
| 886 |
+
# Extract image from BBoxAnnotator data (first element of the tuple)
|
| 887 |
+
image_to_use = visual_prompter_data[0]
|
| 888 |
+
# If image_to_use is a string (file path), convert to PIL Image
|
| 889 |
+
if isinstance(image_to_use, str):
|
| 890 |
+
try:
|
| 891 |
+
from PIL import Image
|
| 892 |
+
|
| 893 |
+
image_to_use = Image.open(image_to_use).convert("RGB")
|
| 894 |
+
except Exception as e:
|
| 895 |
+
return (
|
| 896 |
+
None,
|
| 897 |
+
f"Error loading image from path: {e}",
|
| 898 |
+
)
|
| 899 |
else:
|
| 900 |
image_to_use = input_image
|
| 901 |
|