tensorflow模型ckpt如何查看输入输出节点,以及转uff模型


如何查看ckpt的输入输出节点:

def getinout(input_checkpoint):
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    with tf.Session() as sess:
        file=open('./nodes.txt','a+')
    
        for n in tf.get_default_graph().as_graph_def().node:
            file.write(n.name + '\n')
        
        file.close()

n.name就是各个节点的名称,怎么找到呢?例子:
img_in---------------------》这个是输入节点
hh/conv1_1/conv2d/w/Initializer/truncated_normal/shape
hh/conv1_1/conv2d/w/Initializer/truncated_normal/mean
hh/res3_1/dw_bn/moving_mean/read
hh/fc4/w/Regularizer/l2_regularizer/L2Loss
hh/fc4/w/Regularizer/l2_regularizer
hh/fc4/fc4/Conv2D  ----------》这个就是输出节点
Reshape/shape  -----------》这个是神经网络输出后的操作步骤
Reshape
Reshape_1/shape
Reshape_1
Const
Mul
Mul_1
huber_loss/Sub

输入输出节点完全是跟你的网络定义顺序密切关联的,仔细核对就能找到。

将ckpt转换为pb:

def convt_ckpt_pb(checkpoint, metapath, output_graph):
    output_node_names = "one,two"
    # create a session
    sess = tf.Session()

    # import best model
    saver = tf.train.import_meta_graph(metapath + '.meta', clear_devices=True)
    saver.restore(sess, input_metapath) # variables

    # get graph definition
    gd = sess.graph.as_graph_def()

    # generate protobuf
    converted_graph_def = graph_util.convert_variables_to_constants(sess, gd, output_node_names.split(","))
    tf.train.write_graph(converted_graph_def, input_checkpoint + '\\', 'frozen_model.pb', as_text=False)
    tf.train.write_graph(converted_graph_def, input_checkpoint + '\\', 'frozen_model.pbtxt', as_text=True)

 

查看 输入输出节点:

cd c:/tensorflow131/source/tensorflow

bazel build tools/graph_transforms:summarize_graph

build完之后:

cd C:/_bazel_****/olqihtlt/execroot/org_tensorflow/bazel-out/x64_windows-opt/bin/tensorflow/tools/graph_transforms

summarize_graph --in_graph=D:/project/convert_model/frozen_model.pb

 

将pb模型转换为uff问题处理:
raise UffException("Transpose permutation has op " + str(tf_permutation_node.op) + ", expected Const. Only constant permuations are supported in UFF.")
uff.model.exceptions.UffException: Transpose permutation has op ConcatV2, expected Const. Only constant permuations are supported in UFF.怎么解决呢?
我们打印模型的输出节点.(在转换uff的时候打印)

import graphsurgeon as gs
import tensorflow as tf
import uff

if __name__ == "__main__":
  # USER DEFINED VALUES
  output_nodes = [“one","two"]
  input_node   = ["in","_in"]
  pb_file      = "./frozen_model.pb"
  uff_file     = "./frozen_model.uff"
  # END USER DEFINED VALUES

  # read tensorflow graph
  # NOTE: Make sure to freeze and optimize (remove training nodes, etc.)
  dynamic_graph = gs.DynamicGraph(pb_file)
  nodes=[n.name for n in dynamic_graph.as_graph_def().node]
  print(nodes)    # 在这一行打印
  ns={}
  for node in nodes:
    # replace LeakyRelu with default TRT plugin LReLU_TRT
    if "LeakyRelu" in node:
      ns[node]=gs.create_plugin_node(name=node,op="LReLU_TRT", negSlope=0.1)
    # replace Maximum with L2Norm_Helper_TRT max operation (CUDA's fmaxf)
    # if node == "orientation/l2_normalize/Maximum":
    if node == "embeddings/Maximum":
      ns[node]=gs.create_plugin_node(name=node,op="L2Norm_Helper_TRT",op_type=0,eps=1e-12)
    # replace Rsqrt with L2Norm_Helper_TRT max operation (CUDA's rsqrtf)
    if node == "embeddings/Rsqrt":
      ns[node]=gs.create_plugin_node(name=node,op="L2Norm_Helper_TRT",op_type=1)
  dynamic_graph.collapse_namespaces(ns)
  # write UFF to file
  uff_model = uff.from_tensorflow(dynamic_graph.as_graph_def(), output_nodes=output_nodes,
                                  output_filename=uff_file, text=False)


发现tf.layers.dense中会用到Transpose。那为什么这里会用到呢?因为,在dense的输入中,其shape类似于(?,1,1,128)。这才是问题的核心。
为了使得tf.layers.dense不调用Transpose,我们就需要把shape搞成(?,128)就行了。所以对于,tf.layers.dense的输入,在前面需要做一下tf.reshape(input, (-1,128))处理。
重新训练得到模型,按照上面的方法,将ckpt转为pb,然后将pb转成uff就可以解决问题了。

 

关注他,获取更多干货

tensorflow模型ckpt如何查看输入输出节点,以及转uff模型_第1张图片


 

你可能感兴趣的:(深度学习)