Mountchicken commited on
Commit
a8932c9
·
verified ·
1 Parent(s): 1bd75ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -92
app.py CHANGED
@@ -362,103 +362,99 @@ def run_inference(
362
  if image is None:
363
  return None, "Please upload an image first."
364
 
365
- try:
366
- # Convert numpy array to PIL Image if needed
367
- if isinstance(image, np.ndarray):
368
- image = Image.fromarray(image)
369
-
370
- image_width, image_height = image.size
371
-
372
- # Parse visual prompts if needed
373
- visual_prompt_boxes = []
374
- if task_selection == "Visual Prompting":
375
- # Check if we have predefined visual prompt boxes from examples
376
- if hasattr(image, "_example_visual_prompts"):
377
- visual_prompt_boxes = image._example_visual_prompts
378
- elif visual_prompt_data is not None and "points" in visual_prompt_data:
379
- visual_prompt_boxes = parse_visual_prompt(visual_prompt_data["points"])
380
-
381
- # Determine task type and categories based on task selection
382
- if task_selection == "OCR":
383
- # For OCR, use the selected output format to determine task type
384
- task_type = OCR_OUTPUT_FORMATS[ocr_output_format]["task_type"]
385
- task_key = task_type.value
386
- # Use granularity level to determine categories
387
- categories_list = [OCR_GRANULARITY_LEVELS[ocr_granularity]["categories"]]
388
- elif task_selection == "Visual Prompting":
389
- # For visual prompting, we don't need explicit categories
390
- task_key = "visual_prompting"
391
- categories_list = ["object"]
392
 
393
- # Check if visual prompt boxes are provided
394
- if not visual_prompt_boxes:
395
- return (
396
- None,
397
- "Please draw bounding boxes on the image to provide visual examples for Visual Prompting task.",
398
- )
399
- elif task_selection == "Keypoint":
400
- task_key = "keypoint"
401
- categories_list = [keypoint_type] if keypoint_type else ["person"]
402
- else:
403
- # For other tasks, get task type from demo config
404
- demo_config = DEMO_TASK_CONFIGS[task_selection]
405
- task_type = demo_config["task_type"]
406
- task_key = task_type.value
407
-
408
- # Split categories by comma and clean up
409
- categories_list = [
410
- cat.strip() for cat in categories.split(",") if cat.strip()
411
- ]
412
- if not categories_list:
413
- categories_list = ["object"]
414
-
415
- # Run inference
416
- if task_selection == "Visual Prompting":
417
- results = rex_model.inference(
418
- images=image,
419
- task=task_key,
420
- categories=categories_list,
421
- visual_prompt_boxes=visual_prompt_boxes,
422
- )
423
- elif task_selection == "Keypoint":
424
- results = rex_model.inference(
425
- images=image,
426
- task=task_key,
427
- categories=categories_list,
428
- keypoint_type=keypoint_type if keypoint_type else "person",
429
- )
430
- else:
431
- results = rex_model.inference(
432
- images=image, task=task_key, categories=categories_list
433
- )
434
 
435
- result = results[0]
436
-
437
- # Check if inference was successful
438
- if not result.get("success", False):
439
- error_msg = result.get("error", "Unknown error occurred during inference")
440
- return None, f"Inference failed: {error_msg}"
441
-
442
- # Get predictions and raw output
443
- predictions = result["extracted_predictions"]
444
- raw_output = result["raw_output"]
445
-
446
- # Create visualization
447
- try:
448
- vis_image = RexOmniVisualize(
449
- image=image,
450
- predictions=predictions,
451
- font_size=font_size,
452
- draw_width=draw_width,
453
- show_labels=show_labels,
 
 
 
 
 
 
 
 
 
454
  )
455
- return vis_image, raw_output
456
- except Exception as e:
457
- return image, f"Visualization failed: {str(e)}\n\nRaw output:\n{raw_output}"
 
 
 
 
 
458
 
459
- except Exception as e:
460
- return None, f"Error during inference: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
 
 
 
 
 
 
 
 
 
 
 
 
462
 
463
  def update_interface(task_selection):
464
  """Update interface based on task selection"""
 
362
  if image is None:
363
  return None, "Please upload an image first."
364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
+ # Convert numpy array to PIL Image if needed
367
+ if isinstance(image, np.ndarray):
368
+ image = Image.fromarray(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
370
+ image_width, image_height = image.size
371
+
372
+ # Parse visual prompts if needed
373
+ visual_prompt_boxes = []
374
+ if task_selection == "Visual Prompting":
375
+ # Check if we have predefined visual prompt boxes from examples
376
+ if hasattr(image, "_example_visual_prompts"):
377
+ visual_prompt_boxes = image._example_visual_prompts
378
+ elif visual_prompt_data is not None and "points" in visual_prompt_data:
379
+ visual_prompt_boxes = parse_visual_prompt(visual_prompt_data["points"])
380
+
381
+ # Determine task type and categories based on task selection
382
+ if task_selection == "OCR":
383
+ # For OCR, use the selected output format to determine task type
384
+ task_type = OCR_OUTPUT_FORMATS[ocr_output_format]["task_type"]
385
+ task_key = task_type.value
386
+ # Use granularity level to determine categories
387
+ categories_list = [OCR_GRANULARITY_LEVELS[ocr_granularity]["categories"]]
388
+ elif task_selection == "Visual Prompting":
389
+ # For visual prompting, we don't need explicit categories
390
+ task_key = "visual_prompting"
391
+ categories_list = ["object"]
392
+
393
+ # Check if visual prompt boxes are provided
394
+ if not visual_prompt_boxes:
395
+ return (
396
+ None,
397
+ "Please draw bounding boxes on the image to provide visual examples for Visual Prompting task.",
398
  )
399
+ elif task_selection == "Keypoint":
400
+ task_key = "keypoint"
401
+ categories_list = [keypoint_type] if keypoint_type else ["person"]
402
+ else:
403
+ # For other tasks, get task type from demo config
404
+ demo_config = DEMO_TASK_CONFIGS[task_selection]
405
+ task_type = demo_config["task_type"]
406
+ task_key = task_type.value
407
 
408
+ # Split categories by comma and clean up
409
+ categories_list = [
410
+ cat.strip() for cat in categories.split(",") if cat.strip()
411
+ ]
412
+ if not categories_list:
413
+ categories_list = ["object"]
414
+
415
+ # Run inference
416
+ if task_selection == "Visual Prompting":
417
+ results = rex_model.inference(
418
+ images=image,
419
+ task=task_key,
420
+ categories=categories_list,
421
+ visual_prompt_boxes=visual_prompt_boxes,
422
+ )
423
+ elif task_selection == "Keypoint":
424
+ results = rex_model.inference(
425
+ images=image,
426
+ task=task_key,
427
+ categories=categories_list,
428
+ keypoint_type=keypoint_type if keypoint_type else "person",
429
+ )
430
+ else:
431
+ results = rex_model.inference(
432
+ images=image, task=task_key, categories=categories_list
433
+ )
434
+
435
+ result = results[0]
436
+
437
+ # Check if inference was successful
438
+ if not result.get("success", False):
439
+ error_msg = result.get("error", "Unknown error occurred during inference")
440
+ return None, f"Inference failed: {error_msg}"
441
+
442
+ # Get predictions and raw output
443
+ predictions = result["extracted_predictions"]
444
+ raw_output = result["raw_output"]
445
 
446
+ # Create visualization
447
+ try:
448
+ vis_image = RexOmniVisualize(
449
+ image=image,
450
+ predictions=predictions,
451
+ font_size=font_size,
452
+ draw_width=draw_width,
453
+ show_labels=show_labels,
454
+ )
455
+ return vis_image, raw_output
456
+ except Exception as e:
457
+ return image, f"Visualization failed: {str(e)}\n\nRaw output:\n{raw_output}"
458
 
459
  def update_interface(task_selection):
460
  """Update interface based on task selection"""