TensorFlow中的模型存取都是通过这个类来实现的
以 my_model
为例,存储时会在目录下生成四个文件,分别是 checkpoint
文件、 my_model.meta
文件、 my_model.index
文件和 my_model.data-00000-of-00001
文件
checkpoint
文件:该文件在目录下只会生成一个,是用来记录和管理整个目录下的所有模型文件my_model.meta
文件:该文件用来存储整个模型的计算图my_model.index
文件和 my_model.data-00000-of-00001
文件:在之前的版本中,这两个文件是被一个 .ckpt
文件代替,但是新的版本中会生成两个文件。这两个文件是存储模型中变量的取值__init__(
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)
初始化函数中有非常多的参数,常用的主要是以下几个
一个简单的例子
v1 = tf.Variable(tf.ones(shape=[2, 3]), name='v1')
v2 = tf.Variable(tf.ones(shape=[3, 2]), name='v2')
v3 = tf.Variable(tf.ones(shape=[2, 2]), name='v3')
a = tf.matmul(v1, v2)
res = a + v3
init_op = tf.global_variables_initializer()
saver = tf.train.Saver() # 保存了所有变量
# saver = tf.train.Saver([v1, v2, v3]) # 只保存[v1, v2, v3]变量
在初始化类之后,就可以通过调用 saver.save()
函数来对模型进行存储了
save(
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True,
strip_default_attrs=False
)
global_step=10
,则生成的文件文就是 my_model-10.meta
.meta
文件,这时候就可以指定该参数为 False
一个简单的例子
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess, MODEL_FILE)
# saver.save(sess, MODEL_FILE, global_step=global_step) # 在存储时指定训练步数
当需要使用一个已经存储的模型时,我们就可以通过 saver.restore()
函数实现
restore(
sess,
save_path
)
加载函数非常简单,只需要指定会话和路径就可以了,但是如果你直接运行下面的代码,是没法成功读取的
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, MODEL_FILE)
print(res.eval())
因为 saver.restore()
函数只是从模型中重新载入了变量的值,但是当前的模型并没有对应的变量去存储,也就是还没有定义计算图,所以我们还需要重新定义一次计算图
v1 = tf.Variable(tf.ones(shape=[2, 3]), name='v1')
v2 = tf.Variable(tf.ones(shape=[3, 2]), name='v2')
v3 = tf.Variable(tf.ones(shape=[2, 2]), name='v3')
a = tf.matmul(v1, v2)
res = a + v3
saver = tf.train.Saver()
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
# 不需要执行初始化操作
saver.restore(sess, MODEL_FILE)
print(res.eval())
# output:
# INFO:tensorflow:Restoring parameters from ../files/model/my_model
# [[4. 4.]
# [4. 4.]]
在 tf.train.Saver
的初始化操作中可以传入一个字典实现对模型加载时的变量映射
v1 = tf.Variable(tf.ones(shape=[2, 3]), name='v-1')
v2 = tf.Variable(tf.ones(shape=[3, 2]), name='v-2')
v3 = tf.Variable(tf.ones(shape=[2, 2]), name='v-3')
a = tf.matmul(v1, v2)
res = a + v3
init_op = tf.global_variables_initializer()
saver = tf.train.Saver({'v1': v1, 'v2': v2, 'v3': v3}) #字典的key为原变量的name属性,值为当前变量名
有的时候在当前计算图中变量的名称与存储模型中的不同,需要通过一个字典使其完成映射,就可以使用这个方法加载变量
如果说我们有一个很大的计算图,重新定义计算图工作量大,而且容易出错。难道每次载入模型的时候都需要重新定义一次所有变量吗?
当然不是,因为在存储模型的时候我们已经将模型的计算图存结构存储于 .meta
文件中,所以我们可以直接通过 tf.train.import_meta_graph()
函数载入计算图,而不需要重新定义。此时,你可以通过变量的名称来访问计算图中的变量(注意是变量的name属性,不是变量名)
saver = tf.train.import_meta_graph(MODEL_FILE + '.meta')
with tf.Session() as sess:
saver.restore(sess, MODEL_FILE)
res = tf.get_default_graph().get_tensor_by_name('add:0')
print(res.eval())
如果我们在训练过程中保存了一个模型的多个不同阶段的副本,那么加载模型时如何确定最新的模型副本是比较困难的
之前我们提到过 checkpoint
文件管理了目录下的所有模型,TensorFlow提供了通过该文件获取最新模型的方法
下面是一个例子,加载模型并计算acc
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(config.MODEL_DIR)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# ...一些其他操作
_acc = sess.run([acc], feed_dict=feed_dict)
print('acc on data is %g' % _acc[0])
tf.train.get_checkpoint_state()
函数接受模型所在的目录,从中通过 checkpoint
文件获取最新模型的信息,检查非空后通过 ckpt.model_checkpoint_path
获取模型路径