pytorch pth 模型转 onnx模型,并验证结果正确性

pytorch 模型部署很重要的一步是转存pth模型为ONNX,本文记录方法。

转存 onnx

  • 建立自己的pytorch模型,并加载权重
model = create_model(num_classes=2)
model.load_state_dict(load(model_path, map_location='cpu')["model"])
  • 转存onnx文件
dummy_input = torch.randn(1, 3, 256, 256, device='cpu')
torch.onnx._export(model, dummy_input, "faster_rcnn.onnx", verbose=True, opset_version=11)

将模型保存在了当前目录的 faster_rcnn.onnx文件内

验证 onnx 有效性

  • 安装 onnxruntime
pip install onnxruntime
  • 加载onnx模型并测试
import onnxruntime
from onnxruntime.datasets import get_example

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# 测试数据
dummy_input = torch.randn(1, 3, 256, 256, device='cpu')

example_model = get_example(<absolute_root_to_your_onnx_model_file>)
# netron.start(example_model) 使用 netron python 包可视化网络
sess = onnxruntime.InferenceSession(example_model)

# onnx 网络输出
onnx_out = sess.run(None, {<input_layer_name_of_your_network>: to_numpy(dummy_input)})
print(onnx_out)

model.eval()
with torch.no_grad():
    # pytorch model 网络输出
    torch_out = model(dummy_input)
	print(torch_out)
  • 输出:
onnx_out
[array([[  0.       ,  93.246    , 228.95842  , 256.       ],
       [  0.       ,   2.6370468, 209.39705  , 148.17822  ]],
      dtype=float32), array([1, 1], dtype=int64), array([0.1501071 , 0.07568519], dtype=float32)]

torch_out
[{'boxes': tensor([[  0.0000,  93.2459, 228.9584, 256.0000],
        [  0.0000,   2.6370, 209.3971, 148.1782]]), 'labels': tensor([1, 1]), 'scores': tensor([0.1501, 0.0757])}]

获取自己网络输入层名称

  • 有时对网络不熟悉的情况下不清楚模型输入层的名称,可以使用Netron可视化自己的网络,获取输入层名称,喂入onnx的sess中。

注意 !!!

  • pytorch 模型在转 ONNX 模型的过程中,使用的导出器是一个基于轨迹的导出器,这意味着它执行时需要运行一次模型,然后导出实际参与运算的运算符. 这也意味着, 如果你的模型是动态的,例如,改变一些依赖于输入数据的操作,这时的导出结果是不准确的.同样,一 个轨迹可能只对一个具体的输入尺寸有效 (这是为什么我们在轨迹中需要有明确的输入的原因之一.) 我们建议检查 模型的轨迹,确保被追踪的运算符是合理的. ——— pytorch 文档

  • 也就是说,如果网络模块中存在 if… else… 类似的分支,在生成ONNX模型时会依据所使用的初始数据来选择其中某一个分支,这样所生成的ONNX模型仅会保留这一个分支的结构,在原始pytorch模型中的其他逻辑能力在该模型中不复存在。

参考资料

  • https://pytorch.apachecn.org/docs/0.3/onnx.html
  • https://www.jianshu.com/p/5a0a09fbdeba

你可能感兴趣的:(深度学习,pytorch,深度学习)