tensroflow 模型保存、加载

参考:
[http://blog.csdn.net/scotthuang1989/article/details/77769412]
[http://blog.csdn.net/LordofRobots/article/details/77719020]

tensorflow由于其开源特性,因此API经常发生变化。保存加载模型也发生了一些变化。
本次博客针对与tensorflow1.0以后1.3以前的版本,之后又有什么变化就不知道了。下面进入正题。

常见的持久化方法有两种:一种是保存为ckpt(以前的,现在该改后缀了就叫多文件),一种是graph_def文件。

一、多文件

源代码位于 tensorflow/python/training/saver.py
Saver类可以使用保存以及从某一个检查点恢复数据。
你可以之保存固定数量的检查点(恢复点),比如你可以只保存最近的多少个检查点文件或者是在训练时每隔几个小时保存一次。

tf.train.Saver的初始化

__init__(
    var_list=None,#一系列的Variable,SaveableObject, 或者dict名称,如果没有则表示所有可保存的变量
    reshape=False,#如果为True,表示允许恢复数据室variables有着不同的shape。
    sharded=False,#如果为Ture,表示在不同的设备上通向检查点
    max_to_keep=5,#保持几个检查点,默认是五个
    keep_checkpoint_every_n_hours=10000.0,#间隔多久保存一次检查点。
    name=None,#string,可选的变量,具体没看明白
    restore_sequentially=False,#一个Bool值,可以在恢复大量模型的时候减少内存的使用
    saver_def=None,#用于替换当前的保存builder,这个builder是之前的只保存Graph的那种Saver。将其保存成proto的形式。有点没有理解。
    builder=None,#当没有提供saver_def的时候可选的,默认为BaseSaverBuilder()
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,#是用那个版本进行保存
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)

之后你会得到tf.train.Saver的一个变量。

save

保存变量。
需要包含了需要保存的图的session,并且这些变量需要已经初始化过了,这个方法会返回最新的检查点的文件名称位置,这个位置可以用于调用restore().

save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True
)

示例代码:

import tensorflow as tf

# prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')  # name is very important in restoration
w2 = tf.Variable(tf.random_normal(shape=[2]), name='w2')
b1 = tf.Variable(2.0, name='bias1')
feed_dict = {w1: [10, 3], w2: [5, 5]}
# define a test operation that will be restored
w3 = tf.add(w1, w2)  # without name, w3 will not be stored
w4 = tf.multiply(w3, b1, name="op_to_restore")
#最多备份四次,默认每一个小时保存一次
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(w1))
print(sess.run(w4, feed_dict))
# 保存模型,保存在metaTest文件夹下,文件的名称为my_test_model
saver.save(sess, 'metaTest/my_test_model')

生成的结果为:

tensroflow 模型保存、加载_第1张图片
2017-09-13 21-51-25 的屏幕截图.png

其中,.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。

模型恢复

import tensorflow as tf
sess = tf.Session()
#恢复网络结构
saver = tf.train.import_meta_graph('metaTest/my_test_model.meta')
#恢复参数tf.train.latest_checkpoint('save/'),获取保存在那个位置的最新的文件
saver.restore(sess,tf.train.latest_checkpoint('metaTest/'))
#获取当前的默认图结构
graph = tf.get_default_graph()
#获取某一个tensor
w1 = graph.get_tensor_by_name('w1:0')
print(sess.run(w1))
w2 = graph.get_tensor_by_name('w2:0')
feed_dict = {w1:[-1,1],w2:[4,6]}
op_to_restore = graph.get_tensor_by_name('op_to_restore:0')
print(sess.run(op_to_restore,feed_dict))

如果删除saver.restore(sess,tf.train.latest_checkpoint('metaTest /')),则会报错。

二、graph_def文件

我们需要将TensorFlow的模型导出为单个文件(同时包含模型架构定义与权重),方便在其他地方使用(如在c++中部署网络)。利用tf.train.write_graph()默认情况下只导出了网络的定义(没有权重),而利用tf.train.Saver().save()导出的文件graph_def与权重是分离的,因此需要采用别的方法。

graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant,即可达到使用一个文件同时存储网络架构与权重的目标。

#保存模型以及参数
import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 这里需要填入输出tensor的名字
    graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
    tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)

会在当前文件夹下生成graph.pb文件,其中包含了网络结构以及所有的参数。

#恢复参数
import tensorflow as tf
with tf.Session() as sess:
    with open('./graph.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read()) 
        output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
        print(sess.run(output))

恢复模型的API:
作用:将graph_def引入到默认的Graph中。该方法提供了将序列化了的graph_def的pb文件重新加载到网络中的功能,并且将graph_def中的每一个objects转化为tf.Tensor或者tf.Operation的格式。一旦使用了这个方法,这些结构就会在当前的Graph中。可看 tf.Graph.as_graph_def查看关于GraphDef更详细的定义。

import_graph_def(
    graph_def,#一个graphDef文件
    input_map=None,#将graph_def中定义的map名称的输入转化为Tensor,方便给值
    return_elements=None,#将在[]中出现的graph_def中的操作转化为Operation,或者将graph_def中的tensor names转化为Tensor。
    name=None,
    op_dict=None,
    producer_op_list=None
)

你可能感兴趣的:(tensroflow 模型保存、加载)