Mountchicken commited on
Commit
e2e6048
·
verified ·
1 Parent(s): 60f587b

Update app.py

Browse files

Fix Visual Prompting bug

Files changed (1) hide show
  1. app.py +70 -40
app.py CHANGED
@@ -1,26 +1,32 @@
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
3
- import spaces
4
  import argparse
5
  import json
6
  import os
7
 
 
8
 
9
- os.system("pip install torch==2.4.0 torchvision==0.18.0 --index-url https://download.pytorch.org/whl/cu124")
10
-
 
 
11
  import subprocess
12
- subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
13
 
 
 
 
 
 
 
 
14
  import sys
15
  import threading
16
- import re
17
  from typing import Any, Dict, List
18
 
19
  import gradio as gr
20
  import numpy as np
21
- from gradio_image_prompter import ImagePrompter
22
  from PIL import Image
23
-
24
  from rex_omni import RexOmniVisualize, RexOmniWrapper, TaskType
25
  from rex_omni.tasks import KEYPOINT_CONFIGS, TASK_CONFIGS, get_task_config
26
 
@@ -234,22 +240,33 @@ EXAMPLE_CONFIGS = [
234
  ]
235
 
236
 
237
- def parse_visual_prompt(points: List) -> List[List[float]]:
238
- """Parse visual prompt points to bounding boxes"""
239
- boxes = []
240
- for point in points:
241
- if point[2] == 2 and point[-1] == 3: # Rectangle
242
- x1, y1, _, x2, y2, _ = point
243
- boxes.append([x1, y1, x2, y2])
244
- elif point[2] == 1 and point[-1] == 4: # Positive point
245
- x, y, _, _, _, _ = point
246
- half_width = 10
247
- x1 = max(0, x - half_width)
248
- y1 = max(0, y - half_width)
249
- x2 = x + half_width
250
- y2 = y + half_width
251
- boxes.append([x1, y1, x2, y2])
252
- return boxes
 
 
 
 
 
 
 
 
 
 
 
253
 
254
 
255
  def convert_boxes_to_visual_prompt_format(
@@ -344,6 +361,7 @@ def get_task_prompt(
344
  else:
345
  return task_config.prompt_template.replace("{categories}", "objects")
346
 
 
347
  @spaces.GPU
348
  def run_inference(
349
  image,
@@ -362,7 +380,6 @@ def run_inference(
362
  if image is None:
363
  return None, "Please upload an image first."
364
 
365
-
366
  # Convert numpy array to PIL Image if needed
367
  if isinstance(image, np.ndarray):
368
  image = Image.fromarray(image)
@@ -375,8 +392,8 @@ def run_inference(
375
  # Check if we have predefined visual prompt boxes from examples
376
  if hasattr(image, "_example_visual_prompts"):
377
  visual_prompt_boxes = image._example_visual_prompts
378
- elif visual_prompt_data is not None and "points" in visual_prompt_data:
379
- visual_prompt_boxes = parse_visual_prompt(visual_prompt_data["points"])
380
 
381
  # Determine task type and categories based on task selection
382
  if task_selection == "OCR":
@@ -406,9 +423,7 @@ def run_inference(
406
  task_key = task_type.value
407
 
408
  # Split categories by comma and clean up
409
- categories_list = [
410
- cat.strip() for cat in categories.split(",") if cat.strip()
411
- ]
412
  if not categories_list:
413
  categories_list = ["object"]
414
 
@@ -456,6 +471,7 @@ def run_inference(
456
  except Exception as e:
457
  return image, f"Visualization failed: {str(e)}\n\nRaw output:\n{raw_output}"
458
 
 
459
  def update_interface(task_selection):
460
  """Update interface based on task selection"""
461
  config = DEMO_TASK_CONFIGS.get(task_selection, {})
@@ -580,8 +596,8 @@ def update_prompt_preview(
580
 
581
  # Parse visual prompts
582
  visual_prompt_boxes = []
583
- if "points" in visual_prompt_data:
584
- visual_prompt_boxes = parse_visual_prompt(visual_prompt_data["points"])
585
 
586
  # Generate prompt preview
587
  prompt = get_task_prompt(
@@ -697,7 +713,7 @@ def create_demo():
697
  with visual_prompt_tab:
698
  gr.Markdown("### 🎯 Visual Prompt Configuration")
699
  gr.Markdown(
700
- "Draw bounding boxes on the image to provide visual examples"
701
  )
702
 
703
  # Prompt Preview
@@ -735,10 +751,9 @@ def create_demo():
735
  )
736
 
737
  # Visual Prompt Interface (only visible for Visual Prompting task)
738
- visual_prompter = ImagePrompter(
739
  label="🎯 Visual Prompt Interface",
740
- width=420,
741
- height=315, # 4:3 aspect ratio (420 * 3/4 = 315)
742
  visible=False,
743
  elem_classes=["preserve-aspect-ratio"],
744
  )
@@ -857,15 +872,30 @@ def create_demo():
857
  show_labels,
858
  custom_color,
859
  ):
860
- # For Visual Prompting task, use the visual prompter image
861
  if task_selection == "Visual Prompting":
862
- if visual_prompter_data is not None and "image" in visual_prompter_data:
863
- image_to_use = visual_prompter_data["image"]
864
- else:
 
 
865
  return (
866
  None,
867
- "Please upload an image in the Visual Prompt Interface for Visual Prompting task.",
868
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
869
  else:
870
  image_to_use = input_image
871
 
 
1
  #!/usr/bin/env python
2
  # -*- coding: utf-8 -*-
 
3
  import argparse
4
  import json
5
  import os
6
 
7
+ import spaces
8
 
9
+ os.system(
10
+ "pip install torch==2.4.0 torchvision==0.18.0 --index-url https://download.pytorch.org/whl/cu124"
11
+ )
12
+ os.system("pip install gradio_bbox_annotator")
13
  import subprocess
 
14
 
15
+ subprocess.run(
16
+ "pip install flash-attn==2.7.4.post1 --no-build-isolation",
17
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
18
+ shell=True,
19
+ )
20
+
21
+ import re
22
  import sys
23
  import threading
 
24
  from typing import Any, Dict, List
25
 
26
  import gradio as gr
27
  import numpy as np
28
+ from gradio_bbox_annotator import BBoxAnnotator
29
  from PIL import Image
 
30
  from rex_omni import RexOmniVisualize, RexOmniWrapper, TaskType
31
  from rex_omni.tasks import KEYPOINT_CONFIGS, TASK_CONFIGS, get_task_config
32
 
 
240
  ]
241
 
242
 
243
+ def parse_visual_prompt(bbox_data) -> List[List[float]]:
244
+ """Parse BBoxAnnotator output to bounding boxes"""
245
+ if bbox_data is None:
246
+ return []
247
+
248
+ try:
249
+ # BBoxAnnotator returns format: (image, boxes_list)
250
+ # where boxes_list contains [x, y, width, height] for each box
251
+ if isinstance(bbox_data, tuple) and len(bbox_data) >= 2:
252
+ boxes_list = bbox_data[1]
253
+ else:
254
+ boxes_list = bbox_data
255
+
256
+ if not boxes_list:
257
+ return []
258
+
259
+ # Convert from [x, y, width, height] to [x1, y1, x2, y2] format
260
+ boxes = []
261
+ for box in boxes_list:
262
+ if len(box) >= 4:
263
+ x1, y1, x2, y2 = box[:4]
264
+ boxes.append([x1, y1, x2, y2])
265
+
266
+ return boxes
267
+ except Exception as e:
268
+ print(f"Error parsing visual prompt: {e}")
269
+ return []
270
 
271
 
272
  def convert_boxes_to_visual_prompt_format(
 
361
  else:
362
  return task_config.prompt_template.replace("{categories}", "objects")
363
 
364
+
365
  @spaces.GPU
366
  def run_inference(
367
  image,
 
380
  if image is None:
381
  return None, "Please upload an image first."
382
 
 
383
  # Convert numpy array to PIL Image if needed
384
  if isinstance(image, np.ndarray):
385
  image = Image.fromarray(image)
 
392
  # Check if we have predefined visual prompt boxes from examples
393
  if hasattr(image, "_example_visual_prompts"):
394
  visual_prompt_boxes = image._example_visual_prompts
395
+ elif visual_prompt_data is not None:
396
+ visual_prompt_boxes = parse_visual_prompt(visual_prompt_data)
397
 
398
  # Determine task type and categories based on task selection
399
  if task_selection == "OCR":
 
423
  task_key = task_type.value
424
 
425
  # Split categories by comma and clean up
426
+ categories_list = [cat.strip() for cat in categories.split(",") if cat.strip()]
 
 
427
  if not categories_list:
428
  categories_list = ["object"]
429
 
 
471
  except Exception as e:
472
  return image, f"Visualization failed: {str(e)}\n\nRaw output:\n{raw_output}"
473
 
474
+
475
  def update_interface(task_selection):
476
  """Update interface based on task selection"""
477
  config = DEMO_TASK_CONFIGS.get(task_selection, {})
 
596
 
597
  # Parse visual prompts
598
  visual_prompt_boxes = []
599
+ if visual_prompt_data is not None:
600
+ visual_prompt_boxes = parse_visual_prompt(visual_prompt_data)
601
 
602
  # Generate prompt preview
603
  prompt = get_task_prompt(
 
713
  with visual_prompt_tab:
714
  gr.Markdown("### 🎯 Visual Prompt Configuration")
715
  gr.Markdown(
716
+ "Select the pen tool and draw one or multiple boxes on the image. "
717
  )
718
 
719
  # Prompt Preview
 
751
  )
752
 
753
  # Visual Prompt Interface (only visible for Visual Prompting task)
754
+ visual_prompter = BBoxAnnotator(
755
  label="🎯 Visual Prompt Interface",
756
+ categories="D",
 
757
  visible=False,
758
  elem_classes=["preserve-aspect-ratio"],
759
  )
 
872
  show_labels,
873
  custom_color,
874
  ):
875
+ # For Visual Prompting task, extract image from BBoxAnnotator data
876
  if task_selection == "Visual Prompting":
877
+ if (
878
+ visual_prompter_data is None
879
+ or not isinstance(visual_prompter_data, tuple)
880
+ or len(visual_prompter_data) < 1
881
+ ):
882
  return (
883
  None,
884
+ "Please upload an image and draw bounding boxes in the Visual Prompt Interface for Visual Prompting task.",
885
  )
886
+ # Extract image from BBoxAnnotator data (first element of the tuple)
887
+ image_to_use = visual_prompter_data[0]
888
+ # If image_to_use is a string (file path), convert to PIL Image
889
+ if isinstance(image_to_use, str):
890
+ try:
891
+ from PIL import Image
892
+
893
+ image_to_use = Image.open(image_to_use).convert("RGB")
894
+ except Exception as e:
895
+ return (
896
+ None,
897
+ f"Error loading image from path: {e}",
898
+ )
899
  else:
900
  image_to_use = input_image
901