将tensorflow的ckpt模型转为pb文件, 需要知道网络的输出节点名称, 如果不指定输出节点名称, 程序就不知道该freeze哪些节点, 就没有办法保存模型.
获取ckpt模型中的节点名称
from tensorflow.python import pywrap_tensorflow
checkpoint_path = 'model.ckpt-xxx'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
将模型转为pb模型
import tensorflow as tf
def model(input):
net = tf.layers.conv2d(input,filters=32, kernel_size=3)
net = tf.layers.batch_normalization(net, fused=False)
net = tf.layers.separable_conv2d(net, 32, 3)
net = tf.layers.conv2d(net, filters=32, kernel_size=3, name='output')
return net
input_node = tf.placeholder(tf.float32, [1, 480, 480, 3], name='image')
pb = 'tftest.pb'
with tf.Session() as sess:
_1 = model(input_node)
sess.run(tf.global_variables_initializer())
output_node_names = 'output/BiasAdd'
input_graph_def = tf.get_default_graph().as_graph_def()
# node_names = [n.name for n in input_graph_def.node]
# for node in node_names:
# print (node)
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session
input_graph_def, # input_graph_def is useful for retrieving the nodes
output_node_names.split(",")
)
with tf.gfile.GFile(pb, 'wb') as f:
f.write(output_graph_def.SerializeToString())
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
def freeze_graph(ckpt, output_graph):
output_node_names = 'head_neck_count/backbone/conv5_4_cpm_l1_upsample_output_cpm/BiasAdd'
saver = tf.train.import_meta_graph(ckpt+'.meta', clear_devices=True)
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess, ckpt)
output_graph_def = graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=input_graph_def,
output_node_names=output_node_names.split(',')
)
with tf.gfile.GFile(output_graph, 'wb') as fw:
fw.write(output_graph_def.SerializeToString())
print ('{} ops in the final graph.'.format(len(output_graph_def.node)))
ckpt = '/home/ulsee/server/onestage_v1/model.ckpt-586450'
pb = '/home/ulsee/work/estimator-headneck/pb/head_count_v2.pb'
if __name__ == '__main__':
freeze_graph(ckpt, pb)
节点名称和张量名称区别
output
是节点名称output:0
是张量名称摘抄自blog的使用pb模型进行预测的代码:
def freeze_graph_test(pb_path, image_path):
'''
:param pb_path:pb文件的路径
:param image_path:测试图片的路径
:return:
'''
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 定义输入的张量名称,对应网络结构的输入张量
# input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
input_image_tensor = sess.graph.get_tensor_by_name("input:0")
input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
# 定义输出的张量名称
output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
# 读取测试图片
im=read_image(image_path,resize_height,resize_width,normalization=True)
im=im[np.newaxis,:]
# 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
# out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
input_keep_prob_tensor:1.0,
input_is_training_tensor:False})
print("out:{}".format(out))
score = tf.nn.softmax(out, name='pre')
class_id = tf.argmax(score, 1)
print "pre class_id:{}".format(sess.run(class_id))