*.pb,官方描述如下:
GraphDef(.pb)-a protobuf that represents the Tensorflow training and or computation graph. This contains operators, tensors, and variables definitions.
FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint.
这里可以简单理解为*.pb文件有两种情况,一种是仅保存了计算图结构,不包含变量值,可以通过如下代码生成
tf.train.write_graph()
还有一种就是上面提到的FrozenGraphDef, 不仅包含计算图结构,还包含了训练产生的变量值,这类*.pb可以直接被加载用于推理计算,
1.将一个图直接保存为pb形式,这个在工作目录下保存了一个名为pb_file_pathmodel.pb的文件
import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
pb_file_path = os.getcwd()
with tf.Session(graph=tf.Graph()) as sess:
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
# 这里的输出需要加上name属性
op = tf.add(xy, b, name='op_to_store')#最终输出x*y+b, b的值默认是1
sess.run(tf.global_variables_initializer())
# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['x','op_to_store'])
# 测试 OP
feed_dict = {x: 10, y: 3}
print(sess.run(op, feed_dict))
print(pb_file_path)
# 写入序列化的 PB 文件
with tf.gfile.GFile(pb_file_path + 'model.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
# 输出
# INFO:tensorflow:Froze 1 variables.
# Converted 1 variables to const ops.
# 31
2.刚才保存好了一个pb文件,现在我们测试下他好不好用啊,这里有个弊端就是测试时,需要知道图中各个tensor的名字,你要是不知道这个图中各个tensor的名字,下面是没法测试的
from tensorflow.python.platform import gfile
import tensorflow as tf
import os
pb_file_path = os.getcwd()#用于返回当前工作目录。
print(pb_file_path + 'model.pb')#YOLO_tiny-mastermodel.pb
sess = tf.Session()
with gfile.FastGFile(pb_file_path + 'model.pb', 'rb') as f:
graph_def = tf.GraphDef()#构造一个图
graph_def.ParseFromString(f.read())#在当前图中打开YOLO_tiny-mastermodel.pb
sess.graph.as_default()#将当前图设置为默认图
tf.import_graph_def(graph_def, name='') # 导入计算图
# 需要有一个初始化的过程
sess.run(tf.global_variables_initializer())
# 需要先复原变量
print(sess.run('b:0'))
# 1
# 输入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
output= sess.graph.get_tensor_by_name('op_to_store:0')
ret = sess.run(output, feed_dict={input_x: 5, input_y: 5})
print(ret)
# 输出 26
3. 为解决2中不知道节点名字的问题,可采用下面方法
3.1 先知道自己保存的ckpt文件都有那些节点
from tensorflow.python import pywrap_tensorflow
checkpoint_path = '/home/jerry/PY_project_wang/car_detect/YOLO_tiny-master/data/output/YOLO_tiny.ckpt-6'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
names=[]
for key in var_to_shape_map:
names.append(key)
print("tensor_name:", key)
print('len',len(key)) #tiny_yolo只有12个节点
3.2 先知道自己保存的pb文件都有那些节点
#r若你不知道原来的图中都有什么节点,可以先打印出来\
import tensorflow as tf
from tensorflow.python.platform import gfile
pb_path = './save/model.pb'
with tf.Session() as sess:
with gfile.FastGFile(pb_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
for i, n in enumerate(graph_def.node):
print("Name of the node -%s" % n.name)#
4. 将模型同时保存为ckpt和pb文件
import tensorflow as tf
from tensorflow.python.framework import graph_util
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
op = tf.add(xy, b, name='op_to_store')
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
sess.run(init_op)
saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件,若当前目录没有save文件夹,会自行创建
#tf.train.saver()保持模型的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同文件存储,这种方法是在TensorFlow中最常用的保存方式:
print("Model saved in file:", saver_path)
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
#convert_variables_to_constants()方法,可以固化模型结构,将计算图中的变量取值以常量的形式保存
with tf.gfile.FastGFile('save/model.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())#将模型保存到save/model.pb文件,
checkpoint是检查点的文件,文件保存了一个目录下所有的模型文件列表
model.ckpt.meta文件保存了Tensorflow计算图结构,可以理解为神经网络的网络结构,
ckpt.data是保存模型中每个变量的取值
model.ckpt.index保存了所有变量名
有了这三个文件,就能得到模型的信息并加载到其他项目中
5.从保存好的ckpt文件恢复网络,并保存为pb
从保存的ckpt模型读取所有的节点,可以发现ckpt不能保存所有的的变量
from tensorflow.python import pywrap_tensorflow
checkpoint_path = './save/model.ckpt'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name:", key)#b,为什么只有b呢,因为只有B是变量
6.比较复杂的网络模型的保存,用同样的方法
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
def model(input):
net = tf.layers.conv2d(input, filters=32, kernel_size=3)
net = tf.layers.batch_normalization(net, fused=False)
net = tf.layers.separable_conv2d(net, 32, 3)
net = tf.layers.conv2d(net, filters=32, kernel_size=3, name='output')
return net
input_node = tf.placeholder(tf.float32, [1, 480, 480, 3], name='image')
pb = 'save/model3.pb'
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
model = model(input_node)
sess.run(tf.global_variables_initializer())
# saver_path = saver.save(sess, "save/model.ckpt") # 这里就不可以保存ckpt文件了,会报No variables to save
# tf.train.saver()保持模型的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同文件存储,这种方法是在TensorFlow中最常用的保存方式:
output_node_names = 'output/BiasAdd'#为什么这里的名字与上面不一样?因为上面的output是层的名字,
# 这个层里面有很多的op,那怎么知道每个op的名字呢?不知道
input_graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names.split(','))
with tf.gfile.GFile(pb, 'wb') as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点,38 ops in the final graph.
print(output_graph_def.node)
checkpoint_path = './save/model.ckpt'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name:", key)#b,为什么只有b呢,因为只有B是变量