class tf.train.Saver
保存和恢复变量
最简单的保存和恢复模型的方法是使用tf.train.Saver 对象。构造器给graph 的所有变量,或是定义在列表里的变量,添加save 和 restore ops。saver 对象提供了方法来运行这些ops,定义检查点文件的读写路径。
检查点是专门格式的二进制文件,将变量name 映射到 tensor value。检查checkpoin 内容最好的方法是使用Saver 加载它。
Savers 可以使用提供的计数器自动计数checkpoint 文件名。这可以是你在训练一个模型时,在不同的步骤维持多个checkpoint。例如你可以使用 training step number 计数checkpoint 文件名。为了避免填满硬盘,savers 自动管理checkpoint 文件。例如,你可以最多维持N个最近的文件,或者没训练N小时保存一个checkpoint.
通过传递一个值给可选参数 global_step ,你可以编号checkpoint 名字。
saver.save(sess, 'my-model', global_step=0) ==>filename: 'my-model-0'
saver.save(sess, 'my-model', global_step=1000) ==>filename: 'my-model-1000'
另外,Saver() 构造器可选的参数可以让你控制硬盘上 checkpoint 文件的数量。
一个定期保存的训练程序如下这样:
#Create a saver
saver=tf.train.Saver(...variables...)
#Launch the graph and train, saving the model every 1,000 steps.
sess=tf.Session()
for step in xrange(1000000):
sess.run(...training_op...)
if step % 1000 ==0:
#Append the step number to the checkpoint name:
saver.save(sess,'my-model',global_step=step)
除了checkpoint 文件之外,savers 还在硬盘上保存了一个协议缓存,存储最近的checkpoint 列表。这用于管理 被编号的checkpoint 文件,并且通过latest_checkpoint() 可以很容易找到最近的checkpoint 的路径。协议缓存存储在紧挨checkpoint 文件的名为 'checkpoint' 的文件中。
如果你创建了几个savers,你可以调用save() 指定协议缓存的文件名。
tf.train.Saver.__init__(var_list=None, reshape=False, shared=False, max_to_keep=5, keep_checkpoint_every_n_hour=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None)
创建一个Saver
构造器添加操作去保存和恢复变量。
var_list 指定了将要保存和恢复的变量。它可以传dict 或者list
v1=tf.Variable(..., name='v1')
v2=tf.Variable(..., name='v2')
# Pass the variables as a dict:
saver=tf.train..Saver({'v1':v1, 'v2':v2})
# Or pass them as a list
saver=tf.train..Saver([v1,v2])
# Passing a list is equivalent to passing a dict with the variable op names as keys:
saver=tf.train..Saver({v.op.name: v for v in [v1,v2]})
#!/usr/bin/env python
# coding=utf-8
import os
import tensorflow as tf
# Create some variables.
v1=tf.Variable([[1,1],[2,2],[3,3]],name="v1")
v2=tf.Variable([[4,4],[5,5],[6,7]],name="v2")
# Add an op to initialize the variables.
init_op=tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver=tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
save_path=saver.save(sess,"/home/yhk/tmp/test/model.ckpt")
print "Model saved in file: ", save_path