Pytorch模型转ONNX

参考https://blog.csdn.net/qq_37546267/article/details/106767640 利用下面代码将pth模型转换为onnx

import torch
from torch.autograd import Variable
import onnx
print(torch.__version__)
# torch  -->  onnx
input_name = ['input']
output_name = ['output']
input = Variable(torch.randn(1, 3, 224, 224)).cuda()
# model = torchvision.models.resnet50(pretrained=True).cuda()
model = torch.load('resnet50.pth', map_location="cuda:0")
torch.onnx.export(model, input, 'resnet50.onnx', input_names=input_name, output_names=output_name, verbose=True)
# 模型可视化
# netron.start('resnet50.onnx')

转换时出现 AttributeError: 'dict' object has no attribute 'training' 的问题。

通过查看https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html 关于如何使用pytroch到处到onnx,和https://pytorch.org/tutorials/beginner/saving_loading_models.html 关于torch.save和torch.load的使用。找到了之前代码中使用的是

torch.save(model.state_dict(), PATH)

这种保存的是模型的字典。

model = torch.load(PATH)

再使用torch.load时,会报错。

所以,在保存模型是,使用:

torch.save(model, PATH)

再运行最上面的转换代码。可以正常工作。

你可能感兴趣的:(大数据)