Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from copy import deepcopy | |
| from typing import Union | |
| import mmcv | |
| import numpy as np | |
| from mmengine.structures import InstanceData | |
| from mmpose.datasets.datasets.utils import parse_pose_metainfo | |
| from mmpose.structures import PoseDataSample | |
| from mmpose.visualization import PoseLocalVisualizer | |
| # from posevis import pose_visualization | |
| # def visualize( | |
| # img: Union[np.ndarray, str], | |
| # keypoints: np.ndarray, | |
| # keypoint_score: np.ndarray = None, | |
| # metainfo: Union[str, dict] = None, | |
| # visualizer: PoseLocalVisualizer = None, | |
| # show_kpt_idx: bool = False, | |
| # skeleton_style: str = 'mmpose', | |
| # show: bool = False, | |
| # kpt_thr: float = 0.3, | |
| # ): | |
| # """Visualize 2d keypoints on an image. | |
| # Args: | |
| # img (str | np.ndarray): The image to be displayed. | |
| # keypoints (np.ndarray): The keypoint to be displayed. | |
| # keypoint_score (np.ndarray): The score of each keypoint. | |
| # metainfo (str | dict): The metainfo of dataset. | |
| # visualizer (PoseLocalVisualizer): The visualizer. | |
| # show_kpt_idx (bool): Whether to show the index of keypoints. | |
| # skeleton_style (str): Skeleton style. Options are 'mmpose' and | |
| # 'openpose'. | |
| # show (bool): Whether to show the image. | |
| # wait_time (int): Value of waitKey param. | |
| # kpt_thr (float): Keypoint threshold. | |
| # """ | |
| # kpts = keypoints.reshape(-1, 2) | |
| # kpts = np.concatenate([kpts, keypoint_score[:, None]], axis=1) | |
| # kpts[kpts[:, 2] < kpt_thr, :] = 0 | |
| # pose_results = [{ | |
| # 'keypoints': kpts, | |
| # }] | |
| # img = pose_visualization( | |
| # img, | |
| # pose_results, | |
| # format="COCO", | |
| # greyness=1.0, | |
| # show_markers=True, | |
| # show_bones=True, | |
| # line_type="solid", | |
| # width_multiplier=1.0, | |
| # bbox_width_multiplier=1.0, | |
| # show_bbox=False, | |
| # differ_individuals=False, | |
| # ) | |
| # return img | |
| def visualize( | |
| img: Union[np.ndarray, str], | |
| keypoints: np.ndarray, | |
| keypoint_score: np.ndarray = None, | |
| metainfo: Union[str, dict] = None, | |
| visualizer: PoseLocalVisualizer = None, | |
| show_kpt_idx: bool = False, | |
| skeleton_style: str = 'mmpose', | |
| show: bool = False, | |
| kpt_thr: float = 0.3, | |
| ): | |
| """Visualize 2d keypoints on an image. | |
| Args: | |
| img (str | np.ndarray): The image to be displayed. | |
| keypoints (np.ndarray): The keypoint to be displayed. | |
| keypoint_score (np.ndarray): The score of each keypoint. | |
| metainfo (str | dict): The metainfo of dataset. | |
| visualizer (PoseLocalVisualizer): The visualizer. | |
| show_kpt_idx (bool): Whether to show the index of keypoints. | |
| skeleton_style (str): Skeleton style. Options are 'mmpose' and | |
| 'openpose'. | |
| show (bool): Whether to show the image. | |
| wait_time (int): Value of waitKey param. | |
| kpt_thr (float): Keypoint threshold. | |
| """ | |
| assert skeleton_style in [ | |
| 'mmpose', 'openpose' | |
| ], (f'Only support skeleton style in {["mmpose", "openpose"]}, ') | |
| if visualizer is None: | |
| visualizer = PoseLocalVisualizer() | |
| else: | |
| visualizer = deepcopy(visualizer) | |
| if isinstance(metainfo, str): | |
| metainfo = parse_pose_metainfo(dict(from_file=metainfo)) | |
| elif isinstance(metainfo, dict): | |
| metainfo = parse_pose_metainfo(metainfo) | |
| if metainfo is not None: | |
| visualizer.set_dataset_meta(metainfo, skeleton_style=skeleton_style) | |
| if isinstance(img, str): | |
| img = mmcv.imread(img, channel_order='rgb') | |
| elif isinstance(img, np.ndarray): | |
| img = mmcv.bgr2rgb(img) | |
| if keypoint_score is None: | |
| keypoint_score = np.ones(keypoints.shape[0]) | |
| tmp_instances = InstanceData() | |
| tmp_instances.keypoints = keypoints | |
| tmp_instances.keypoint_score = keypoint_score | |
| tmp_datasample = PoseDataSample() | |
| tmp_datasample.pred_instances = tmp_instances | |
| visualizer.add_datasample( | |
| 'visualization', | |
| img, | |
| tmp_datasample, | |
| show_kpt_idx=show_kpt_idx, | |
| skeleton_style=skeleton_style, | |
| show=show, | |
| wait_time=0, | |
| kpt_thr=kpt_thr) | |
| return visualizer.get_image() | |