【模型转化】修改onnx节点属性

    在之前的一篇文章【模型转换】onnx转tensorrt报错:Attribute not found: axes中提到squeeze操作在不明确指定axes参数时onnx转tensorrt会报错。解决办法也很简单,因为我是先将pytorch转onnx,再转的tensorrt。pytorch网络结构中,给squeeze操作指定好axes参数再重新生成onnx即可。实际上我们还可以借助onnx的helper功能直接修改onnx节点属性来解决这个问题。

【模型转化】修改onnx节点属性_第1张图片

    例如,我要将卷积层的输出[1x64x40000x1]通过Squeeze操作进行维度压缩,剔除维数为1的维度,输出tensor[64x40000]给Transpose层。

import onnx                                                                        
model = onnx.load('my.onnx')
for node_id,node in enumerate(model.graph.node):
    print("######%s######" % node_id)                                                                                                                                                                                                                                 
    print(node)                                                                   

通过简单操作,我们可以很快定位到onnx中Squeeze的基本信息。

        
######44######
input: "383"
input: "vfe.pfn_layers.0.conv3.weight"
output: "384"
name: "Conv_249"
op_type: "Conv"
attribute {
  name: "dilations"
  ints: 1
  ints: 3
  type: INTS
}           
attribute {
  name: "group"
  i: 1  
  type: INT
}           
attribute {
  name: "kernel_shape"
  ints: 1
  ints: 11
  type: INTS
}           
attribute {
  name: "pads"
  ints: 0
  ints: 0
  ints: 0
  ints: 0
  type: INTS
}           
attribute {
  name: "strides"
  ints: 1
  ints: 1
  type: INTS
}           
            
######45######
input: "384"
output: "385"
name: "Squeeze_250"
op_type: "Squeeze"
            
######46######
input: "385"
output: "386"
name: "Transpose_251"
op_type: "Transpose"
attribute {
  name: "perm"
  ints: 1
  ints: 0
  type: INTS
}           
            

 可见,相比前后的Conv层和Transpose层,Squeeze层压根就没有属性信息。这里通过onnx的helper功能给该节点加上属性,指定要降维的轴。我这里因为只是要修改第1个Squeeze,所以可以如下修改:

model = onnx.load('my.onnx')
for node_id,node in enumerate(model.graph.node):                                                                                                                                                                                                                      
    if node.op_type == "Squeeze":
        attr = onnx.helper.make_attribute("axes",[0,3])
        node.attribute.insert(0,attr)
        break

再看:

            
######45######
input: "384"
output: "385"
name: "Squeeze_250"
op_type: "Squeeze"
attribute {
  name: "axes"
  ints: 0
  ints: 3
  type: INTS
}           
            

【模型转化】修改onnx节点属性_第2张图片

你可能感兴趣的:(模型转换,onnx,onnx)