Spaces:
Runtime error
Runtime error
| # # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| import json | |
| import warnings | |
| import argparse | |
| from io import BytesIO | |
| import onnx | |
| import torch | |
| from mmdet.apis import init_detector | |
| from mmengine.config import ConfigDict | |
| from mmengine.logging import print_log | |
| from mmengine.utils.path import mkdir_or_exist | |
| from easydeploy.model import DeployModel, MMYOLOBackend # noqa E402 | |
| warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) | |
| warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning) | |
| warnings.filterwarnings(action='ignore', category=UserWarning) | |
| warnings.filterwarnings(action='ignore', category=FutureWarning) | |
| warnings.filterwarnings(action='ignore', category=ResourceWarning) | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('config', help='Config file') | |
| parser.add_argument('checkpoint', help='Checkpoint file') | |
| parser.add_argument('--custom-text', | |
| type=str, | |
| help='custom text inputs (text json) for YOLO-World.') | |
| parser.add_argument('--add-padding', | |
| action="store_true", | |
| help="add an empty padding to texts.") | |
| parser.add_argument('--model-only', | |
| action='store_true', | |
| help='Export model only') | |
| parser.add_argument('--without-nms', | |
| action='store_true', | |
| help='Export model without NMS') | |
| parser.add_argument('--without-bbox-decoder', | |
| action='store_true', | |
| help='Export model without Bbox Decoder (for INT8 Quantization)') | |
| parser.add_argument('--work-dir', | |
| default='./work_dirs', | |
| help='Path to save export model') | |
| parser.add_argument('--img-size', | |
| nargs='+', | |
| type=int, | |
| default=[640, 640], | |
| help='Image size of height and width') | |
| parser.add_argument('--batch-size', type=int, default=1, help='Batch size') | |
| parser.add_argument('--device', | |
| default='cuda:0', | |
| help='Device used for inference') | |
| parser.add_argument('--simplify', | |
| action='store_true', | |
| help='Simplify onnx model by onnx-sim') | |
| parser.add_argument('--opset', | |
| type=int, | |
| default=11, | |
| help='ONNX opset version') | |
| parser.add_argument('--backend', | |
| type=str, | |
| default='onnxruntime', | |
| help='Backend for export onnx') | |
| parser.add_argument('--pre-topk', | |
| type=int, | |
| default=1000, | |
| help='Postprocess pre topk bboxes feed into NMS') | |
| parser.add_argument('--keep-topk', | |
| type=int, | |
| default=100, | |
| help='Postprocess keep topk bboxes out of NMS') | |
| parser.add_argument('--iou-threshold', | |
| type=float, | |
| default=0.65, | |
| help='IoU threshold for NMS') | |
| parser.add_argument('--score-threshold', | |
| type=float, | |
| default=0.25, | |
| help='Score threshold for NMS') | |
| args = parser.parse_args() | |
| args.img_size *= 2 if len(args.img_size) == 1 else 1 | |
| return args | |
| def build_model_from_cfg(config_path, checkpoint_path, device): | |
| model = init_detector(config_path, checkpoint_path, device=device) | |
| model.eval() | |
| return model | |
| def main(): | |
| args = parse_args() | |
| mkdir_or_exist(args.work_dir) | |
| backend = MMYOLOBackend(args.backend.lower()) | |
| if backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO, | |
| MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7): | |
| if not args.model_only: | |
| print_log('Export ONNX with bbox decoder and NMS ...') | |
| else: | |
| args.model_only = True | |
| print_log(f'Can not export postprocess for {args.backend.lower()}.\n' | |
| f'Set "args.model_only=True" default.') | |
| if args.model_only: | |
| postprocess_cfg = None | |
| output_names = None | |
| else: | |
| postprocess_cfg = ConfigDict(pre_top_k=args.pre_topk, | |
| keep_top_k=args.keep_topk, | |
| iou_threshold=args.iou_threshold, | |
| score_threshold=args.score_threshold) | |
| output_names = ['num_dets', 'boxes', 'scores', 'labels'] | |
| if args.without_bbox_decoder or args.without_nms: | |
| output_names = ['scores', 'boxes'] | |
| if args.custom_text is not None and len(args.custom_text) > 0: | |
| with open(args.custom_text) as f: | |
| texts = json.load(f) | |
| texts = [x[0] for x in texts] | |
| else: | |
| from mmdet.datasets import CocoDataset | |
| texts = CocoDataset.METAINFO['classes'] | |
| if args.add_padding: | |
| texts = texts + [' '] | |
| baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device) | |
| if hasattr(baseModel, 'reparameterize'): | |
| # reparameterize text into YOLO-World | |
| baseModel.reparameterize([texts]) | |
| deploy_model = DeployModel(baseModel=baseModel, | |
| backend=backend, | |
| postprocess_cfg=postprocess_cfg, | |
| with_nms=not args.without_nms, | |
| without_bbox_decoder=args.without_bbox_decoder) | |
| deploy_model.eval() | |
| fake_input = torch.randn(args.batch_size, 3, | |
| *args.img_size).to(args.device) | |
| # dry run | |
| deploy_model(fake_input) | |
| save_onnx_path = os.path.join( | |
| args.work_dir, | |
| os.path.basename(args.checkpoint).replace('pth', 'onnx')) | |
| # export onnx | |
| with BytesIO() as f: | |
| torch.onnx.export(deploy_model, | |
| fake_input, | |
| f, | |
| input_names=['images'], | |
| output_names=output_names, | |
| opset_version=args.opset) | |
| f.seek(0) | |
| onnx_model = onnx.load(f) | |
| onnx.checker.check_model(onnx_model) | |
| # Fix tensorrt onnx output shape, just for view | |
| if not args.model_only and not args.without_nms and backend in ( | |
| MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7): | |
| shapes = [ | |
| args.batch_size, 1, args.batch_size, args.keep_topk, 4, | |
| args.batch_size, args.keep_topk, args.batch_size, | |
| args.keep_topk | |
| ] | |
| for i in onnx_model.graph.output: | |
| for j in i.type.tensor_type.shape.dim: | |
| j.dim_param = str(shapes.pop(0)) | |
| if args.simplify: | |
| try: | |
| import onnxsim | |
| onnx_model, check = onnxsim.simplify(onnx_model) | |
| assert check, 'assert check failed' | |
| except Exception as e: | |
| print_log(f'Simplify failure: {e}') | |
| onnx.save(onnx_model, save_onnx_path) | |
| print_log(f'ONNX export success, save into {save_onnx_path}') | |
| if __name__ == '__main__': | |
| main() | |