Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -362,103 +362,99 @@ def run_inference(
|
|
| 362 |
if image is None:
|
| 363 |
return None, "Please upload an image first."
|
| 364 |
|
| 365 |
-
try:
|
| 366 |
-
# Convert numpy array to PIL Image if needed
|
| 367 |
-
if isinstance(image, np.ndarray):
|
| 368 |
-
image = Image.fromarray(image)
|
| 369 |
-
|
| 370 |
-
image_width, image_height = image.size
|
| 371 |
-
|
| 372 |
-
# Parse visual prompts if needed
|
| 373 |
-
visual_prompt_boxes = []
|
| 374 |
-
if task_selection == "Visual Prompting":
|
| 375 |
-
# Check if we have predefined visual prompt boxes from examples
|
| 376 |
-
if hasattr(image, "_example_visual_prompts"):
|
| 377 |
-
visual_prompt_boxes = image._example_visual_prompts
|
| 378 |
-
elif visual_prompt_data is not None and "points" in visual_prompt_data:
|
| 379 |
-
visual_prompt_boxes = parse_visual_prompt(visual_prompt_data["points"])
|
| 380 |
-
|
| 381 |
-
# Determine task type and categories based on task selection
|
| 382 |
-
if task_selection == "OCR":
|
| 383 |
-
# For OCR, use the selected output format to determine task type
|
| 384 |
-
task_type = OCR_OUTPUT_FORMATS[ocr_output_format]["task_type"]
|
| 385 |
-
task_key = task_type.value
|
| 386 |
-
# Use granularity level to determine categories
|
| 387 |
-
categories_list = [OCR_GRANULARITY_LEVELS[ocr_granularity]["categories"]]
|
| 388 |
-
elif task_selection == "Visual Prompting":
|
| 389 |
-
# For visual prompting, we don't need explicit categories
|
| 390 |
-
task_key = "visual_prompting"
|
| 391 |
-
categories_list = ["object"]
|
| 392 |
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
None,
|
| 397 |
-
"Please draw bounding boxes on the image to provide visual examples for Visual Prompting task.",
|
| 398 |
-
)
|
| 399 |
-
elif task_selection == "Keypoint":
|
| 400 |
-
task_key = "keypoint"
|
| 401 |
-
categories_list = [keypoint_type] if keypoint_type else ["person"]
|
| 402 |
-
else:
|
| 403 |
-
# For other tasks, get task type from demo config
|
| 404 |
-
demo_config = DEMO_TASK_CONFIGS[task_selection]
|
| 405 |
-
task_type = demo_config["task_type"]
|
| 406 |
-
task_key = task_type.value
|
| 407 |
-
|
| 408 |
-
# Split categories by comma and clean up
|
| 409 |
-
categories_list = [
|
| 410 |
-
cat.strip() for cat in categories.split(",") if cat.strip()
|
| 411 |
-
]
|
| 412 |
-
if not categories_list:
|
| 413 |
-
categories_list = ["object"]
|
| 414 |
-
|
| 415 |
-
# Run inference
|
| 416 |
-
if task_selection == "Visual Prompting":
|
| 417 |
-
results = rex_model.inference(
|
| 418 |
-
images=image,
|
| 419 |
-
task=task_key,
|
| 420 |
-
categories=categories_list,
|
| 421 |
-
visual_prompt_boxes=visual_prompt_boxes,
|
| 422 |
-
)
|
| 423 |
-
elif task_selection == "Keypoint":
|
| 424 |
-
results = rex_model.inference(
|
| 425 |
-
images=image,
|
| 426 |
-
task=task_key,
|
| 427 |
-
categories=categories_list,
|
| 428 |
-
keypoint_type=keypoint_type if keypoint_type else "person",
|
| 429 |
-
)
|
| 430 |
-
else:
|
| 431 |
-
results = rex_model.inference(
|
| 432 |
-
images=image, task=task_key, categories=categories_list
|
| 433 |
-
)
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
)
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
-
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
|
| 463 |
def update_interface(task_selection):
|
| 464 |
"""Update interface based on task selection"""
|
|
|
|
| 362 |
if image is None:
|
| 363 |
return None, "Please upload an image first."
|
| 364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
| 366 |
+
# Convert numpy array to PIL Image if needed
|
| 367 |
+
if isinstance(image, np.ndarray):
|
| 368 |
+
image = Image.fromarray(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
+
image_width, image_height = image.size
|
| 371 |
+
|
| 372 |
+
# Parse visual prompts if needed
|
| 373 |
+
visual_prompt_boxes = []
|
| 374 |
+
if task_selection == "Visual Prompting":
|
| 375 |
+
# Check if we have predefined visual prompt boxes from examples
|
| 376 |
+
if hasattr(image, "_example_visual_prompts"):
|
| 377 |
+
visual_prompt_boxes = image._example_visual_prompts
|
| 378 |
+
elif visual_prompt_data is not None and "points" in visual_prompt_data:
|
| 379 |
+
visual_prompt_boxes = parse_visual_prompt(visual_prompt_data["points"])
|
| 380 |
+
|
| 381 |
+
# Determine task type and categories based on task selection
|
| 382 |
+
if task_selection == "OCR":
|
| 383 |
+
# For OCR, use the selected output format to determine task type
|
| 384 |
+
task_type = OCR_OUTPUT_FORMATS[ocr_output_format]["task_type"]
|
| 385 |
+
task_key = task_type.value
|
| 386 |
+
# Use granularity level to determine categories
|
| 387 |
+
categories_list = [OCR_GRANULARITY_LEVELS[ocr_granularity]["categories"]]
|
| 388 |
+
elif task_selection == "Visual Prompting":
|
| 389 |
+
# For visual prompting, we don't need explicit categories
|
| 390 |
+
task_key = "visual_prompting"
|
| 391 |
+
categories_list = ["object"]
|
| 392 |
+
|
| 393 |
+
# Check if visual prompt boxes are provided
|
| 394 |
+
if not visual_prompt_boxes:
|
| 395 |
+
return (
|
| 396 |
+
None,
|
| 397 |
+
"Please draw bounding boxes on the image to provide visual examples for Visual Prompting task.",
|
| 398 |
)
|
| 399 |
+
elif task_selection == "Keypoint":
|
| 400 |
+
task_key = "keypoint"
|
| 401 |
+
categories_list = [keypoint_type] if keypoint_type else ["person"]
|
| 402 |
+
else:
|
| 403 |
+
# For other tasks, get task type from demo config
|
| 404 |
+
demo_config = DEMO_TASK_CONFIGS[task_selection]
|
| 405 |
+
task_type = demo_config["task_type"]
|
| 406 |
+
task_key = task_type.value
|
| 407 |
|
| 408 |
+
# Split categories by comma and clean up
|
| 409 |
+
categories_list = [
|
| 410 |
+
cat.strip() for cat in categories.split(",") if cat.strip()
|
| 411 |
+
]
|
| 412 |
+
if not categories_list:
|
| 413 |
+
categories_list = ["object"]
|
| 414 |
+
|
| 415 |
+
# Run inference
|
| 416 |
+
if task_selection == "Visual Prompting":
|
| 417 |
+
results = rex_model.inference(
|
| 418 |
+
images=image,
|
| 419 |
+
task=task_key,
|
| 420 |
+
categories=categories_list,
|
| 421 |
+
visual_prompt_boxes=visual_prompt_boxes,
|
| 422 |
+
)
|
| 423 |
+
elif task_selection == "Keypoint":
|
| 424 |
+
results = rex_model.inference(
|
| 425 |
+
images=image,
|
| 426 |
+
task=task_key,
|
| 427 |
+
categories=categories_list,
|
| 428 |
+
keypoint_type=keypoint_type if keypoint_type else "person",
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
results = rex_model.inference(
|
| 432 |
+
images=image, task=task_key, categories=categories_list
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
result = results[0]
|
| 436 |
+
|
| 437 |
+
# Check if inference was successful
|
| 438 |
+
if not result.get("success", False):
|
| 439 |
+
error_msg = result.get("error", "Unknown error occurred during inference")
|
| 440 |
+
return None, f"Inference failed: {error_msg}"
|
| 441 |
+
|
| 442 |
+
# Get predictions and raw output
|
| 443 |
+
predictions = result["extracted_predictions"]
|
| 444 |
+
raw_output = result["raw_output"]
|
| 445 |
|
| 446 |
+
# Create visualization
|
| 447 |
+
try:
|
| 448 |
+
vis_image = RexOmniVisualize(
|
| 449 |
+
image=image,
|
| 450 |
+
predictions=predictions,
|
| 451 |
+
font_size=font_size,
|
| 452 |
+
draw_width=draw_width,
|
| 453 |
+
show_labels=show_labels,
|
| 454 |
+
)
|
| 455 |
+
return vis_image, raw_output
|
| 456 |
+
except Exception as e:
|
| 457 |
+
return image, f"Visualization failed: {str(e)}\n\nRaw output:\n{raw_output}"
|
| 458 |
|
| 459 |
def update_interface(task_selection):
|
| 460 |
"""Update interface based on task selection"""
|