-
Notifications
You must be signed in to change notification settings - Fork 662
/
Copy pathpytorch2onnx.py
117 lines (102 loc) · 4.58 KB
/
pytorch2onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Any, Optional, Union
import mmengine
from .core import PIPELINE_MANAGER
@PIPELINE_MANAGER.register_pipeline()
def torch2onnx(img: Any,
work_dir: str,
save_file: str,
deploy_cfg: Union[str, mmengine.Config],
model_cfg: Union[str, mmengine.Config],
model_checkpoint: Optional[str] = None,
append_input: list = None,
device: str = 'cuda:0'):
"""Convert PyTorch model to ONNX model.
Examples:
>>> from mmdeploy.apis import torch2onnx
>>> img = 'demo.jpg'
>>> work_dir = 'work_dir'
>>> save_file = 'fcos.onnx'
>>> deploy_cfg = ('configs/mmdet/detection/'
'detection_onnxruntime_dynamic.py')
>>> model_cfg = ('mmdetection/configs/fcos/'
'fcos_r50_caffe_fpn_gn-head_1x_coco.py')
>>> model_checkpoint = ('checkpoints/'
'fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth')
>>> device = 'cpu'
>>> torch2onnx(img, work_dir, save_file, deploy_cfg, \
model_cfg, model_checkpoint, device)
Args:
img (str | np.ndarray | torch.Tensor): Input image used to assist
converting model.
work_dir (str): A working directory to save files.
save_file (str): Filename to save onnx model.
deploy_cfg (str | mmengine.Config): Deployment config file or
Config object.
model_cfg (str | mmengine.Config): Model config file or Config object.
model_checkpoint (str): A checkpoint path of PyTorch model,
defaults to `None`.
append_input (list): Additional inputs other than images, suitable for multimodal models such as text features of Grounded DINO.
device (str): A string specifying device type, defaults to 'cuda:0'.
"""
from mmdeploy.apis.core.pipeline_manager import no_mp
from mmdeploy.utils import (Backend, get_backend, get_dynamic_axes,
get_input_shape, get_onnx_config, load_config)
from .onnx import export
# load deploy_cfg if necessary
deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg)
mmengine.mkdir_or_exist(osp.abspath(work_dir))
input_shape = get_input_shape(deploy_cfg)
# create model an inputs
from mmdeploy.apis import build_task_processor
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
torch_model = task_processor.build_pytorch_model(model_checkpoint)
data, model_inputs = task_processor.create_input(
img,
input_shape,
data_preprocessor=getattr(torch_model, 'data_preprocessor', None))
if isinstance(model_inputs, list) and len(model_inputs) == 1:
model_inputs = model_inputs[0]
if isinstance(append_input, list):
temp = [model_inputs]
temp.extend(append_input)
model_inputs = temp
data_samples = data['data_samples']
input_metas = {'data_samples': data_samples, 'mode': 'predict'}
# export to onnx
context_info = dict()
context_info['deploy_cfg'] = deploy_cfg
output_prefix = osp.join(work_dir,
osp.splitext(osp.basename(save_file))[0])
backend = get_backend(deploy_cfg).value
onnx_cfg = get_onnx_config(deploy_cfg)
opset_version = onnx_cfg.get('opset_version', 11)
input_names = onnx_cfg['input_names']
output_names = onnx_cfg['output_names']
axis_names = input_names + output_names
dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names)
verbose = not onnx_cfg.get('strip_doc_string', True) or onnx_cfg.get(
'verbose', False)
keep_initializers_as_inputs = onnx_cfg.get('keep_initializers_as_inputs',
True)
optimize = onnx_cfg.get('optimize', False)
if backend == Backend.NCNN.value:
"""NCNN backend needs a precise blob counts, while using onnx optimizer
will merge duplicate initilizers without reference count."""
optimize = False
with no_mp():
export(
torch_model,
model_inputs,
input_metas=input_metas,
output_path_prefix=output_prefix,
backend=backend,
input_names=input_names,
output_names=output_names,
context_info=context_info,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
verbose=verbose,
keep_initializers_as_inputs=keep_initializers_as_inputs,
optimize=optimize)