Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| from pathlib import Path | |
| import numpy as np | |
| import yaml | |
| from demo.demo_utils import DotDict, concat_instances, filter_instances, pose_nms, visualize_demo | |
| from demo.mm_utils import run_MMDetector, run_MMPose | |
| from mmdet.apis import init_detector | |
| from demo.sam2_utils import prepare_model as prepare_sam2_model | |
| from demo.sam2_utils import process_image_with_SAM | |
| from mmpose.apis import init_model as init_pose_estimator | |
| from mmpose.utils import adapt_mmdet_pipeline | |
| # Default thresholds | |
| DEFAULT_CAT_ID: int = 0 | |
| DEFAULT_BBOX_THR: float = 0.3 | |
| DEFAULT_NMS_THR: float = 0.3 | |
| DEFAULT_KPT_THR: float = 0.3 | |
| # Global models variable | |
| det_model = None | |
| pose_model = None | |
| sam2_model = None | |
| def _parse_yaml_config(yaml_path: Path) -> DotDict: | |
| """ | |
| Load BMP configuration from a YAML file. | |
| Args: | |
| yaml_path (Path): Path to YAML config. | |
| Returns: | |
| DotDict: Nested config dictionary. | |
| """ | |
| with open(yaml_path, "r") as f: | |
| cfg = yaml.safe_load(f) | |
| return DotDict(cfg) | |
| def load_models(bmp_config): | |
| device = 'cuda:0' | |
| global det_model, pose_model, sam2_model | |
| # build detectors | |
| det_model = init_detector(bmp_config.detector.det_config, bmp_config.detector.det_checkpoint, device='cpu') # Detect with CPU because of installation issues on HF | |
| det_model.cfg = adapt_mmdet_pipeline(det_model.cfg) | |
| # build pose estimator | |
| pose_model = init_pose_estimator( | |
| bmp_config.pose_estimator.pose_config, | |
| bmp_config.pose_estimator.pose_checkpoint, | |
| device=device, | |
| cfg_options=dict(model=dict(test_cfg=dict(output_heatmaps=False))), | |
| ) | |
| sam2_model = prepare_sam2_model( | |
| model_cfg=bmp_config.sam2.sam2_config, | |
| model_checkpoint=bmp_config.sam2.sam2_checkpoint, | |
| ) | |
| return det_model, pose_model, sam2_model | |
| def process_image_with_BMP( | |
| img: np.ndarray | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Run the full BMP pipeline on a single image: detection, pose, SAM mask refinement, and visualization. | |
| Args: | |
| args (Namespace): Parsed CLI arguments. | |
| bmp_config (DotDict): Configuration parameters. | |
| img_path (Path): Path to the input image. | |
| detector: Primary MMDetection model. | |
| detector_prime: Secondary MMDetection model for iterations. | |
| pose_estimator: MMPose model for keypoint estimation. | |
| sam2_model: SAM model for mask refinement. | |
| Returns: | |
| InstanceData: Final merged detections and refined masks. | |
| """ | |
| bmp_config = _parse_yaml_config(Path("configs/bmp_D3.yaml")) | |
| load_models(bmp_config) | |
| # img: RGB -> BGR | |
| img = img[..., ::-1] | |
| img_for_detection = img.copy() | |
| rtmdet_result = None | |
| all_detections = None | |
| for iteration in range(bmp_config.num_bmp_iters): | |
| # Step 1: Detection | |
| det_instances = run_MMDetector( | |
| det_model, | |
| img_for_detection, | |
| det_cat_id=DEFAULT_CAT_ID, | |
| bbox_thr=DEFAULT_BBOX_THR, | |
| nms_thr=DEFAULT_NMS_THR, | |
| ) | |
| if len(det_instances.bboxes) == 0: | |
| continue | |
| # Step 2: Pose estimation | |
| pose_instances = run_MMPose( | |
| pose_model, | |
| img.copy(), | |
| detections=det_instances, | |
| kpt_thr=DEFAULT_KPT_THR, | |
| ) | |
| # Restrict to first 17 COCO keypoints | |
| pose_instances.keypoints = pose_instances.keypoints[:, :17, :] | |
| pose_instances.keypoint_scores = pose_instances.keypoint_scores[:, :17] | |
| pose_instances.keypoints = np.concatenate( | |
| [pose_instances.keypoints, pose_instances.keypoint_scores[:, :, None]], axis=-1 | |
| ) | |
| # Step 3: Pose-NMS and SAM refinement | |
| all_keypoints = ( | |
| pose_instances.keypoints | |
| if all_detections is None | |
| else np.concatenate([all_detections.keypoints, pose_instances.keypoints], axis=0) | |
| ) | |
| all_bboxes = ( | |
| pose_instances.bboxes | |
| if all_detections is None | |
| else np.concatenate([all_detections.bboxes, pose_instances.bboxes], axis=0) | |
| ) | |
| num_valid_kpts = np.sum(all_keypoints[:, :, 2] > bmp_config.sam2.prompting.confidence_thr, axis=1) | |
| keep_indices = pose_nms( | |
| DotDict({"confidence_thr": bmp_config.sam2.prompting.confidence_thr, "oks_thr": bmp_config.oks_nms_thr}), | |
| image_kpts=all_keypoints, | |
| image_bboxes=all_bboxes, | |
| num_valid_kpts=num_valid_kpts, | |
| ) | |
| keep_indices = sorted(keep_indices) # Sort by original index | |
| num_old_detections = 0 if all_detections is None else len(all_detections.bboxes) | |
| keep_new_indices = [i - num_old_detections for i in keep_indices if i >= num_old_detections] | |
| keep_old_indices = [i for i in keep_indices if i < num_old_detections] | |
| if len(keep_new_indices) == 0: | |
| continue | |
| # filter new detections and compute scores | |
| new_dets = filter_instances(pose_instances, keep_new_indices) | |
| new_dets.scores = pose_instances.keypoint_scores[keep_new_indices].mean(axis=-1) | |
| old_dets = None | |
| if len(keep_old_indices) > 0: | |
| old_dets = filter_instances(all_detections, keep_old_indices) | |
| new_detections = process_image_with_SAM( | |
| DotDict(bmp_config.sam2.prompting), | |
| img.copy(), | |
| sam2_model, | |
| new_dets, | |
| old_dets if old_dets is not None else None, | |
| ) | |
| # Merge detections | |
| if all_detections is None: | |
| all_detections = new_detections | |
| else: | |
| all_detections = concat_instances(all_detections, new_dets) | |
| # Step 4: Visualization | |
| img_for_detection, rtmdet_r, _ = visualize_demo( | |
| img.copy(), | |
| all_detections, | |
| ) | |
| if iteration == 0: | |
| rtmdet_result = rtmdet_r | |
| _, _, bmp_result = visualize_demo( | |
| img.copy(), | |
| all_detections, | |
| ) | |
| # img: BGR -> RGB | |
| rtmdet_result = rtmdet_result[..., ::-1] | |
| bmp_result = bmp_result[..., ::-1] | |
| return rtmdet_result, bmp_result | |
| with gr.Blocks() as app: | |
| gr.Markdown("# BBoxMaskPose Image Demo") | |
| gr.Markdown("### [M. Purkrabek](https://mirapurkrabek.github.io/), [J. Matas](https://cmp.felk.cvut.cz/~matas/)") | |
| gr.Markdown( | |
| "Official demo for paper **Detection, Pose Estimation and Segmentation for Multiple Bodies: Closing the Virtuous Circle.** [ICCV 2025]" | |
| ) | |
| gr.Markdown( | |
| "For details, see the [project website](https://mirapurkrabek.github.io/BBox-Mask-Pose/) or [arXiv paper](https://arxiv.org/abs/2412.01562). " | |
| "The demo showcases the capabilities of the BBoxMaskPose framework on any image. " | |
| "If you want to play around with parameters, use the [GitHub demo](https://github.com/MiraPurkrabek/BBoxMaskPose). " | |
| "Please note that due to HuggingFace restrictions, the demo runs much slower than the GitHub implementation." | |
| ) | |
| gr.Markdown( | |
| "If you find the project interesting, please like ❤️ the HF demo and star ⭐ the GH repo to help us spread the word." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| original_image_input = gr.Image(type="numpy", label="Original Image") | |
| submit_button = gr.Button("Run Inference") | |
| with gr.Column(): | |
| output_standard = gr.Image(type="numpy", label="RTMDet-L + MaskPose-B") | |
| with gr.Column(): | |
| output_sahi_sliced = gr.Image(type="numpy", label="BBoxMaskPose 2x") | |
| gr.Examples( | |
| label="In-the-Wild Examples", | |
| examples=[ | |
| ["examples/prochazka_MMA.jpg"], | |
| ["examples/riner_judo.jpg"], | |
| ["examples/tackle3.jpg"], | |
| ["examples/tackle1.jpg"], | |
| ["examples/tackle2.jpg"], | |
| ["examples/tackle5.jpg"], | |
| ["examples/SKV_example1.jpg"], | |
| ["examples/SKV_example2.jpg"], | |
| ["examples/SKV_example3.jpg"], | |
| ["examples/SKV_example4.jpg"], | |
| ], | |
| inputs=[ | |
| original_image_input, | |
| ], | |
| outputs=[output_standard, output_sahi_sliced], | |
| fn=process_image_with_BMP, | |
| cache_examples=True, | |
| ) | |
| gr.Examples( | |
| label="OCHuman Examples", | |
| examples=[ | |
| ["examples/004806.jpg"], | |
| ["examples/005056.jpg"], | |
| ["examples/004981.jpg"], | |
| ["examples/004655.jpg"], | |
| ["examples/004684.jpg"], | |
| ["examples/004974.jpg"], | |
| ["examples/004983.jpg"], | |
| ["examples/005017.jpg"], | |
| ["examples/004849.jpg"], | |
| ["examples/000105.jpg"], | |
| ], | |
| inputs=[ | |
| original_image_input, | |
| ], | |
| outputs=[output_standard, output_sahi_sliced], | |
| fn=process_image_with_BMP, | |
| cache_examples=True, | |
| ) | |
| gr.Examples( | |
| label="Failure Cases", | |
| examples=[ | |
| ["examples/SKV_example_F1.jpg"], | |
| ["examples/tackle4.jpg"], | |
| ["examples/000061.jpg"], | |
| ["examples/000141.jpg"], | |
| ["examples/000287.jpg"], | |
| ], | |
| inputs=[ | |
| original_image_input, | |
| ], | |
| outputs=[output_standard, output_sahi_sliced], | |
| fn=process_image_with_BMP, | |
| cache_examples=True, | |
| ) | |
| submit_button.click( | |
| fn=process_image_with_BMP, | |
| inputs=[ | |
| original_image_input, | |
| ], | |
| outputs=[output_standard, output_sahi_sliced], | |
| ) | |
| # Launch the demo | |
| app.launch() |