tensorflow获取模型节点名称及将.ckpt转为.pb文件

将tensorflow的ckpt模型转为pb文件, 需要知道网络的输出节点名称, 如果不指定输出节点名称, 程序就不知道该freeze哪些节点, 就没有办法保存模型.

获取ckpt模型中的节点名称

from tensorflow.python import pywrap_tensorflow
checkpoint_path = 'model.ckpt-xxx'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)

将模型转为pb模型

  import tensorflow as tf
  
  def model(input):
      net = tf.layers.conv2d(input,filters=32, kernel_size=3)
      net = tf.layers.batch_normalization(net, fused=False)
      net = tf.layers.separable_conv2d(net, 32, 3)
      net = tf.layers.conv2d(net, filters=32, kernel_size=3, name='output')
      return net
  
  input_node = tf.placeholder(tf.float32, [1, 480, 480, 3], name='image')
  pb = 'tftest.pb'

  with tf.Session() as sess:
      _1 = model(input_node)
      sess.run(tf.global_variables_initializer())
      output_node_names = 'output/BiasAdd'

      input_graph_def = tf.get_default_graph().as_graph_def()
      # node_names = [n.name for n  in input_graph_def.node]
      # for node in node_names:
      #     print (node)
      output_graph_def = tf.graph_util.convert_variables_to_constants(
              sess,  # The session
              input_graph_def,  # input_graph_def is useful for retrieving the nodes
              output_node_names.split(",")
      )
    
  with tf.gfile.GFile(pb, 'wb') as f:
      f.write(output_graph_def.SerializeToString())

将已有的ckpt模型转为pb

	import tensorflow as tf
	from tensorflow.python.framework import graph_util
	from tensorflow.python.platform import gfile
	
	def freeze_graph(ckpt, output_graph):
	    output_node_names = 'head_neck_count/backbone/conv5_4_cpm_l1_upsample_output_cpm/BiasAdd'
	    saver = tf.train.import_meta_graph(ckpt+'.meta', clear_devices=True)
	    graph = tf.get_default_graph()
	    input_graph_def = graph.as_graph_def()
	
	    with tf.Session() as sess:
	        saver.restore(sess, ckpt)
	        output_graph_def = graph_util.convert_variables_to_constants(
	            sess=sess,
	            input_graph_def=input_graph_def,
	            output_node_names=output_node_names.split(',')
	        )
	        with tf.gfile.GFile(output_graph, 'wb') as fw:
	            fw.write(output_graph_def.SerializeToString())
	        print ('{} ops in the final graph.'.format(len(output_graph_def.node)))
	
	ckpt = '/home/ulsee/server/onestage_v1/model.ckpt-586450'
	pb = '/home/ulsee/work/estimator-headneck/pb/head_count_v2.pb'
	
	if __name__ == '__main__':
	    freeze_graph(ckpt, pb)

节点名称和张量名称区别

  • 类似于output是节点名称
  • 类似于output:0是张量名称

摘抄自blog的使用pb模型进行预测的代码:

def freeze_graph_test(pb_path, image_path):
    '''
    :param pb_path:pb文件的路径
    :param image_path:测试图片的路径
    :return:
    '''
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()
        with open(pb_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
            tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

        # 定义输入的张量名称,对应网络结构的输入张量
        # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
        input_image_tensor = sess.graph.get_tensor_by_name("input:0")
        input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
        input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")

        # 定义输出的张量名称
        output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")

        # 读取测试图片
        im=read_image(image_path,resize_height,resize_width,normalization=True)
        im=im[np.newaxis,:]
        # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
        # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
        out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
                                                    input_keep_prob_tensor:1.0,
                                                    input_is_training_tensor:False})
        print("out:{}".format(out))
        score = tf.nn.softmax(out, name='pre')
        class_id = tf.argmax(score, 1)
        print "pre class_id:{}".format(sess.run(class_id))

你可能感兴趣的:(tensorflow)