ONNX动态输入和动态输出问题

记录一下最近遇到的ONNX动态输入问题

1. 一个tensor的动态输入数据

首先是使用到的onnx的torch.onnx.export()函数:

贴一下官方的代码示意地址:ONNX动态输入

#首先我们要有个tensor输入,比如网络的输入是batch_size*1*224*224
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
#torch_model是模型的实例化
torch_out = torch_model(x)
#下面是导出的主要函数
# Export the model
torch.onnx.export(torch_model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "super_resolution.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                                'output' : {0 : 'batch_size'}})
上面我们主要是设置dynamic_axes的相关属性,这个属性的Key是从input_names和output_names里面获取的,所以在这两个里面一定要有相关的属性值,否则会有warning。这里面的batch_size则为动态输入的值,当然我们也可以在外面设置dynamic的属性,比如下面:
dynamic_axes = {'input' : {0 : 'batch_size'},   
                                'output' : {0 : 'batch_size'}}
然后在外面将dynamic_axes=dynamic_axes赋值一下就OK了。

2.多个tensor的动态输入问题

那么以上以只有一个tensor输入的情况我们进行的操作。
下面我们说一下有多个动态的tensor输入的情况下如何进行相关的操作:

比如下面的:

pillar_x = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0")
pillar_y = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0")
pillar_z = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0")
pillar_i = torch.ones([1, 1, 9918, 100], dtype=torch.float32, device="cuda:0")
加入上面的tensor是网络的一个输入,比如9918这个数据是我们网络输入的时候可能发生变化的情况,那么我们就要将其输入成动态输入的模式,详情如下:
input = [pillar_x, pillar_y, pillar_z, pillar_i]
dynamic_axes = {'Pillar_input_pillar_x':{2:'pillar_num'},
                'Pillar_input_pillar_y': {2: 'pillar_num'},
                'Pillar_input_pillar_z': {2: 'pillar_num'},
                'Pillar_input_pillar_i': {2: 'pillar_num'},
                'output_loss1':{0:'batch_size'},
                'output_loss2': {0: 'batch_size'},
                'output_loss3': {0: 'batch_size'}}
torch.onnx.export(net,input,'test1.onnx',verbose=True,input_names=['Pillar_input_pillar_x','Pillar_input_pillar_y','Pillar_input_pillar_z',                                                                    'Pillar_input_pillar_i'],
                                                                output_names=['output_loss1','output_loss2','output_loss3'],dynamic_axes=dynamic_axes)
我们单拿一组数据来说:dynamic_axes字典里面的item就是我们要设置动态输入数据,key是我们要动态输入的某一项数据,val的值是这一项数据中的哪一维度要设置成动态的。
那么我们dynamic_axes中的数据是要在input_names和output_names找到的,其实这两项只是数据的别名,然后根据dynamic里面的名字找到我们需要设置的具体动态项,也就是我们可以设置动态的输入,也可以设置动态的输出。
有一点要强调的是input_names中的序列是和input输入的数据相对应的。

实际效果如下:
ONNX动态输入和动态输出问题_第1张图片
ONNX动态输入和动态输出问题_第2张图片
ONNX动态输入和动态输出问题_第3张图片
以上。
目前正在研究怎么使用trt进行动态的数据调用做inference工作,待补充。

你可能感兴趣的:(Python)