Tensorflow 模型保存与恢复的两种方法

“ 人生,其实无非是矛盾与选择的综合体,无关对错,仅仅在于我们能否有勇气在矛盾中作出选择并勇敢承担一切后果。 ​​​​”

 

最近在使用Object Detection API时

才发现Tensorflow训练的模型可以保存为一个.pb文件

下面将这两种模型保存方式整理下:

 

 保存checkpoint模型文件(.ckpt)

# 通过tf.train.Saver类实现保存和载入神经网络模型
 
# 执行本段程序时注意当前的工作路径
import tensorflow as tf
 
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

# 首先定义saver类
saver = tf.train.Saver()

# 定义会话
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for epoch in range(300):
        if epoch % 10 == 0:
            print "------------------------------------------------------"
            # 每迭代10次保存一次模型
            saver.save(sess, "Model/", global_step=epoch)
    print "------------------------------------------------------"
    saver.save(sess, "Model/model.ckpt")#也可以最后迭代完成再保存模型

 注意:

          程序生成并保存四个文件(在tensorflow版本0.11之前只会生成三个文件:checkpoint, model.ckpt, model.ckpt.meta),即训练模型的网络和权重是分开保存的。

              1. checkpoint 文本文件,记录了模型文件的路径信息列表

              2. model.ckpt.data-00000-of-00001 网络权重信息

              3. model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息

              4. model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf

 

  加载checkpoint模型文件(.ckpt)

#通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中
 
import tensorflow as tf
from tensorflow.python.framework import graph_util
 
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess,
                                                        graph_def, ['add'])
 
    with tf.gfile.GFile("Model/combined_model.pb", 'wb') as f:
        f.write(output_graph_def.SerializeToString())

然而,在实际生产部署时,为了方便使用,将模型保存为一个.pb文件,也就是将模型文件和权重文件整合为一个文件进行保存,这个过程的主要思路是graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant(使用graph_util.convert_variables_to_constants()函数),即可达到使用一个文件同时存储网络架构与权重的目标。

接下来我们看一看如何保存.pb文件~

 

 保存.pb模型文件

# 通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中
 
import tensorflow as tf
from tensorflow.python.framework import graph_util
 
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def, ['add'])
 
    with tf.gfile.GFile("Model/combined_model.pb", 'wb') as f:
        f.write(output_graph_def.SerializeToString())
    #或者下面的方法均可
    output_graph_def = graph_util.convert_variables_to_constants(sess,sess.graph_def,["output"])
    with tf.gfile.FastGFile("Model/combined_model.pb", mode='wb') as f:
        f.write(output_graph_def.SerializeToString())

 程序生成并保存一个文件combined_model.pb 二进制文件,同时保存了模型网络结构和参数(权重)信息

 

☛  加载.pb模型文件

#载入包含变量及其取值的模型
 
import tensorflow as tf
from tensorflow.python.platform import gfile
 
with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    model_filename = "Model/combined_model.pb"
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
 
    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print(sess.run(result)) # [array([ 3.], dtype=float32)]

    #或者下面的方法:
    with open(pb_file_path, "rb") as f:
         output_graph_def = tf.GraphDef()
         output_graph_def.ParseFromString(f.read())
         _ = tf.import_graph_def(output_graph_def, name="")
   
    #将模型读取到默认的图中(强烈建议此种方法)
    with tf.gfile.GFile(MODEL_CHECK_FILE, 'rb') as f:
        _graph = tf.GraphDef()
        _graph.ParseFromString(fd.read())
        tf.import_graph_def(_graph, name='')

恩~

就这样~

See you next!

你可能感兴趣的:(Tensorflow 模型保存与恢复的两种方法)