安装onnx
conda install onnx -c conda-forge
安装onnxruntime
pip install onnxruntime
下载个模型试试
import torch
import torchvision.models as models
# download
model = models.resnet50(pretrained=True)
# 保存成pth
torch.save(model, 'resnet50.pth')
保存成onnx
data = torch.rand(1,3,224,224)
torch.onnx.export(model, data, 'resnet50.onnx')
torch.onnx.export(model, args, f, export_params=True, verbose=False, training=False, input_names=None, output_names=None, aten=False, export_raw_ir=False, operator_export_type=None, opset_version=None, _retain_param_name=True, do_constant_folding=False, example_outputs=None, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None)
分别加载pth和onnx
import onnxruntime
# PyTorch model
torch_model = torch.load('resnet50.pth')
# ONNX model
onnx_model = onnxruntime.InferenceSession('resnet50.onnx')
构造输入数据
import numpy as np
data = np.random.rand(1,3,224,224).astype(np.float32)
torch_data = torch.from_numpy(data)
# pth模型推理
def torch_inf():
torch_model(torch_data)
# onnx模型跑推理
def onnx_inf():
onnx_model.run(None,{
onnx_model.get_inputs()[0].name: data
})
# 设置循环次数
n = 100
torch_time = timeit(lambda : torch_inf(), number=n)/n #0.139
onnx_time = timeit(lambda : onnx_inf(), number=n)/n #0.0257
可以看出onnx在CPU上的推理速度是比pytorch快的。后面我测了一下retinaface,效率缩减了10倍。配的工作站CPU是Xeon E5-2678 v3。
run(output_names, input_feed, run_options=None):
torch.onnx.export的时候,如果网络用了F.interpolate,会给个UserWarning: You are trying to export the model with onnx:Upsample for ONNX opset version 9. This operator might cause results to not match the expected results by PyTorch. ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. Attributes to determine how to transform the input were added in onnx:Resize in opset 11 to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode). We recommend using opset 11 and above for models using this operator.
大概意思就是如果用到上采样操作了,opset_version默认是9的情况下可能会有些问题。
结果方法:
torch.onnx.export里添加个参数opset_version=11,就搞定。
使用onnxruntime.run的时候,如果输入的input是torch格式,就会报错RuntimeError: Input must be a list of dictionaries or a single numpy array for input 'input.1'.
大概意思就是让你输入个numpy或者dictionary的格式。
但是torch.onnx.export的时候输入的data是个tensor啊!
害,不纠结这个事。
输入的图像经过letterbox后构造成[1, 3, 640, 640]后在转个numpy就妥了呗。
待续。。。