Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import logging | |
| import mimetypes | |
| import os | |
| import time | |
| from argparse import ArgumentParser | |
| from functools import partial | |
| import cv2 | |
| import json_tricks as json | |
| import mmcv | |
| import mmengine | |
| import numpy as np | |
| from mmengine.logging import print_log | |
| from mmpose.apis import (_track_by_iou, _track_by_oks, | |
| convert_keypoint_definition, extract_pose_sequence, | |
| inference_pose_lifter_model, inference_topdown, | |
| init_model) | |
| from mmpose.models.pose_estimators import PoseLifter | |
| from mmpose.models.pose_estimators.topdown import TopdownPoseEstimator | |
| from mmpose.registry import VISUALIZERS | |
| from mmpose.structures import (PoseDataSample, merge_data_samples, | |
| split_instances) | |
| from mmpose.utils import adapt_mmdet_pipeline | |
| try: | |
| from mmdet.apis import inference_detector, init_detector | |
| has_mmdet = True | |
| except (ImportError, ModuleNotFoundError): | |
| has_mmdet = False | |
| def parse_args(): | |
| parser = ArgumentParser() | |
| parser.add_argument('det_config', help='Config file for detection') | |
| parser.add_argument('det_checkpoint', help='Checkpoint file for detection') | |
| parser.add_argument( | |
| 'pose_estimator_config', | |
| type=str, | |
| default=None, | |
| help='Config file for the 1st stage 2D pose estimator') | |
| parser.add_argument( | |
| 'pose_estimator_checkpoint', | |
| type=str, | |
| default=None, | |
| help='Checkpoint file for the 1st stage 2D pose estimator') | |
| parser.add_argument( | |
| 'pose_lifter_config', | |
| help='Config file for the 2nd stage pose lifter model') | |
| parser.add_argument( | |
| 'pose_lifter_checkpoint', | |
| help='Checkpoint file for the 2nd stage pose lifter model') | |
| parser.add_argument('--input', type=str, default='', help='Video path') | |
| parser.add_argument( | |
| '--show', | |
| action='store_true', | |
| default=False, | |
| help='Whether to show visualizations') | |
| parser.add_argument( | |
| '--disable-rebase-keypoint', | |
| action='store_true', | |
| default=False, | |
| help='Whether to disable rebasing the predicted 3D pose so its ' | |
| 'lowest keypoint has a height of 0 (landing on the ground). Rebase ' | |
| 'is useful for visualization when the model do not predict the ' | |
| 'global position of the 3D pose.') | |
| parser.add_argument( | |
| '--disable-norm-pose-2d', | |
| action='store_true', | |
| default=False, | |
| help='Whether to scale the bbox (along with the 2D pose) to the ' | |
| 'average bbox scale of the dataset, and move the bbox (along with the ' | |
| '2D pose) to the average bbox center of the dataset. This is useful ' | |
| 'when bbox is small, especially in multi-person scenarios.') | |
| parser.add_argument( | |
| '--num-instances', | |
| type=int, | |
| default=1, | |
| help='The number of 3D poses to be visualized in every frame. If ' | |
| 'less than 0, it will be set to the number of pose results in the ' | |
| 'first frame.') | |
| parser.add_argument( | |
| '--output-root', | |
| type=str, | |
| default='', | |
| help='Root of the output video file. ' | |
| 'Default not saving the visualization video.') | |
| parser.add_argument( | |
| '--save-predictions', | |
| action='store_true', | |
| default=False, | |
| help='Whether to save predicted results') | |
| parser.add_argument( | |
| '--device', default='cuda:0', help='Device used for inference') | |
| parser.add_argument( | |
| '--det-cat-id', | |
| type=int, | |
| default=0, | |
| help='Category id for bounding box detection model') | |
| parser.add_argument( | |
| '--bbox-thr', | |
| type=float, | |
| default=0.3, | |
| help='Bounding box score threshold') | |
| parser.add_argument('--kpt-thr', type=float, default=0.3) | |
| parser.add_argument( | |
| '--use-oks-tracking', action='store_true', help='Using OKS tracking') | |
| parser.add_argument( | |
| '--tracking-thr', type=float, default=0.3, help='Tracking threshold') | |
| parser.add_argument( | |
| '--show-interval', type=int, default=0, help='Sleep seconds per frame') | |
| parser.add_argument( | |
| '--thickness', | |
| type=int, | |
| default=1, | |
| help='Link thickness for visualization') | |
| parser.add_argument( | |
| '--radius', | |
| type=int, | |
| default=3, | |
| help='Keypoint radius for visualization') | |
| parser.add_argument( | |
| '--online', | |
| action='store_true', | |
| default=False, | |
| help='Inference mode. If set to True, can not use future frame' | |
| 'information when using multi frames for inference in the 2D pose' | |
| 'detection stage. Default: False.') | |
| args = parser.parse_args() | |
| return args | |
| def process_one_image(args, detector, frame, frame_idx, pose_estimator, | |
| pose_est_results_last, pose_est_results_list, next_id, | |
| pose_lifter, visualize_frame, visualizer): | |
| """Visualize detected and predicted keypoints of one image. | |
| Pipeline of this function: | |
| frame | |
| | | |
| V | |
| +-----------------+ | |
| | detector | | |
| +-----------------+ | |
| | det_result | |
| V | |
| +-----------------+ | |
| | pose_estimator | | |
| +-----------------+ | |
| | pose_est_results | |
| V | |
| +--------------------------------------------+ | |
| | convert 2d kpts into pose-lifting format | | |
| +--------------------------------------------+ | |
| | pose_est_results_list | |
| V | |
| +-----------------------+ | |
| | extract_pose_sequence | | |
| +-----------------------+ | |
| | pose_seq_2d | |
| V | |
| +-------------+ | |
| | pose_lifter | | |
| +-------------+ | |
| | pose_lift_results | |
| V | |
| +-----------------+ | |
| | post-processing | | |
| +-----------------+ | |
| | pred_3d_data_samples | |
| V | |
| +------------+ | |
| | visualizer | | |
| +------------+ | |
| Args: | |
| args (Argument): Custom command-line arguments. | |
| detector (mmdet.BaseDetector): The mmdet detector. | |
| frame (np.ndarray): The image frame read from input image or video. | |
| frame_idx (int): The index of current frame. | |
| pose_estimator (TopdownPoseEstimator): The pose estimator for 2d pose. | |
| pose_est_results_last (list(PoseDataSample)): The results of pose | |
| estimation from the last frame for tracking instances. | |
| pose_est_results_list (list(list(PoseDataSample))): The list of all | |
| pose estimation results converted by | |
| ``convert_keypoint_definition`` from previous frames. In | |
| pose-lifting stage it is used to obtain the 2d estimation sequence. | |
| next_id (int): The next track id to be used. | |
| pose_lifter (PoseLifter): The pose-lifter for estimating 3d pose. | |
| visualize_frame (np.ndarray): The image for drawing the results on. | |
| visualizer (Visualizer): The visualizer for visualizing the 2d and 3d | |
| pose estimation results. | |
| Returns: | |
| pose_est_results (list(PoseDataSample)): The pose estimation result of | |
| the current frame. | |
| pose_est_results_list (list(list(PoseDataSample))): The list of all | |
| converted pose estimation results until the current frame. | |
| pred_3d_instances (InstanceData): The result of pose-lifting. | |
| Specifically, the predicted keypoints and scores are saved at | |
| ``pred_3d_instances.keypoints`` and | |
| ``pred_3d_instances.keypoint_scores``. | |
| next_id (int): The next track id to be used. | |
| """ | |
| pose_lift_dataset = pose_lifter.cfg.test_dataloader.dataset | |
| pose_lift_dataset_name = pose_lifter.dataset_meta['dataset_name'] | |
| # First stage: conduct 2D pose detection in a Topdown manner | |
| # use detector to obtain person bounding boxes | |
| det_result = inference_detector(detector, frame) | |
| pred_instance = det_result.pred_instances.cpu().numpy() | |
| # filter out the person instances with category and bbox threshold | |
| # e.g. 0 for person in COCO | |
| bboxes = pred_instance.bboxes | |
| bboxes = bboxes[np.logical_and(pred_instance.labels == args.det_cat_id, | |
| pred_instance.scores > args.bbox_thr)] | |
| # estimate pose results for current image | |
| pose_est_results = inference_topdown(pose_estimator, frame, bboxes) | |
| if args.use_oks_tracking: | |
| _track = partial(_track_by_oks) | |
| else: | |
| _track = _track_by_iou | |
| pose_det_dataset_name = pose_estimator.dataset_meta['dataset_name'] | |
| pose_est_results_converted = [] | |
| # convert 2d pose estimation results into the format for pose-lifting | |
| # such as changing the keypoint order, flipping the keypoint, etc. | |
| for i, data_sample in enumerate(pose_est_results): | |
| pred_instances = data_sample.pred_instances.cpu().numpy() | |
| keypoints = pred_instances.keypoints | |
| # calculate area and bbox | |
| if 'bboxes' in pred_instances: | |
| areas = np.array([(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) | |
| for bbox in pred_instances.bboxes]) | |
| pose_est_results[i].pred_instances.set_field(areas, 'areas') | |
| else: | |
| areas, bboxes = [], [] | |
| for keypoint in keypoints: | |
| xmin = np.min(keypoint[:, 0][keypoint[:, 0] > 0], initial=1e10) | |
| xmax = np.max(keypoint[:, 0]) | |
| ymin = np.min(keypoint[:, 1][keypoint[:, 1] > 0], initial=1e10) | |
| ymax = np.max(keypoint[:, 1]) | |
| areas.append((xmax - xmin) * (ymax - ymin)) | |
| bboxes.append([xmin, ymin, xmax, ymax]) | |
| pose_est_results[i].pred_instances.areas = np.array(areas) | |
| pose_est_results[i].pred_instances.bboxes = np.array(bboxes) | |
| # track id | |
| track_id, pose_est_results_last, _ = _track(data_sample, | |
| pose_est_results_last, | |
| args.tracking_thr) | |
| if track_id == -1: | |
| if np.count_nonzero(keypoints[:, :, 1]) >= 3: | |
| track_id = next_id | |
| next_id += 1 | |
| else: | |
| # If the number of keypoints detected is small, | |
| # delete that person instance. | |
| keypoints[:, :, 1] = -10 | |
| pose_est_results[i].pred_instances.set_field( | |
| keypoints, 'keypoints') | |
| pose_est_results[i].pred_instances.set_field( | |
| pred_instances.bboxes * 0, 'bboxes') | |
| pose_est_results[i].set_field(pred_instances, 'pred_instances') | |
| track_id = -1 | |
| pose_est_results[i].set_field(track_id, 'track_id') | |
| # convert keypoints for pose-lifting | |
| pose_est_result_converted = PoseDataSample() | |
| pose_est_result_converted.set_field( | |
| pose_est_results[i].pred_instances.clone(), 'pred_instances') | |
| pose_est_result_converted.set_field( | |
| pose_est_results[i].gt_instances.clone(), 'gt_instances') | |
| keypoints = convert_keypoint_definition(keypoints, | |
| pose_det_dataset_name, | |
| pose_lift_dataset_name) | |
| pose_est_result_converted.pred_instances.set_field( | |
| keypoints, 'keypoints') | |
| pose_est_result_converted.set_field(pose_est_results[i].track_id, | |
| 'track_id') | |
| pose_est_results_converted.append(pose_est_result_converted) | |
| pose_est_results_list.append(pose_est_results_converted.copy()) | |
| # Second stage: Pose lifting | |
| # extract and pad input pose2d sequence | |
| pose_seq_2d = extract_pose_sequence( | |
| pose_est_results_list, | |
| frame_idx=frame_idx, | |
| causal=pose_lift_dataset.get('causal', False), | |
| seq_len=pose_lift_dataset.get('seq_len', 1), | |
| step=pose_lift_dataset.get('seq_step', 1)) | |
| # conduct 2D-to-3D pose lifting | |
| norm_pose_2d = not args.disable_norm_pose_2d | |
| pose_lift_results = inference_pose_lifter_model( | |
| pose_lifter, | |
| pose_seq_2d, | |
| image_size=visualize_frame.shape[:2], | |
| norm_pose_2d=norm_pose_2d) | |
| # post-processing | |
| for idx, pose_lift_result in enumerate(pose_lift_results): | |
| pose_lift_result.track_id = pose_est_results[idx].get('track_id', 1e4) | |
| pred_instances = pose_lift_result.pred_instances | |
| keypoints = pred_instances.keypoints | |
| keypoint_scores = pred_instances.keypoint_scores | |
| if keypoint_scores.ndim == 3: | |
| keypoint_scores = np.squeeze(keypoint_scores, axis=1) | |
| pose_lift_results[ | |
| idx].pred_instances.keypoint_scores = keypoint_scores | |
| if keypoints.ndim == 4: | |
| keypoints = np.squeeze(keypoints, axis=1) | |
| keypoints = keypoints[..., [0, 2, 1]] | |
| keypoints[..., 0] = -keypoints[..., 0] | |
| keypoints[..., 2] = -keypoints[..., 2] | |
| # rebase height (z-axis) | |
| if not args.disable_rebase_keypoint: | |
| keypoints[..., 2] -= np.min( | |
| keypoints[..., 2], axis=-1, keepdims=True) | |
| pose_lift_results[idx].pred_instances.keypoints = keypoints | |
| pose_lift_results = sorted( | |
| pose_lift_results, key=lambda x: x.get('track_id', 1e4)) | |
| pred_3d_data_samples = merge_data_samples(pose_lift_results) | |
| det_data_sample = merge_data_samples(pose_est_results) | |
| pred_3d_instances = pred_3d_data_samples.get('pred_instances', None) | |
| if args.num_instances < 0: | |
| args.num_instances = len(pose_lift_results) | |
| # Visualization | |
| if visualizer is not None: | |
| visualizer.add_datasample( | |
| 'result', | |
| visualize_frame, | |
| data_sample=pred_3d_data_samples, | |
| det_data_sample=det_data_sample, | |
| draw_gt=False, | |
| dataset_2d=pose_det_dataset_name, | |
| dataset_3d=pose_lift_dataset_name, | |
| show=args.show, | |
| draw_bbox=True, | |
| kpt_thr=args.kpt_thr, | |
| num_instances=args.num_instances, | |
| wait_time=args.show_interval) | |
| return pose_est_results, pose_est_results_list, pred_3d_instances, next_id | |
| def main(): | |
| assert has_mmdet, 'Please install mmdet to run the demo.' | |
| args = parse_args() | |
| assert args.show or (args.output_root != '') | |
| assert args.input != '' | |
| assert args.det_config is not None | |
| assert args.det_checkpoint is not None | |
| detector = init_detector( | |
| args.det_config, args.det_checkpoint, device=args.device.lower()) | |
| detector.cfg = adapt_mmdet_pipeline(detector.cfg) | |
| pose_estimator = init_model( | |
| args.pose_estimator_config, | |
| args.pose_estimator_checkpoint, | |
| device=args.device.lower()) | |
| assert isinstance(pose_estimator, TopdownPoseEstimator), 'Only "TopDown"' \ | |
| 'model is supported for the 1st stage (2D pose detection)' | |
| det_kpt_color = pose_estimator.dataset_meta.get('keypoint_colors', None) | |
| det_dataset_skeleton = pose_estimator.dataset_meta.get( | |
| 'skeleton_links', None) | |
| det_dataset_link_color = pose_estimator.dataset_meta.get( | |
| 'skeleton_link_colors', None) | |
| pose_lifter = init_model( | |
| args.pose_lifter_config, | |
| args.pose_lifter_checkpoint, | |
| device=args.device.lower()) | |
| assert isinstance(pose_lifter, PoseLifter), \ | |
| 'Only "PoseLifter" model is supported for the 2nd stage ' \ | |
| '(2D-to-3D lifting)' | |
| pose_lifter.cfg.visualizer.radius = args.radius | |
| pose_lifter.cfg.visualizer.line_width = args.thickness | |
| pose_lifter.cfg.visualizer.det_kpt_color = det_kpt_color | |
| pose_lifter.cfg.visualizer.det_dataset_skeleton = det_dataset_skeleton | |
| pose_lifter.cfg.visualizer.det_dataset_link_color = det_dataset_link_color | |
| visualizer = VISUALIZERS.build(pose_lifter.cfg.visualizer) | |
| # the dataset_meta is loaded from the checkpoint | |
| visualizer.set_dataset_meta(pose_lifter.dataset_meta) | |
| if args.input == 'webcam': | |
| input_type = 'webcam' | |
| else: | |
| input_type = mimetypes.guess_type(args.input)[0].split('/')[0] | |
| if args.output_root == '': | |
| save_output = False | |
| else: | |
| mmengine.mkdir_or_exist(args.output_root) | |
| output_file = os.path.join(args.output_root, | |
| os.path.basename(args.input)) | |
| if args.input == 'webcam': | |
| output_file += '.mp4' | |
| save_output = True | |
| if args.save_predictions: | |
| assert args.output_root != '' | |
| args.pred_save_path = f'{args.output_root}/results_' \ | |
| f'{os.path.splitext(os.path.basename(args.input))[0]}.json' | |
| if save_output: | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| pose_est_results_list = [] | |
| pred_instances_list = [] | |
| if input_type == 'image': | |
| frame = mmcv.imread(args.input, channel_order='rgb') | |
| _, _, pred_3d_instances, _ = process_one_image( | |
| args=args, | |
| detector=detector, | |
| frame=frame, | |
| frame_idx=0, | |
| pose_estimator=pose_estimator, | |
| pose_est_results_last=[], | |
| pose_est_results_list=pose_est_results_list, | |
| next_id=0, | |
| pose_lifter=pose_lifter, | |
| visualize_frame=frame, | |
| visualizer=visualizer) | |
| if args.save_predictions: | |
| # save prediction results | |
| pred_instances_list = split_instances(pred_3d_instances) | |
| if save_output: | |
| frame_vis = visualizer.get_image() | |
| mmcv.imwrite(mmcv.rgb2bgr(frame_vis), output_file) | |
| elif input_type in ['webcam', 'video']: | |
| next_id = 0 | |
| pose_est_results = [] | |
| if args.input == 'webcam': | |
| video = cv2.VideoCapture(0) | |
| else: | |
| video = cv2.VideoCapture(args.input) | |
| (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split('.') | |
| if int(major_ver) < 3: | |
| fps = video.get(cv2.cv.CV_CAP_PROP_FPS) | |
| else: | |
| fps = video.get(cv2.CAP_PROP_FPS) | |
| video_writer = None | |
| frame_idx = 0 | |
| while video.isOpened(): | |
| success, frame = video.read() | |
| frame_idx += 1 | |
| if not success: | |
| break | |
| pose_est_results_last = pose_est_results | |
| # First stage: 2D pose detection | |
| # make person results for current image | |
| (pose_est_results, pose_est_results_list, pred_3d_instances, | |
| next_id) = process_one_image( | |
| args=args, | |
| detector=detector, | |
| frame=frame, | |
| frame_idx=frame_idx, | |
| pose_estimator=pose_estimator, | |
| pose_est_results_last=pose_est_results_last, | |
| pose_est_results_list=pose_est_results_list, | |
| next_id=next_id, | |
| pose_lifter=pose_lifter, | |
| visualize_frame=mmcv.bgr2rgb(frame), | |
| visualizer=visualizer) | |
| if args.save_predictions: | |
| # save prediction results | |
| pred_instances_list.append( | |
| dict( | |
| frame_id=frame_idx, | |
| instances=split_instances(pred_3d_instances))) | |
| if save_output: | |
| frame_vis = visualizer.get_image() | |
| if video_writer is None: | |
| # the size of the image with visualization may vary | |
| # depending on the presence of heatmaps | |
| video_writer = cv2.VideoWriter(output_file, fourcc, fps, | |
| (frame_vis.shape[1], | |
| frame_vis.shape[0])) | |
| video_writer.write(mmcv.rgb2bgr(frame_vis)) | |
| if args.show: | |
| # press ESC to exit | |
| if cv2.waitKey(5) & 0xFF == 27: | |
| break | |
| time.sleep(args.show_interval) | |
| video.release() | |
| if video_writer: | |
| video_writer.release() | |
| else: | |
| args.save_predictions = False | |
| raise ValueError( | |
| f'file {os.path.basename(args.input)} has invalid format.') | |
| if args.save_predictions: | |
| with open(args.pred_save_path, 'w') as f: | |
| json.dump( | |
| dict( | |
| meta_info=pose_lifter.dataset_meta, | |
| instance_info=pred_instances_list), | |
| f, | |
| indent='\t') | |
| print(f'predictions have been saved at {args.pred_save_path}') | |
| if save_output: | |
| input_type = input_type.replace('webcam', 'video') | |
| print_log( | |
| f'the output {input_type} has been saved at {output_file}', | |
| logger='current', | |
| level=logging.INFO) | |
| if __name__ == '__main__': | |
| main() | |