在使用PyTorch进行网络训练得到.pth模型文件后,我们可能会做一些模型部署和加速的工作。这里一般会涉及到将PyTorch模型转为ONNX模型的过程。
PyTorch自带了ONNX转换方法(torch.onnx.export
),可以很方便的将一些仅包含通用算子的网络的PyTorch模型转为ONNX格式。
将PyTorch模型文件准备好,放在’./weights/torch.pth’路径下。
__author__ = 'TracelessLe'
import torch
TORCH_WEIGHT_PATH = './weights/torch.pth'
ONNX_MODEL_PATH = 'net_bs8_v1.onnx'
def get_numpy_data():
batch_size = 8
img_input = np.ones((batch_size, 3, 128, 128), dtype=np.float32)
return img_input
def get_torch_model():
# Load Network Here
pass
def torch2onnx(img_input, onnx_model_path, device_id=0):
torch_model = get_torch_model() # Network define
device = 'cpu' if device_id < 0 else f'cuda:{device_id}'
torch_model.to(device)
torch_weights = torch.load(TORCH_WEIGHT_PATH)
torch_model.load_state_dict(torch_weights)
torch_model.eval()
dummy_img = torch.Tensor(img_input).to(device)
torch.onnx.export(
torch_model,
(dummy_img),
onnx_model_path,
input_names=['input'],
output_names=['output'],
export_params=True,
verbose=True,
do_constant_folding=False, # or True
opset_version=11
)
print("Generate ONNX file over!")
if __name__ == "__main__":
img_input = get_numpy_data()
torch2onnx(img_input, ONNX_MODEL_PATH)
生成的ONNX结构可能还有简化的空间,可以使用onnx-simplifier工具进一步优化。
__author__ = 'TracelessLe'
import onnx
from onnxsim import simplify
ONNX_MODEL_PATH = 'net_bs8_v1.onnx'
ONNX_SIM_MODEL_PATH = 'net_bs8_v1_simple.onnx'
if __name__ == "__main__":
onnx_model = onnx.load(ONNX_MODEL_PATH)
onnx_sim_model, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(onnx_sim_model, ONNX_SIM_MODEL_PATH)
print('ONNX file simplified!')
__author__ = 'TracelessLe'
import time
import torch
import onnxruntime
import numpy as np
def test_torch(img_input, device_id=0, loop=100):
torch_model = get_torch_model()
device = 'cpu' if device_id < 0 else f'cuda:{device_id}'
torch_model.to(device)
torch_weights = torch.load(TORCH_WEIGHT_PATH)
torch_model.load_state_dict(torch_weights)
torch_model.eval()
dummy_img = torch.Tensor(img_input).to(device)
batch_size = 8
with torch.no_grad():
out_img = torch_model(dummy_img)
time1 = time.time()
for i in range(loop):
time_bs1 = time.time()
with torch.no_grad():
out_img = torch_model(dummy_img)
out_img_numpy = out_img.detach().cpu().numpy()
time_bs2 = time.time()
time_use_pt_bs = time_bs2 - time_bs1
print(f'PyTorch use time {time_use_pt_bs} for bs8')
time2 = time.time()
time_use_pt = time2-time1
print(f'PyTorch use time {time_use_pt} for loop {loop}, FPS={loop*batch_size//time_use_pt}')
return out_img_numpy
def test_onnx(inputs, loop=100):
inputs = inputs.astype(np.float32)
print(onnxruntime.get_device())
sess = onnxruntime.InferenceSession(ONNX_SIM_MODEL_PATH)
batch_size = 8
time1 = time.time()
for i in range(loop):
time_bs1 = time.time()
out_ort_img = sess.run(None, {sess.get_inputs()[0].name: inputs,})
time_bs2 = time.time()
time_use_onnx_bs = time_bs2 - time_bs1
print(f'ONNX use time {time_use_onnx_bs} for bs8')
time2 = time.time()
time_use_onnx = time2-time1
print(f'ONNX use time {time_use_onnx} for loop {loop}, FPS={loop*batch_size//time_use_onnx}')
return out_ort_img
if __name__ == "__main__":
img_input = get_numpy_data()
out_ort_img = test_onnx(img_input, loop=1)[0]
out_img_numpy = test_torch(img_input, loop=1)
mse = np.square(np.subtract(out_ort_img, out_img_numpy)).mean()
print('mse between pytorch and onnx result: ', mse)
可以使用Netron可视化PyTorch、ONNX等模型结构。
通过Netron打开转换得到的ONNX以及简化后的ONNX文件,可以对比查看网络结构的变化。
(1)do_constant_folding
在使用torch.onnx.export
时可能会报错(与do_constant_folding相关),那么可以将do_constant_folding=True
设为do_constant_folding=False
。
(2)自定义算子的转换
PyTorch转ONNX目前仅支持一些通用算子(见PyTorch Doc),自定义的算子在转出时会报错。可以通过改写一部分算子或者使用功能相近的通用算子替代。
(3)PyTorch转ONNX的方式
目前PyTorch转ONNX有两种方式,除了主流的torch.onnx.export外,还有torch.jit.trace。但是后一种还处于实验阶段,可能会遇到较多问题。
本文为原创文章,独家发布在blog.csdn.net/TracelessLe。未经个人允许不得转载。如需帮助请email至[email protected]。
[1] ONNX | Home
[2] torch.onnx — PyTorch 1.9.0 documentation
[3] (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime — PyTorch Tutorials 1.9.0+cu102 documentation
[4] daquexian/onnx-simplifier: Simplify your onnx model
[5] lutzroeder/netron: Visualizer for neural network, deep learning, and machine learning models