TensorFlow保存和恢复变量——tf.train.Saver()

声明:

  1. 参考Tensorflow官方文档
  2. tensorflow当前版本1.1
  3. 更新:现在tensorflow官网有了中文教程,很方便学习了

tf.train.Saver()

tf.train.Saver()是一个类,提供了变量、模型(也称图Graph)的保存和恢复模型方法。

TensorFlow是通过构造Graph的方式进行深度学习,任何操作(如卷积、池化等)都需要operator,保存和恢复操作也不例外。在tf.train.Saver()类初始化时,用于保存和恢复的saverestore operator会被加入Graph。所以,下列类初始化操作应在搭建Graph时完成。

saver = tf.train.Saver()

TensorFlow的保存和恢复分为两种:

  • 保存和恢复变量
  • 保存和恢复模型

保存变量

TensorFlow会讲变量保存在二进制checkpoint文件中,这类文件会将变量名称映射到张量值。

下面是保存变量的例子:

  1. 创建变量
  2. 初始化变量
  3. 实例化tf.train.Saver()
  4. 创建Session并保存
import tensorflow as tf
# Create some variables.
v1 = tf.get_variable("v1_name", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2_name", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
    sess.run(init_op)
    # Do some work with the model.
    inc_v1.op.run()
    dec_v2.op.run()
    # Save the variables to disk.
    save_path = saver.save(sess, "model/model.ckpt")  # 返回一个保存路径的字符串
    print("Model saved in path: %s" % save_path)
    
'''output
Model saved in path: model/model.ckpt
'''

保存路径中的文件为:

  • checkpoint:保存当前网络状态的文件
  • model.data-00000-of-00001
  • model.index
  • model.meta:保存Graph结构的文件

可以发现,没有名为 ‘model/model.ckpt’ 的实体文件,其中 ‘model’ 是一个与用户交互的前缀。

关于函数saver.save(),常用的参数就是前三个:

save(
	sess,  # 必需参数,Session对象
	save_path,  # 必需参数,存储路径
	global_step=None,  # 可以是Tensor, Tensor name, 整型数
	latest_filename=None,  # 协议缓冲文件名,默认为'checkpoint',不用管
	meta_graph_suffix='meta',  # 图文件的后缀,默认为'.meta',不用管
	write_meta_graph=True,  # 是否保存Graph
	write_state=True,  # 建议选择默认值True
	strip_default_attrs=False  # 是否跳过具有默认值的节点

恢复变量

从checkpoint文件中提取变量值赋给新定义的变量。

tf.reset_default_graph()
# Create some variables
# !!!variable name必须与保存时的name一致
v1 = tf.get_variable("v1_name", shape=[3])
v2 = tf.get_variable("v2_name", shape=[5])

saver = tf.train.Saver()
with tf.Session() as sess:
    # Restore variables from disk
    saver.restore(sess, "model/model.ckpt")
    print("v1: %s" % v1.eval())
    print("v2: %s" % v2.eval())
'''output
INFO:tensorflow:Restoring parameters from model/model.ckpt
v1: [ 1.  1.  1.]
v2: [-1. -1. -1. -1. -1.]
'''

variable().eval()
eval(session=None)
In a session computes and returns the value of this variable.

选择要保存和恢复的变量

tf.train.Saver()的构造函数传递以下任意内容来轻松指定要保存或加载的名称和变量:

  • 变量列表(要求变量与变量名之间的一一对应)
  • Python字典,其中,key是要使用的名称,value是要管理的变量(通过键值映射自定义变量与变量名之间的对应关系)
tf.reset_default_graph()
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
saver = tf.train.Saver({"v2_name": v2})
with tf.Session() as sess:
    v1.initializer.run()
    saver.restore(sess, "model/model.ckpt")
    print("v1: %s" % v1.eval())
    print("v2: %s" % v2.eval())
'''output
INFO:tensorflow:Restoring parameters from model/model.ckpt
v1: [ 0.  0.  0.]
v2: [-1. -1. -1. -1. -1.]
'''

variable().initializer
The initializer operation for this variable.

Note

  • 如果需要保存和恢复模型变量的不同子集,您可以根据需要创建任意数量的Saver对象。同一个变量可以列在多个Saver对象中。
  • 变量的值只有在Saver.restore()方法运行时才会更改,这些变量不需要初始化。
  • 如果您仅在会话开始时恢复模型变量的子集,则必须为其他变量运行初始化 op。

查看ckpt二进制文件中的变量

我们可以使用 inspect_checkpoint 库快速检查某个检查点的变量。

prints tensors in a checkpoint file.

If no tensor_name is provided, prints the tensor names and shapes in
the checkpoint file.

If tensor_name is provided, prints the content of the tensor.

Args Description
file_name Name of the checkpoint file
tensor_name Name of the tensor in the checkpoint file
all_tensors Boolean indicating whether to print all tensors
all_tensor_names Boolean indicating whether to print all tensor names
from tensorflow.python.tools import inspect_checkpoint as ickpt
ickpt.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name="v1_name", all_tensors=False)
'''output
tensor_name:  v1_name
[ 1.  1.  1.]
'''
ickpt.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name="v1_name", all_tensors=True)
'''output
tensor_name:  v1_name
[ 1.  1.  1.]
tensor_name:  v2_name
[-1. -1. -1. -1. -1.]
'''

你可能感兴趣的:(TensorFlow)