PyTorch模型转换为ONNX格式模型

测试PyTorch版本:1.2.0

测试代码:

import torch.onnx

model        = ******** # model类型:torch.nn.Module
dummy_input  = torch.randn(1, 3, 224, 224) # 模型输入维度为(1, 3, 224, 224)
input_names  = ["input"]
output_names = ["output"]
torch.onnx.export(model, dummy_input, "model_onnx.onnx", verbose=True,
                        input_names=input_names, output_names=output_names)

 

你可能感兴趣的:(pytorch,ONNX)