主要是使用torch.onnx.export()这个方法来实现。
Unet的实现参考:
链接: https://blog.csdn.net/weixin_44791964/article/details/108866828.
这位博主写的很详细,b站还有实现视频,手把手教学!!!
import onnx
import torch.onnx
from unet import Unet
unet=Unet()
print(unet)
model = torch.load('D:/PycharmProjects/test/unet-pytorch-main/model_data/Epoch19-Total_Loss0.5564-Val_Loss0.5045.pth',map_location='cpu')
unet.net.load_state_dict(model)
unet.net.eval()
print(unet.net)
print(torch.__version__)
#batch_size:1,model_image_size:3*512*512
dummy_input = torch.randn(1,3, 512, 512) # 你模型的输入 NCHW
#关键参数 verbose=True 会使导出过程中打印出该网络的可读表示
torch.onnx.export(unet.net, dummy_input,'unetmodel.onnx',verbose=True,opset_version=11)
onnx_model = onnx.load('./unetmodel.onnx') # load onnx model
onnx.checker.check_model(onnx_model) # check onnx model
print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
注:运行出错
这是我在刚开始运行的时候出现的问题。
完整错误信息
D:\Anaconda3\envs\superglue\lib\site-packages\torch\nn\functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ..\c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
D:\Anaconda3\envs\superglue\lib\site-packages\torch\onnx\symbolic_helper.py:375: UserWarning: You are trying to export the model with onnx:Upsample for ONNX opset version 9. This operator might cause results to not match the expected results by PyTorch.
ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. Attributes to determine how to transform the input were added in onnx:Resize in opset 11 to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).
We recommend using opset 11 and above for models using this operator.
"" + str(_export_onnx_opset_version) + ". "
D:\Anaconda3\envs\superglue\lib\site-packages\torch\onnx\symbolic_helper.py:243: UserWarning: ONNX export failed on upsample_bilinear2d because align_corners == True not supported
warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
原因:
转换ONNX使用的版本较低,PyTorch.ONNX不支持。另外,参考源码, torch.onnx.export 默认使用 opset_version=9。
解决办法:
在export中设定opset_versiond的值为11。
torch.onnx.export(unet.net, dummy_input,'unetmodel.onnx',verbose=True,opset_version=11)