onnx模型输入输出维度修改

场景

我有一个onnx模型,需要转换为rknn格式。
但是转换的脚本总是加载onnx模型时就出错了。

 load_onnx: ValueError: Invalid input_shape = [0, 128, 128, 1] for input 0!

看来是输入的维度问题。通过netron查看,可以看到输入输出的信息
onnx模型输入输出维度修改_第1张图片
可以看到右侧inputs和outputs的信息,第一个维度都是无效的,所以我需要把他们固定为1,因为我本来也不需要批量推理。

修改onnx输入输出维度的代码

import onnx
onnx_model = onnx.load("./tf_face_landmark.onnx")

print("============input============")
print(onnx_model.graph.input)

print("============output============")
print(onnx_model.graph.output)


onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = '1'

onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = '1'

print("============new input============")
print(onnx_model.graph.input)

print("============new output============")
print(onnx_model.graph.output)

onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, 'face_landmark.onnx')
print("模型已保存")

代码运行输出(unk__207和unk__208应该是引起上述保存的原因,改了它们就好了)

# python3 change_onnx_shape.py
============input============
[name: "image_batch"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "unk__207"
      }
      dim {
        dim_value: 128
      }
      dim {
        dim_value: 128
      }
      dim {
        dim_value: 1
      }
    }
  }
}
]
============output============
[name: "Logits_out/output"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_param: "unk__208"
      }
      dim {
        dim_value: 136
      }
    }
  }
}
]
============new input============
[name: "image_batch"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 128
      }
      dim {
        dim_value: 128
      }
      dim {
        dim_value: 1
      }
    }
  }
}
]
============new output============
[name: "Logits_out/output"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 136
      }
    }
  }
}
]
模型已保存

效果

onnx模型输入输出维度修改_第2张图片

吐槽

费了半天劲搭环境,解决onnx转换过程的各种问题,最后发现rknn-toolkit2-v1.3不支持onnx的hardswish算子,白干了。

最后,在rknn-toolkit-v1.7环境下将ckpt固化为pb文件,拷贝到rknn-toolkit2-v1.3下,再转换为rknn成功。
折腾!

参考资料

onnx修改模型输出节点属性

将tensorflow 1.x & 2.x转化成onnx文件(以arcface-tf2人脸识别模型为例)

你可能感兴趣的:(深度学习论文笔记和实践,python,onnx,tensorflow)