Spaces:
Runtime error
Runtime error
| # Copyright (c) Tencent Inc. All rights reserved. | |
| # This file is modifef from mmyolo/demo/video_demo.py | |
| import argparse | |
| import cv2 | |
| import mmcv | |
| import torch | |
| from mmengine.dataset import Compose | |
| from mmdet.apis import init_detector | |
| from mmengine.utils import track_iter_progress | |
| from mmyolo.registry import VISUALIZERS | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='YOLO-World video demo') | |
| parser.add_argument('config', help='Config file') | |
| parser.add_argument('checkpoint', help='Checkpoint file') | |
| parser.add_argument('video', help='video file path') | |
| parser.add_argument( | |
| 'text', | |
| help= | |
| 'text prompts, including categories separated by a comma or a txt file with each line as a prompt.' | |
| ) | |
| parser.add_argument('--device', | |
| default='cuda:0', | |
| help='device used for inference') | |
| parser.add_argument('--score-thr', | |
| default=0.1, | |
| type=float, | |
| help='confidence score threshold for predictions.') | |
| parser.add_argument('--out', type=str, help='output video file') | |
| args = parser.parse_args() | |
| return args | |
| def inference_detector(model, image, texts, test_pipeline, score_thr=0.3): | |
| data_info = dict(img_id=0, img=image, texts=texts) | |
| data_info = test_pipeline(data_info) | |
| data_batch = dict(inputs=data_info['inputs'].unsqueeze(0), | |
| data_samples=[data_info['data_samples']]) | |
| with torch.no_grad(): | |
| output = model.test_step(data_batch)[0] | |
| pred_instances = output.pred_instances | |
| pred_instances = pred_instances[pred_instances.scores.float() > | |
| score_thr] | |
| output.pred_instances = pred_instances | |
| return output | |
| def main(): | |
| args = parse_args() | |
| model = init_detector(args.config, args.checkpoint, device=args.device) | |
| # build test pipeline | |
| model.cfg.test_dataloader.dataset.pipeline[ | |
| 0].type = 'mmdet.LoadImageFromNDArray' | |
| test_pipeline = Compose(model.cfg.test_dataloader.dataset.pipeline) | |
| if args.text.endswith('.txt'): | |
| with open(args.text) as f: | |
| lines = f.readlines() | |
| texts = [[t.rstrip('\r\n')] for t in lines] + [[' ']] | |
| else: | |
| texts = [[t.strip()] for t in args.text.split(',')] + [[' ']] | |
| # reparameterize texts | |
| model.reparameterize(texts) | |
| # init visualizer | |
| visualizer = VISUALIZERS.build(model.cfg.visualizer) | |
| # the dataset_meta is loaded from the checkpoint and | |
| # then pass to the model in init_detector | |
| visualizer.dataset_meta = model.dataset_meta | |
| video_reader = mmcv.VideoReader(args.video) | |
| video_writer = None | |
| if args.out: | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| video_writer = cv2.VideoWriter( | |
| args.out, fourcc, video_reader.fps, | |
| (video_reader.width, video_reader.height)) | |
| for frame in track_iter_progress(video_reader): | |
| result = inference_detector(model, | |
| frame, | |
| texts, | |
| test_pipeline, | |
| score_thr=args.score_thr) | |
| visualizer.add_datasample(name='video', | |
| image=frame, | |
| data_sample=result, | |
| draw_gt=False, | |
| show=False, | |
| pred_score_thr=args.score_thr) | |
| frame = visualizer.get_image() | |
| if args.out: | |
| video_writer.write(frame) | |
| if video_writer: | |
| video_writer.release() | |
| if __name__ == '__main__': | |
| main() | |