Pytorch模型转成onnx并可视化

文章目录

  • 转换模型
    • 前提
    • 转换方法
  • 模型可视化
  • 可能出现的报错信息
    • ValueError: torch.nn.DataParallel is not supported by ONNX exporter, please use 'attribute' module to unwrap model from torch.nn.DataParallel. Try torch.onnx.export(model.module, ...)
    • RuntimeError: ONNX export failed on an operator with unrecognized namespace torchvision::roi_align. If you are trying to export a custom operator, make sure you registered it with the right domain and version.
  • 参考资料

本文介绍 pth 模型转为 onnx 模型、使用 onnx 模型进行可视化以及过程中可能出现的问题。

转换模型

前提

首先,你需要有一个自己的pytorch格式的模型,通常的后缀为.pth。
可以通过

save_path = './Mlp.pth'
torch.save(net.state_dict(), save_path)

进行保存。

可以通过

net.load_state_dict(torch.load('Mlp.pth'))

进行加载参数使用。

当然如果你只想可视化的话,不需要训练得到pth参数,只要有模型就可以。

转换方法

import torch.nn

model = MLP().cuda()	# 声明模型
model.load_state_dict(torch.load('Mlp.pth'))	# 加载参数文件(可以没有)
model.eval()

input_names = ['input']
output_names = ['output']
# 自己起名字

x = torch.randn(1,3,512,512,requires_grad=True)
# 这里要把握住网络的输入大小,如果模型是在gpu上进行训练,则将x变为
# x = torch.randn(1,3,128,128,requires_grad=True,device="cuda")

torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')

如果你的模型需要传入两个参数的话,那就再声明一个 y 变量,和 x 一起以元组的方式 (x, y) 传入.export()方法中。

这样,你就得到了一个onnx模型。
过程中可能出现的问题见最后一章。

模型可视化

import netron
modelPath = "best.onnx"
netron.start(modelPath)
# 先安装netron模块(pip就可以)
# 这里加载的模型可以是torch,也可以是onnx

可视化会进入 http://localhost:8080/ 网址。
效果大致如下图:
Pytorch模型转成onnx并可视化_第1张图片

可能出现的报错信息

ValueError: torch.nn.DataParallel is not supported by ONNX exporter, please use ‘attribute’ module to unwrap model from torch.nn.DataParallel. Try torch.onnx.export(model.module, …)

这是因为你的模型使用了 DataParallel 包装。
只需要按照报错信息修改为如下

torch.onnx.export(model.module, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')

即可。

RuntimeError: ONNX export failed on an operator with unrecognized namespace torchvision::roi_align. If you are trying to export a custom operator, make sure you registered it with the right domain and version.

修改为

torch.onnx.export(model.module, (x, y), 'best.onnx', input_names=input_names, output_names=output_names, verbose='True', opset_version=11)

即添加了 opset_version 参数。

类似的报错都可以去查看一下官方文档的版本信息。(查看方法和文档网址见参考资料一章)

参考资料

模型部署入门教程(三):PyTorch 转 ONNX 详解
onnx/docs/Operators.md
Pytorch模型转onnx,模型可视化
UserWarning: You are trying to export the model with onnx:Upsample for ONNX opset version 9
onnx.export报警告:WARNING: The shape inference of prim::Constant type is missing…解决方法

你可能感兴趣的:(深度学习,pytorch,深度学习,神经网络,onnx,可视化)