TensorFlow SavedModel保存和加载模型

声明:

  1. 参考文献:

    • TensorFlow官方文档
    • A quick complete tutorial to save and restore Tensorflow models
    • TensorFlow: Save and Restore Models
  2. 版本:Tensorflow 1.1

利用tf.train.Saver()保存和加载模型

"""保存模型和变量"""
v1 = tf.Variable([0], name='v1')
v2 = tf.Variable([0], name='v2')

saver = tf.train.Saver()  # 1. 初始化saver

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'ckp')  # 2. 保存模型和变量

保存后文件的目录结构:

  • checkpoint Checkpoint 文件会记录最近一次的断点文件(Checkpoint File) 的前缀,根据前缀可以找对对应的索引和数据文件。当调用tf.train.latest_checkpoint,可以快速找到最近一次的断点文件。
  • ckp.data-00000-of-00001
    数据(data) 文件记录了所有变量(Variable) 的值。当restore 某个变量时,首先从索引文件中找到相应变量在哪个数据文件,然后根据索引直接获取变量的值,从而实现变量数据的恢复。
  • ckp.index 索引(index)文件,保存了一个不可变表的数据。其中,关键字为Tensor 的名称,其值描述该Tensor 的元数据信息,包括该Tensor 存储在哪个数据(data) 文件中,及其在该数据文件中的偏移,及其校验和等信息。
  • ckp.meta 元文件(meta) 中保存了MetaGraphDef 的持久化数据,即模型数据,它包括GraphDef, SaverDef 等元数据。
"""保存模型和变量"""
with tf.Session() as sess:
    saver = tf.import_meta_graph('ckp.meta')  # 3. 加载模型
    saver.restore(sess, 'ckp')  # 4. 加载变量

总结使用tf.train.Saver()的关键步骤:

  1. 初始化saver
  2. 保存模型和变量
  3. 加载模型
  4. 加载变量

1. 初始化saver

__init__(
	var_list=None,  # 选择要保存的变量,可以是列表,也可以是字典,默认表示保存所有变量
	max_to_keep=5,  # Maximum number of recent checkpoints to keep. Defaults to 5.
	keep_checkpoint_every_n_hours=10000.0  # How often to keep checkpoints. Defaults to 10,000 hours.
)

Checkpoint 文件也记录了所有的断点文件列表,并且文件列表按照由旧至新的时间依次排序。当训练任务时间周期非常长,断点检查将持续进行,必将导致磁盘空间被耗尽。为了避免这个问题,存在两种基本的方法:

  1. max_to_keep: 配置最近有效文件的最大数目,当新的断点文件生成时,且文件数目超过max_to_keep,则删除最旧的断点文件;其中,max_to_keep 默认值为5;
  2. keep_checkpoint_every_n_hours: 在训练过程中每n 小时做一次断点检查,保证只有一个断点文件;其中,该选项默认是关闭的。

2. 保存模型和变量

save(  # save()的重要参数
    sess,  # A Session to use to save the variables.
    save_path,  # 文件名前缀
    global_step=None,  # 每`global_step`个迭代新建checkpoint filenames保存当前变量。参数值可以是tensor, tensor name或整型数
    write_meta_graph=True  # 是否保存模型数据
)

若调用saver.save(sess, 'ckp',global_step=1000),会产生文件:ckp.data-1000-of-00001。其中data文件的后缀格式是’-???-of-nnnnn’,???对应global_step,nnnnn用于分布式资源管理。

有时候我们会对一个模型训练的多个阶段分别保存,而此时模型的结构是确定的,变化的只有变量值,所以不需要重复保存meta文件:

saver.save(
	sess, 
	'ckp',
	global_step=step, 
	write_meta_graph=False
)

3. 加载模型

tf.train.import_meta_graph(  # 重要参数
	meta_graph_or_file
)

该函数从.meta文件汇总重建Graph。

4. 加载变量

restore(
	sess, 
	save_path
)

附加:如何使用加载的模型

要使用加载的模型,就需要有一个tensor变量作为handle,此时需要函数graph.get_tensor_by_name(),下面是一个例子:

import tensorflow as tf
 
sess=tf.Session()    

saver = tf.train.import_meta_graph('ckp.meta')
saver.restore(sess,tf.train.latest_checkpoint('ckp'))

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1")
w2 = graph.get_tensor_by_name("w2")
train_op = graph.get_tensor_by_name("train_op")
 
print sess.run(train_op, feed_dict={w1:13.0, w2:17.0})

使用加载的模型进行fine-tune

我们以VGG模型为例:

......
saver = tf.train.import_meta_graph('vgg.meta')
graph = tf.get_default_graph()
fc7 = graph.get_tensor_by_name('fc7:0')
fc7 = tf.stop_gradient(fc7)
fc7_shape = fc7.get_shape().as_list()

new_outputs = 2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)

你可能感兴趣的:(TensorFlow)