声明:
参考文献:
版本:Tensorflow 1.1
"""保存模型和变量"""
v1 = tf.Variable([0], name='v1')
v2 = tf.Variable([0], name='v2')
saver = tf.train.Saver() # 1. 初始化saver
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, 'ckp') # 2. 保存模型和变量
保存后文件的目录结构:
- checkpoint Checkpoint 文件会记录最近一次的断点文件(Checkpoint File) 的前缀,根据前缀可以找对对应的索引和数据文件。当调用
tf.train.latest_checkpoint
,可以快速找到最近一次的断点文件。- ckp.data-00000-of-00001
数据(data) 文件记录了所有变量(Variable) 的值。当restore 某个变量时,首先从索引文件中找到相应变量在哪个数据文件,然后根据索引直接获取变量的值,从而实现变量数据的恢复。- ckp.index 索引(index)文件,保存了一个不可变表的数据。其中,关键字为Tensor 的名称,其值描述该Tensor 的元数据信息,包括该Tensor 存储在哪个数据(data) 文件中,及其在该数据文件中的偏移,及其校验和等信息。
- ckp.meta 元文件(meta) 中保存了MetaGraphDef 的持久化数据,即模型数据,它包括GraphDef, SaverDef 等元数据。
"""保存模型和变量"""
with tf.Session() as sess:
saver = tf.import_meta_graph('ckp.meta') # 3. 加载模型
saver.restore(sess, 'ckp') # 4. 加载变量
总结使用tf.train.Saver()的关键步骤:
- 初始化saver
- 保存模型和变量
- 加载模型
- 加载变量
__init__(
var_list=None, # 选择要保存的变量,可以是列表,也可以是字典,默认表示保存所有变量
max_to_keep=5, # Maximum number of recent checkpoints to keep. Defaults to 5.
keep_checkpoint_every_n_hours=10000.0 # How often to keep checkpoints. Defaults to 10,000 hours.
)
Checkpoint 文件也记录了所有的断点文件列表,并且文件列表按照由旧至新的时间依次排序。当训练任务时间周期非常长,断点检查将持续进行,必将导致磁盘空间被耗尽。为了避免这个问题,存在两种基本的方法:
save( # save()的重要参数
sess, # A Session to use to save the variables.
save_path, # 文件名前缀
global_step=None, # 每`global_step`个迭代新建checkpoint filenames保存当前变量。参数值可以是tensor, tensor name或整型数
write_meta_graph=True # 是否保存模型数据
)
若调用saver.save(sess, 'ckp',global_step=1000)
,会产生文件:ckp.data-1000-of-00001。其中data文件的后缀格式是’-???-of-nnnnn’,???对应global_step,nnnnn用于分布式资源管理。
有时候我们会对一个模型训练的多个阶段分别保存,而此时模型的结构是确定的,变化的只有变量值,所以不需要重复保存meta文件:
saver.save(
sess,
'ckp',
global_step=step,
write_meta_graph=False
)
tf.train.import_meta_graph( # 重要参数
meta_graph_or_file
)
该函数从.meta文件汇总重建Graph。
restore(
sess,
save_path
)
要使用加载的模型,就需要有一个tensor变量作为handle,此时需要函数graph.get_tensor_by_name()
,下面是一个例子:
import tensorflow as tf
sess=tf.Session()
saver = tf.train.import_meta_graph('ckp.meta')
saver.restore(sess,tf.train.latest_checkpoint('ckp'))
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1")
w2 = graph.get_tensor_by_name("w2")
train_op = graph.get_tensor_by_name("train_op")
print sess.run(train_op, feed_dict={w1:13.0, w2:17.0})
我们以VGG模型为例:
......
saver = tf.train.import_meta_graph('vgg.meta')
graph = tf.get_default_graph()
fc7 = graph.get_tensor_by_name('fc7:0')
fc7 = tf.stop_gradient(fc7)
fc7_shape = fc7.get_shape().as_list()
new_outputs = 2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)