Tensorflow转ONNX - it610.com
转PB文件(Graph Freezing)
如上所述,常规TF的模型导出方法会将网络信息与权重信息分开存储在不同文件当中,这在部署时候不是很方便。官方提供了一种Freeze Graph的方式,用于将模型相关信息统统打包到一个*.pb文件当中。
官方提供了相关工具freeze_graph,一般安装完TensorFlow后会自动添加到用户PATH相应的bin目录下,如果没有找到的话可以去TensorFlow源码tensorflow/python/tools/free_graph.py这个位置去找一下,或者直接通过命令行导入module的方式调用。
举例如下,如果有多个输出节点,用逗号隔开:
# 1.直接调用 freeze_graph --input_graph=/home/mnist-tf/graph.proto \ --input_checkpoint=/home/mnist-tf/ckpt/model.ckpt \ --output_graph=/tmp/frozen_graph.pb \ --output_node_names=fc2/add \ --input_binary=True # 2. 通过调用module的方式 python -m tensorflow.python.tools.freeze_graph \ --input_graph=my_checkpoint_dir/graphdef.pb \ --input_binary=true \ --output_node_names=output \ --input_checkpoint=my_checkpoint_dir \ --output_graph=tests/models/fc-layers/frozen.pb
其中有一个比较困扰的点在于需要准确的知道输出节点的节点名称,我的做法是通过tf.get_default_graph().as_graph_def().node来得到各节点信息,然后再从中查看具体的输出节点名。
print([tensor.name for tensor in tf.get_default_graph().as_graph_def().node])
onnx 1.10.1
onnx-simplifier 0.3.6
onnxoptimizer 0.2.6
onnxruntime 1.8.1
完整的SSD-tf版本转换为onnx示例:
https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/ConvertingSSDMobilenetToONNX.ipynb
本文所示示例为superpoint
https://github.com/rpautrat/SuperPoint
magic-point网络训练模拟数据input(h,w,c)=(120,160,1)
magic-point网络导出coco数据的GT,使用input(h,w,c)=(240,320,1)
super-point网络训练coco数据input(h,w,c)=(240,320,1)
super-point网络inference的时候使用input(h,w,c)=(480,640,1)
项目作者给出了网络三个阶段保存的模型
MagicPoint (synthetic) mp_synth-v11
MagicPoint (COCO) mp_synth-v11_ha1_trained
SuperPoint (COCO) sp_v6
我们通过导出onnx已经确定了他们的input和output
mp_synth-v11
input: image
output: pred logits prob prob_nms
mp_synth-v11_ha1_trained
input: image
output: pred logits prob prob_nms
sp_v6
input: image
output: descriptors prob_nms prob descriptors_raw pred logits
参考:
1.模型checkpoint格式
https://github.com/rpautrat/SuperPoint/tree/master/pretrained_models
1.5 如果你的模型是frozen_inference_graph.pb格式,就需要且不知道input output的name,
就比较麻烦,这里我们不考虑这种情况
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md
2.将checkpoint格式转换为saved model格式(pb)
https://github.com/rpautrat/SuperPoint/blob/master/superpoint/export_model.py
3.将saved model格式转换为onnx
https://github.com/onnx/tensorflow-onnx
记得选择适合自己网络的opset
可参考:
https://github.com/microsoft/onnxruntime/blob/v1.4.0/docs/OperatorKernels.md
https://github.com/onnx/tensorflow-onnx/blob/master/support_status.md
python -m tf2onnx.convert --saved-model saved_models/model/ --output saved_models/model/saved_model.onnx --opset 11
4.得到onnx后发现,无法正常显示网络输入和网络输出
5.安装onnx
通过onnx打印网络所有的input和output
import onnx
name = "X.onnx"
onnx_file = name
model = onnx.load(onnx_file)
for m in range(len(model.graph.input)):
print(model.graph.input[m].name)
print("\n")
for m in range(len(model.graph.output)):
print(model.graph.output[m].name)
6.下载onnx-simplifier
https://github.com/daquexian/onnx-simplifier
5.通过onnx-simplifier,指定input=(N,H,W,C) ps: data_format='channels_last'
导出simplifier后的onnx,可看到网络每层的shape
import onnx
from onnxsim import simplify
onnx_file = "mp_synth-v11.onnx"
sim_onnx_path = "mp_synth-v11__simplify.onnx"
model = onnx.load(onnx_file)
onnx.checker.check_model(model)
# print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))
# 列表里的 shape 改成自己对应的
model_simp, check = simplify(model, input_shapes={"image":[1, 120, 160, 1]})
onnx.save(model_simp, sim_onnx_path)
print("Simplify onnx done !")
参考:
https://blog.csdn.net/tangshopping/article/details/111874321
其它一些参考
ONNX学习笔记 - 知乎
如何在Netron中看到每一层的shape_BokyLiu的博客-CSDN博客
【模型加速】PointPillars模型TensorRT加速实验(2)_昌山小屋的博客-CSDN博客
onnx2pytorch和onnx-simplifer新版介绍 - 知乎
torch.onnx — PyTorch 1.9.1 documentation
torch.onnx — PyTorch 1.9.1 documentation
GitHub - microsoft/onnxjs: ONNX.js: run ONNX models using JavaScript
GitHub - microsoft/onnxjs-demo: demos to show the capabilities of ONNX.js
superpoint导出pb
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "7"
import yaml
import argparse
import logging
from pathlib import Path
logging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
import tensorflow as tf # noqa: E402
from superpoint.models import get_model # noqa: E402
if __name__ == '__main__':
with open("config.yml", 'r') as f:
config = yaml.load(f)
config['model']['data_format'] = 'channels_last'
def mkdir_os(path):
if not os.path.exists(path):
os.makedirs(path)
export_root_dir = "onnx/"
mkdir_os(export_root_dir)
export_dir = "onnx/pb_model_onlytwo_NoneH640W480C1"
checkpoint_path = "model"
'''
data_shape
superpoint/models/base_model.py
# Prediction network with feed_dict
if self.data_shape is None:
self.data_shape = {i: spec['shape'] for i, spec in self.input_spec.items()}
self.pred_in = {i: tf.placeholder(spec['type'], shape=self.data_shape[i], name=i)
for i, spec in self.input_spec.items()}
self._pred_graph(self.pred_in)
'''
with get_model(config['model']['name'])(
data_shape={'image': [None, 640, 480, 1]},
**config['model']) as net:
net.load(str(checkpoint_path))
# tf.saved_model.simple_save(
# net.sess,
# str(export_dir),
# inputs=net.pred_in,
# outputs=net.pred_out)
# name:logits
# name:prob
# name:descriptors_raw
# name:descriptors
# name:prob_nms
# name:pred
tf.saved_model.simple_save(
net.sess,
str(export_dir),
inputs=net.pred_in,
outputs={"prob_nms": net.pred_out["prob_nms"],
"descriptors": net.pred_out["descriptors"],
})
此时Batch还是1
后面onnx流程中就被固定住了
所以在:
pd转onnx的时候:
Dynamic Input Reshape Incorrect · Issue #1640 · onnx/tensorflow-onnx · GitHub
Dynamic Input Reshape Incorrect · Issue #8591 · microsoft/onnxruntime · GitHub