tensorflow保存模型为pb文件的各种方式

*.pb,官方描述如下:

GraphDef(.pb)-a protobuf that represents the Tensorflow training and or computation graph. This contains operators, tensors, and variables definitions.

FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint.

这里可以简单理解为*.pb文件有两种情况,一种是仅保存了计算图结构,不包含变量值,可以通过如下代码生成

tf.train.write_graph()

还有一种就是上面提到的FrozenGraphDef, 不仅包含计算图结构,还包含了训练产生的变量值,这类*.pb可以直接被加载用于推理计算,

1.将一个图直接保存为pb形式,这个在工作目录下保存了一个名为pb_file_pathmodel.pb的文件

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = os.getcwd()

with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    # 这里的输出需要加上name属性
    op = tf.add(xy, b, name='op_to_store')#最终输出x*y+b, b的值默认是1

    sess.run(tf.global_variables_initializer())

    # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['x','op_to_store'])

    # 测试 OP
    feed_dict = {x: 10, y: 3}
    print(sess.run(op, feed_dict))
    print(pb_file_path)

    # 写入序列化的 PB 文件
    with tf.gfile.GFile(pb_file_path + 'model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())


    # 输出
    # INFO:tensorflow:Froze 1 variables.
    # Converted 1 variables to const ops.
    # 31

2.刚才保存好了一个pb文件,现在我们测试下他好不好用啊,这里有个弊端就是测试时,需要知道图中各个tensor的名字,你要是不知道这个图中各个tensor的名字,下面是没法测试的

from tensorflow.python.platform import gfile
import tensorflow as tf
import os

pb_file_path = os.getcwd()#用于返回当前工作目录。
print(pb_file_path + 'model.pb')#YOLO_tiny-mastermodel.pb
sess = tf.Session()
with gfile.FastGFile(pb_file_path + 'model.pb', 'rb') as f:
    graph_def = tf.GraphDef()#构造一个图
    graph_def.ParseFromString(f.read())#在当前图中打开YOLO_tiny-mastermodel.pb
    sess.graph.as_default()#将当前图设置为默认图
    tf.import_graph_def(graph_def, name='')  # 导入计算图

# 需要有一个初始化的过程    
sess.run(tf.global_variables_initializer())

# 需要先复原变量
print(sess.run('b:0'))
# 1

# 输入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')

output= sess.graph.get_tensor_by_name('op_to_store:0')

ret = sess.run(output, feed_dict={input_x: 5, input_y: 5})
print(ret)
# 输出 26

3. 为解决2中不知道节点名字的问题,可采用下面方法

   3.1 先知道自己保存的ckpt文件都有那些节点

from tensorflow.python import pywrap_tensorflow
checkpoint_path = '/home/jerry/PY_project_wang/car_detect/YOLO_tiny-master/data/output/YOLO_tiny.ckpt-6'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
names=[]
for key in var_to_shape_map:
    names.append(key)
    print("tensor_name:", key)
print('len',len(key)) #tiny_yolo只有12个节点

3.2  先知道自己保存的pb文件都有那些节点

 #r若你不知道原来的图中都有什么节点,可以先打印出来\
import tensorflow as tf
from tensorflow.python.platform import gfile
pb_path = './save/model.pb'
with tf.Session() as sess:
    with gfile.FastGFile(pb_path, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
        for i, n in enumerate(graph_def.node):
            print("Name of the node -%s" % n.name)#

4. 将模型同时保存为ckpt和pb文件

import tensorflow as tf
from tensorflow.python.framework import graph_util
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
op = tf.add(xy, b, name='op_to_store')
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
    sess.run(init_op)
    saver_path = saver.save(sess, "save/model.ckpt")  # 将模型保存到save/model.ckpt文件,若当前目录没有save文件夹,会自行创建
    #tf.train.saver()保持模型的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同文件存储,这种方法是在TensorFlow中最常用的保存方式:
    print("Model saved in file:", saver_path)
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
    #convert_variables_to_constants()方法,可以固化模型结构,将计算图中的变量取值以常量的形式保存
    with tf.gfile.FastGFile('save/model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())#将模型保存到save/model.pb文件,

tensorflow保存模型为pb文件的各种方式_第1张图片

checkpoint是检查点的文件,文件保存了一个目录下所有的模型文件列表

model.ckpt.meta文件保存了Tensorflow计算图结构,可以理解为神经网络的网络结构,

ckpt.data是保存模型中每个变量的取值

model.ckpt.index保存了所有变量名

有了这三个文件,就能得到模型的信息并加载到其他项目中

5.从保存好的ckpt文件恢复网络,并保存为pb

 

从保存的ckpt模型读取所有的节点,可以发现ckpt不能保存所有的的变量

from tensorflow.python import pywrap_tensorflow

checkpoint_path = './save/model.ckpt'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name:", key)#b,为什么只有b呢,因为只有B是变量

6.比较复杂的网络模型的保存,用同样的方法


import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
def model(input):
    net = tf.layers.conv2d(input, filters=32, kernel_size=3)
    net = tf.layers.batch_normalization(net, fused=False)
    net = tf.layers.separable_conv2d(net, 32, 3)
    net = tf.layers.conv2d(net, filters=32, kernel_size=3, name='output')
    return net
input_node = tf.placeholder(tf.float32, [1, 480, 480, 3], name='image')
pb = 'save/model3.pb'

saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
    model = model(input_node)
    sess.run(tf.global_variables_initializer())
    # saver_path = saver.save(sess, "save/model.ckpt")  # 这里就不可以保存ckpt文件了,会报No variables to save
    # tf.train.saver()保持模型的时候会产生多个文件,会把计算图的结构和图上参数取值分成了不同文件存储,这种方法是在TensorFlow中最常用的保存方式:   
    output_node_names = 'output/BiasAdd'#为什么这里的名字与上面不一样?因为上面的output是层的名字,
    # 这个层里面有很多的op,那怎么知道每个op的名字呢?不知道
    input_graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names.split(','))
with tf.gfile.GFile(pb, 'wb') as f:
    f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点,38 ops in the final graph.
    print(output_graph_def.node)

checkpoint_path = './save/model.ckpt'
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name:", key)#b,为什么只有b呢,因为只有B是变量

 

你可能感兴趣的:(tensorflow保存模型为pb文件的各种方式)