1.掌握pytorch模型转换到onnx模型
2.顺利运行onnx模型
3.比对onnx模型和pytorch模型的输出结果
前提条件:需要安装onnx 和 onnxruntime,可以通过 pip install onnx 和 pip install onnxruntime 进行安装
1 . pytorch 转 onnx
pytorch 转 onnx 只需要一个函数 torch.onnx.export
torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
参数说明:
import torch
import torch.nn
import onnx
model = torch.load('best.pt')
model.eval()
input_names = ['input']
output_names = ['output']
x = torch.randn(1,3,32,32,requires_grad=True)
torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')
2 . 运行onnx模型
检查onnx模型,并使用onnxruntime运行。
import onnx
import onnxruntime as ort
model = onnx.load('best.onnx')
onnx.checker.check_model(model)
session = ort.InferenceSession('best.onnx')
x=np.random.randn(1,3,32,32).astype(np.float32) # 注意输入type一定要np.float32!!!!!
# x= torch.randn(batch_size,chancel,h,w)
outputs = session.run(None,input = { 'input' : x })
参数说明:
3.onnx模型输出与pytorch模型比对
import numpy as np
np.testing.assert_allclose(torch_result[0].detach().numpu(),onnx_result,rtol=0.0001)
内容参考:
https://zhuanlan.zhihu.com/p/422290231