获取原网络中的所有节点
在训练代码中定义好图之后加入以下代码:
for node in tf.get_default_graph().as_graph_def().node:
print(node.name)
主要是要查看最后一个节点的名字
模型转化
不再重新建图时, 使用tf.train.import_meta_graph
def freeze_graph(input_checkpoint,output_graph):
'''
:param input_checkpoint:ckpt模型路径
:param output_graph: pb模型保存路径
'''
output_node_names = " " # 填入第一步得到的最后一个节点名
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) #恢复图并得到数据
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=sess.graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) # 统计图中总的操作节点数
或者修改前传代码,使用tf.train.Saver()
在前传代码里,restore模型
restorer = tf.train.Saver(tf.global_variables())
ckpt = tf.train.get_checkpoint_state(' ') # 填入ckpt模型所在文件夹路径
model_path = ckpt.model_checkpoint_path # 读取checkpoint文件里的第一行
with tf.Session() as sess:
# Create a saver.
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
try:
restorer.restore(sess, model_path)
print(model_path.split('/')[-1] + " restored!")
except IOError:
print("checkpoints not found.")
output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=sess.graph_def, # 等于:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(out_pb_path, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(output_graph_def.node))
# 统计图中总的操作节点数
从pb模型中读取节点
#coding:utf-8
import tensorflow as tf
from tensorflow.python.framework import graph_util
tf.reset_default_graph() # 重置计算图
output_graph_path = 'model/model_tfnew.pb'
with tf.Session() as sess:
tf.global_variables_initializer().run()
output_graph_def = tf.GraphDef()
# 获得默认的图
graph = tf.get_default_graph()
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
# 得到当前图有几个操作节点
print("%d ops in the final graph." % len(output_graph_def.node))
tensor_name = [tensor.name for tensor in output_graph_def.node]
print(tensor_name)
print('---------------------------')
# 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型
summaryWriter = tf.summary.FileWriter('log_graph/', graph)
for op in graph.get_operations():
# print出tensor的name和值
print(op.name, op.values())