superpoint的onnx转换全流程

Tensorflow转ONNX - it610.com

转PB文件(Graph Freezing)

如上所述,常规TF的模型导出方法会将网络信息与权重信息分开存储在不同文件当中,这在部署时候不是很方便。官方提供了一种Freeze Graph的方式,用于将模型相关信息统统打包到一个*.pb文件当中。

官方提供了相关工具freeze_graph,一般安装完TensorFlow后会自动添加到用户PATH相应的bin目录下,如果没有找到的话可以去TensorFlow源码tensorflow/python/tools/free_graph.py这个位置去找一下,或者直接通过命令行导入module的方式调用。

举例如下,如果有多个输出节点,用逗号隔开:

# 1.直接调用
freeze_graph --input_graph=/home/mnist-tf/graph.proto \
    --input_checkpoint=/home/mnist-tf/ckpt/model.ckpt \
    --output_graph=/tmp/frozen_graph.pb \
    --output_node_names=fc2/add \
    --input_binary=True

# 2. 通过调用module的方式
python -m tensorflow.python.tools.freeze_graph \
    --input_graph=my_checkpoint_dir/graphdef.pb \
    --input_binary=true \
    --output_node_names=output \
    --input_checkpoint=my_checkpoint_dir \
    --output_graph=tests/models/fc-layers/frozen.pb

其中有一个比较困扰的点在于需要准确的知道输出节点的节点名称,我的做法是通过tf.get_default_graph().as_graph_def().node来得到各节点信息,然后再从中查看具体的输出节点名。

print([tensor.name for tensor in tf.get_default_graph().as_graph_def().node])

onnx                  1.10.1
onnx-simplifier       0.3.6
onnxoptimizer         0.2.6
onnxruntime           1.8.1

完整的SSD-tf版本转换为onnx示例:
https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/ConvertingSSDMobilenetToONNX.ipynb

本文所示示例为superpoint
https://github.com/rpautrat/SuperPoint


magic-point网络训练模拟数据input(h,w,c)=(120,160,1)
magic-point网络导出coco数据的GT,使用input(h,w,c)=(240,320,1)
super-point网络训练coco数据input(h,w,c)=(240,320,1)
super-point网络inference的时候使用input(h,w,c)=(480,640,1)


项目作者给出了网络三个阶段保存的模型
MagicPoint (synthetic) mp_synth-v11
MagicPoint (COCO) mp_synth-v11_ha1_trained
SuperPoint (COCO) sp_v6
我们通过导出onnx已经确定了他们的input和output
mp_synth-v11
input: image
output: pred  logits  prob  prob_nms
mp_synth-v11_ha1_trained
input: image
output: pred  logits  prob  prob_nms
sp_v6
input: image
output: descriptors  prob_nms  prob  descriptors_raw pred logits

参考:
1.模型checkpoint格式
https://github.com/rpautrat/SuperPoint/tree/master/pretrained_models


1.5 如果你的模型是frozen_inference_graph.pb格式,就需要且不知道input output的name,
就比较麻烦,这里我们不考虑这种情况
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md


2.将checkpoint格式转换为saved model格式(pb)
https://github.com/rpautrat/SuperPoint/blob/master/superpoint/export_model.py


3.将saved model格式转换为onnx
https://github.com/onnx/tensorflow-onnx
记得选择适合自己网络的opset
可参考:
https://github.com/microsoft/onnxruntime/blob/v1.4.0/docs/OperatorKernels.md
https://github.com/onnx/tensorflow-onnx/blob/master/support_status.md

python -m tf2onnx.convert --saved-model saved_models/model/ --output saved_models/model/saved_model.onnx --opset 11


4.得到onnx后发现,无法正常显示网络输入和网络输出


5.安装onnx
通过onnx打印网络所有的input和output

import onnx

name = "X.onnx"
onnx_file = name
model = onnx.load(onnx_file)
for m in range(len(model.graph.input)):
    print(model.graph.input[m].name)

print("\n")

for m in range(len(model.graph.output)):
    print(model.graph.output[m].name)


6.下载onnx-simplifier
https://github.com/daquexian/onnx-simplifier


5.通过onnx-simplifier,指定input=(N,H,W,C) ps: data_format='channels_last'
导出simplifier后的onnx,可看到网络每层的shape

import onnx
from onnxsim import simplify



onnx_file = "mp_synth-v11.onnx"
sim_onnx_path = "mp_synth-v11__simplify.onnx"

model = onnx.load(onnx_file)
onnx.checker.check_model(model)

# print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

# 列表里的 shape 改成自己对应的
model_simp, check = simplify(model, input_shapes={"image":[1, 120, 160, 1]}) 

onnx.save(model_simp, sim_onnx_path)

print("Simplify onnx done !")


参考:
https://blog.csdn.net/tangshopping/article/details/111874321

其它一些参考

ONNX学习笔记 - 知乎

如何在Netron中看到每一层的shape_BokyLiu的博客-CSDN博客

【模型加速】PointPillars模型TensorRT加速实验(2)_昌山小屋的博客-CSDN博客

onnx2pytorch和onnx-simplifer新版介绍 - 知乎

torch.onnx — PyTorch 1.9.1 documentation

torch.onnx — PyTorch 1.9.1 documentation

GitHub - microsoft/onnxjs: ONNX.js: run ONNX models using JavaScript

GitHub - microsoft/onnxjs-demo: demos to show the capabilities of ONNX.js

superpoint导出pb

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "7"
import yaml
import argparse
import logging
from pathlib import Path

logging.basicConfig(format='[%(asctime)s %(levelname)s] %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
import tensorflow as tf  # noqa: E402

from superpoint.models import get_model  # noqa: E402

if __name__ == '__main__':


    with open("config.yml", 'r') as f:
        config = yaml.load(f)
    config['model']['data_format'] = 'channels_last'

    def mkdir_os(path):
        if not os.path.exists(path):
            os.makedirs(path)

    export_root_dir = "onnx/"
    mkdir_os(export_root_dir)
    export_dir = "onnx/pb_model_onlytwo_NoneH640W480C1"
    checkpoint_path = "model"

    '''
    data_shape
    superpoint/models/base_model.py

    # Prediction network with feed_dict
    if self.data_shape is None:
        self.data_shape = {i: spec['shape'] for i, spec in self.input_spec.items()}
    self.pred_in = {i: tf.placeholder(spec['type'], shape=self.data_shape[i], name=i)
                    for i, spec in self.input_spec.items()}
    self._pred_graph(self.pred_in)
    '''
    with get_model(config['model']['name'])(
            data_shape={'image': [None, 640, 480, 1]},
            **config['model']) as net:

        net.load(str(checkpoint_path))

        # tf.saved_model.simple_save(
        #         net.sess,
        #         str(export_dir),
        #         inputs=net.pred_in,
        #         outputs=net.pred_out)

        # name:logits
        # name:prob
        # name:descriptors_raw
        # name:descriptors
        # name:prob_nms
        # name:pred
        tf.saved_model.simple_save(
                net.sess,
                str(export_dir),
                inputs=net.pred_in,
                outputs={"prob_nms": net.pred_out["prob_nms"],
                         "descriptors": net.pred_out["descriptors"],
                })

此时Batch还是1

后面onnx流程中就被固定住了

所以在:

pd转onnx的时候:

Dynamic Input Reshape Incorrect · Issue #1640 · onnx/tensorflow-onnx · GitHub

Dynamic Input Reshape Incorrect · Issue #8591 · microsoft/onnxruntime · GitHub

你可能感兴趣的:(深度学习,tensorflow,深度学习,神经网络)