贴一下官方的代码示意地址:ONNX动态输入
#首先我们要有个tensor输入,比如网络的输入是batch_size*1*224*224
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
#torch_model是模型的实例化
torch_out = torch_model(x)
#下面是导出的主要函数
# Export the model
torch.onnx.export(torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
"super_resolution.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes
'output' : {0 : 'batch_size'}})
dynamic_axes = {'input' : {0 : 'batch_size'},
'output' : {0 : 'batch_size'}}
dynamic_axes=dynamic_axes
赋值一下就OK了。比如下面的:
pillar_x = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0")
pillar_y = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0")
pillar_z = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0")
pillar_i = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0")
input = [pillar_x, pillar_y, pillar_z, pillar_i]
dynamic_axes = {'Pillar_input_pillar_x':{2:'pillar_num'},
'Pillar_input_pillar_y': {2: 'pillar_num'},
'Pillar_input_pillar_z': {2: 'pillar_num'},
'Pillar_input_pillar_i': {2: 'pillar_num'},
'output_loss1':{0:'batch_size'},
'output_loss2': {0: 'batch_size'},
'output_loss3': {0: 'batch_size'}}
torch.onnx.export(net,input,'test1.onnx',verbose=True,input_names=['Pillar_input_pillar_x','Pillar_input_pillar_y','Pillar_input_pillar_z', 'Pillar_input_pillar_i'],
output_names=['output_loss1','output_loss2','output_loss3'],dynamic_axes=dynamic_axes)