onnx重写输入和输出的维度

onnx模型输入是静态的,比如是1x3x960x960,但是想把输入改成动态输入,相应的输出也得改成动态,以下代码可以修改onnx模型的维度:

import onnx
import onnx.checker
import onnx.utils
from onnx.tools import update_model_dims
 
 
model = onnx.load('infer_rec.onnx')
# 此处可以理解为获得了一个维度 “引用”,通过该 “引用“可以修改其对应的维度                                                                                          
dim_proto0 = model.graph.input[0].type.tensor_type.shape.dim[0]
dim_proto3 = model.graph.input[0].type.tensor_type.shape.dim[3]
# 将该维度赋值为字符串,其维度不再为和dummy_input绑定的值
dim_proto0.dim_param = '1'
dim_proto3.dim_param = 'width'
dim_proto_0 = model.graph.output[0].type.tensor_type.shape.dim[0]
dim_proto_1 = model.graph.output[0].type.tensor_type.shape.dim[1]
dim_proto_0.dim_param = '1'
dim_proto_1.dim_param = 'width'
onnx.save(model, 'infer_rec_dynamic.onnx')

这样修改以后,就变成动态输入输出的模型了。

修改前:

onnx重写输入和输出的维度_第1张图片

修改后: 

onnx重写输入和输出的维度_第2张图片

但是转onnx需要主要指定 --opset_version 的版本为11,官方的文档解释如下:

Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. Each dimension value of the output tensor is: output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input "sizes" is not specified.

Version

This version of the operator has been available since version 11 of the default ONNX operator set.

如果不指定版本为11,有可能onnx转其他推理框架的模型可能会失败,这个地方需要注意。

当然,如果是自己通过pytorch代码也可以直接重新生成onnx为动态输入或动态输出模型,生成过程如下:

torch.onnx.export(rec_model, rec_input, "crnn.onnx",

dynamic_axes={'input' : {3 : 'width'},'output':{1:'width'}},

do_constant_folding=False,training=False,export_params=True)

参考:pytorch模型转onnx后修改模型输入维度 - 知乎 (zhihu.com)

你可能感兴趣的:(深度学习,pytorch)