原文来自于地平线开发者社区,未来会持续发布深度学习、板端部署的相关优质文章与视频,如果文章对您有帮助,麻烦给点个赞,如果您有兴趣一起学习,欢迎点个关注:寻找永不遗憾(CSDN用户名)
使用深度学习开源框架Pytorch训练完网络模型后,在部署之前通常需要进行格式转换,例如地平线工具链模型转换目前仅支持Caffe1.0和ONNX(opset_version=10/11 且 ir_version≤7)两种。ONNX(Open Neural Network Exchange)格式是一种常用的开源神经网络格式,被较多推理引擎支持,例如Pytorch、PaddlePaddle、TensorFlow等。本文将详细介绍如何将Pytorch格式的模型导出到ONNX格式的模型。
本文以Python3.6为例,涉及到的whl包及版本信息如下:
torch 1.10.2
onnx 1.8.0
onnxruntime 1.10.0
numpy 1.19.5
torch.onnx.export函数实现了Pytorch模型导出到ONNX模型,在pytorch1.10.2中,torch.onnx.export函数参数如下:
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
input_names=None, output_names=None, operator_export_type=None,
opset_version=None, _retain_param_name=None, do_constant_folding=True,
example_outputs=None, strip_doc_string=None, dynamic_axes=None,
keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=None,
use_external_data_format=None):
大多数参数使用默认配置即可,下面对常用的几个参数进行介绍:
torch.onnx.export(
model, # 需要转换的网络模型
args, # ONNX模型输入,通常为 tuple 或 torch.Tensor
f, # ONNX模型导出路径
input_names=None, # 按顺序定义ONNX模型输入结点名称,格式为:list of str,若不指定,会使用默认名字
output_names=None, # 按顺序定义ONNX模型输出结点名称,格式为:list of str,若不指定,会使用默认名字
opset_version=11 # opset版本,地平线目前仅支持设置为 10 or 11
)
其它参数的介绍可参考官方torch.onnx.export()函数手册。
该节内容主要包括单输入网络构建、模型导出生成ONNX格式、导出的ONNX模型有效性验证三个部分。可直接运行下方代码得到对应的ONNX模型,欢迎参考代码中的注释进行理解。
import torch.nn as nn
import torch
import numpy as np
import onnx
import onnxruntime
# -----------------------------------#
# 定义一个简单的单输入网络
# -----------------------------------#
class MyNet(nn.Module):
def __init__(self, num_classes=10):
super(MyNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), # input[3, 28, 28] output[32, 28, 28]
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # output[64, 14, 14]
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2) # output[64, 7, 7]
)
self.fc = nn.Linear(64 * 7 * 7, num_classes)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, start_dim=1)
x = self.fc(x)
return x
# -----------------------------------#
# 导出ONNX模型函数
# -----------------------------------#
def model_convert_onnx(model, input_shape, output_path):
dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1])
input_names = ["input1"] # 导出的ONNX模型输入节点名称
output_names = ["output1"] # 导出的ONNX模型输出节点名称
torch.onnx.export(
model,
dummy_input,
output_path,
verbose=False, # 如果指定为True,在导出的ONNX中会有详细的导出过程信息description
keep_initializers_as_inputs=False, # 若为True,会出现需要warning消除的问题
opset_version=11, # 版本通常为10 or 11
input_names=input_names,
output_names=output_names,
)
if __name__ == '__main__':
model = MyNet()
# print(model)
# 建议将模型转成 eval 模式
model.eval()
# 网络模型的输入尺寸
input_shape = (28, 28)
# ONNX模型输出路径
output_path = './MyNet.onnx'
# 导出为ONNX模型
model_convert_onnx(model, input_shape, output_path)
print("model convert onnx finsh.")
# -----------------------------------#
# 复杂模型可以使用下面的方法进行简化
# -----------------------------------#
# import onnxsim
# MyNet_sim = onnxsim.simplify(onnx.load(output_path))
# onnx.save(MyNet_sim[0], "MyNet_sim.onnx")
# -----------------------------------------------------------------------#
# 第一轮ONNX模型有效性验证,用来检查模型是否满足 ONNX 标准
# 这一步是必要的,因为无论模型是否满足标准,ONNX 都允许使用 onnx.save 存储模型,
# 我们都不会希望生成一个不满足标准的模型~
# -----------------------------------------------------------------------#
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("onnx model check_1 finsh.")
# ----------------------------------------------------------------#
# 第二轮ONNX模型有效性验证,用来验证ONNX模型与Pytorch模型的推理一致性
# ----------------------------------------------------------------#
# 随机初始化一个模型输入,注意输入分辨率
x = torch.randn(size=(1, 3, input_shape[0], input_shape[1]))
# torch模型推理
with torch.no_grad():
torch_out = model(x)
print(torch_out) # tensor([[-0.5728, 0.1695, ..., -0.3256, 1.1357, -0.4081]])
# print(type(torch_out)) #
# 初始化ONNX模型
ort_session = onnxruntime.InferenceSession(output_path)
# ONNX模型输入初始化
ort_inputs = {ort_session.get_inputs()[0].name: x.numpy()}
# ONNX模型推理
ort_outs = ort_session.run(None, ort_inputs)
# print(ort_outs) # [array([[-0.5727689 , 0.16947027, ..., -0.32555276, 1.13574252, -0.40812433]], dtype=float32)]
# print(type(ort_outs)) # ,里面是个numpy矩阵
# print(type(ort_outs[0])) #
ort_outs = ort_outs[0] # 把内部numpy矩阵取出来,这一步很有必要
# print(torch_out.numpy().shape) # (1, 10)
# print(ort_outs.shape) # (1, 10)
# ----------------------------------------------------------------#
# 比较实际值与期望值的差异,通过继续往下执行,不通过引发AssertionError
# 需要两个numpy输入
# ----------------------------------------------------------------#
np.testing.assert_allclose(torch_out.numpy(), ort_outs, rtol=1e-03, atol=1e-05)
print("onnx model check_2 finsh.")
该节内容主要包括多输入网络构建、模型导出生成ONNX格式、导出的ONNX模型有效性验证三个部分。可直接运行下方代码得到对应的ONNX模型,欢迎参考代码中的注释进行理解。
import torch.nn as nn
import torch
import numpy as np
import onnx
import onnxruntime
# -----------------------------------#
# 定义一个简单的双输入网络
# -----------------------------------#
class MyNet_multi_input(nn.Module):
def __init__(self, num_classes=10):
super(MyNet_multi_input, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1) # input[3, 28, 28] output[32, 14, 14]
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1) # input[1, 28, 28] output[32, 14, 14]
self.bn2 = nn.BatchNorm2d(16)
self.relu2 = nn.ReLU(inplace=True)
self.fc = nn.Linear(48 * 14 * 14, num_classes)
def forward(self, x, y):
x = self.relu1(self.bn1(self.conv1(x)))
y = self.relu2(self.bn2(self.conv2(y)))
z = torch.cat((x, y), 1)
z = torch.flatten(z, start_dim=1)
z = self.fc(z)
return z
# -----------------------------------#
# 导出ONNX模型函数
# -----------------------------------#
def multi_input_model_convert_onnx(model, input_shape, output_path):
dummy_input1 = torch.randn(1, 3, input_shape[0], input_shape[1])
dummy_input2 = torch.randn(1, 1, input_shape[0], input_shape[1])
input_names = ["input1", "input2"] # 导出的ONNX模型输入节点名称
output_names = ["output1"] # 导出的ONNX模型输出节点名称
torch.onnx.export(
model,
(dummy_input1, dummy_input2),
output_path,
verbose=False, # 如果指定为True,在导出的ONNX中会有详细的导出过程信息description
keep_initializers_as_inputs=False, # 若为True,会出现需要warning消除的问题
opset_version=11, # 版本通常为10 or 11
input_names=input_names,
output_names=output_names,
)
if __name__ == '__main__':
multi_input_model = MyNet_multi_input()
# print(multi_input_model)
# 建议将模型转成 eval 模式
multi_input_model.eval()
# 网络模型的输入尺寸
input_shape = (28, 28)
# ONNX模型输出路径
multi_input_model_output_path = './multi_input_model.onnx'
# 导出为ONNX模型
multi_input_model_convert_onnx(multi_input_model, input_shape, multi_input_model_output_path)
print("multi_input_model convert onnx finsh.")
# -----------------------------------#
# 复杂模型可以使用下面的方法进行简化
# -----------------------------------#
# import onnxsim
# multi_input_model_sim = onnxsim.simplify(onnx.load(multi_input_model_output_path))
# onnx.save(multi_input_model_sim[0], "multi_input_model_sim.onnx")
# -----------------------------------------------------------------------#
# 第一轮ONNX模型有效性验证,用来检查模型是否满足 ONNX 标准
# 这一步是必要的,因为无论模型是否满足标准,ONNX 都允许使用 onnx.save 存储模型,
# 我们都不会希望生成一个不满足标准的模型~
# -----------------------------------------------------------------------#
onnx_model = onnx.load(multi_input_model_output_path)
onnx.checker.check_model(multi_input_model_output_path)
print("onnx model check_1 finsh.")
# ----------------------------------------------------------------#
# 第二轮ONNX模型有效性验证,用来验证ONNX模型与Pytorch模型的推理一致性
# ----------------------------------------------------------------#
# 随机初始化一个模型输入,注意输入分辨率
x = torch.randn(size=(1, 3, input_shape[0], input_shape[1]))
y = torch.randn(size=(1, 1, input_shape[0], input_shape[1]))
# torch模型推理
with torch.no_grad():
torch_out = multi_input_model(x, y)
# print(torch_out) # tensor([[-0.5728, 0.1695, ..., -0.3256, 1.1357, -0.4081]])
# print(type(torch_out)) #
# 初始化ONNX模型
ort_session = onnxruntime.InferenceSession(multi_input_model_output_path)
# ONNX模型输入初始化
ort_inputs = {ort_session.get_inputs()[0].name: x.numpy(), ort_session.get_inputs()[1].name: y.numpy()}
# ONNX模型推理
ort_outs = ort_session.run(None, ort_inputs)
# print(ort_outs) # [array([[-0.5727689 , 0.16947027, ..., -0.32555276, 1.13574252, -0.40812433]], dtype=float32)]
# print(type(ort_outs)) # ,里面是个numpy矩阵
# print(type(ort_outs[0])) #
ort_outs = ort_outs[0] # 把内部numpy矩阵取出来,这一步很有必要
# print(torch_out.numpy().shape) # (1, 10)
# print(ort_outs.shape) # (1, 10)
# ----------------------------------------------------------------#
# 比较实际值与期望值的差异,通过继续往下执行,不通过引发AssertionError
# 需要两个numpy输入
# ----------------------------------------------------------------#
np.testing.assert_allclose(torch_out.numpy(), ort_outs, rtol=1e-03, atol=1e-05)
print("onnx model check_2 finsh.")
更多内容可参考 PyTorch官方导出ONNX模型教程。
导出成ONNX模型后,可以使用开源可视化工具Netron来查看网络结构及相关配置信息。Netron的使用方式主要分为两种,一种是使用在线网页版,另一种是下载安装程序。下面以在线网页版打开第4节中导出单输入ONNX模型为例,进行介绍。点击在线网页版链接,打开导出的ONNX模型,可视化效果为:
地平线工具链支持的ONNX模型需要满足 opset_version=10/11 且 ir_version≤7。
地平线工具链支持的ONNX模型需要满足 opset_version=10/11 且 ir_version≤7,当拿到的ONNX模型不满足这两个要求时怎么办呢?
如果有条件修改代码重新导出的话,这是一种解决方案。另外一种可尝试的解决方案是直接修改ONNX模型的对应属性,代码示例如下:
import onnx
model = onnx.load("./MyNet.onnx")
model.ir_version = 6
model.opset_import[0].version = 10
onnx.save_model(model, "MyNetOutput.onnx")
注意:高版本向低版本切换时可能会出现问题,这里只是一种可尝试的解决方案。
使用Netron可视化MyNetoutput.onnx,如下图所示:
原文来自于地平线开发者社区,未来会持续发布深度学习、板端部署的相关优质文章与视频,如果文章对您有帮助,麻烦给点个赞,如果您有兴趣一起学习,欢迎点个关注:寻找永不遗憾(CSDN用户名)