pytorch训练模型pth导出为onnx ,再简化onnx模型

import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model=torch.load("/home/shangzaixing/code/LaneNet-PyTorch-RNN/pytorch-crnn.pth")#pytorch模型加载

batch_size=1#批处理大小

input_shape=(1,1024,1024)#输入数据,改成自己的输入shape

# #set the model to inference mode

model.eval()

x=torch.randn(batch_size,*input_shape)#生产张量

x=x.to(device)

export_onnx_file="/home/shangzaixing/code/LaneNet-PyTorch-RNN/pytorch-crnn.onnx"#目标ONNX文件名

torch.onnx.export(model,

x,

export_onnx_file,

opset_version=10,

do_constant_folding=True,#是否执行常量折叠化

input_names=["input"],#输入名

output_names=["output1"],#输出名

operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,

dynamic_axes={"input":{0:"batch_size"},#批处理变量

"output":{0:"batch_szie"}})

from onnxsim import simplify

import onnx

onnx_model = onnx.load(export_onnx_file)

model_simp,check =simplify(onnx_model)

assert check,"faild!"

onnx.save(model_simp,export_onnx_file)

print("sucess finished")

你可能感兴趣的:(深度学习,深度学习,python,pytorch)