TensorFlow笔记:模型的存储和加载

tf.train.Saver类

TensorFlow中的模型存取都是通过这个类来实现的
my_model 为例,存储时会在目录下生成四个文件,分别是 checkpoint 文件、 my_model.meta 文件、 my_model.index 文件和 my_model.data-00000-of-00001 文件

  • checkpoint 文件:该文件在目录下只会生成一个,是用来记录和管理整个目录下的所有模型文件
  • my_model.meta文件:该文件用来存储整个模型的计算图
  • my_model.index 文件和 my_model.data-00000-of-00001 文件:在之前的版本中,这两个文件是被一个 .ckpt 文件代替,但是新的版本中会生成两个文件。这两个文件是存储模型中变量的取值

构造类

__init__(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None,
    restore_sequentially=False,
    saver_def=None,
    builder=None,
    defer_build=False,
    allow_empty=False,
    write_version=tf.train.SaverDef.V2,
    pad_step_number=False,
    save_relative_paths=False,
    filename=None
)

初始化函数中有非常多的参数,常用的主要是以下几个

  • var_list:用于指定保存的变量,如果为None则保存所有的变量
  • max_to_keep:用于指定最多保存的模型个数,因为通常在经过一定的训练轮次之后,我们会对模型进行一次保存
  • keep_checkpoint_every_n_hours:用于指定自动保存的时间间隔

一个简单的例子

v1 = tf.Variable(tf.ones(shape=[2, 3]), name='v1')
v2 = tf.Variable(tf.ones(shape=[3, 2]), name='v2')
v3 = tf.Variable(tf.ones(shape=[2, 2]), name='v3')
a = tf.matmul(v1, v2)
res = a + v3

init_op = tf.global_variables_initializer()

saver = tf.train.Saver()  # 保存了所有变量
# saver = tf.train.Saver([v1, v2, v3])  # 只保存[v1, v2, v3]变量

模型存储

saver.save()函数

在初始化类之后,就可以通过调用 saver.save() 函数来对模型进行存储了

save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True,
    strip_default_attrs=False
)
  • sess:指定会话
  • global_step:指定当前的训练次数,模型在存储时会自动在文件名中加入该属性,例如指定 global_step=10 ,则生成的文件文就是 my_model-10.meta
  • write_meta_graph:因为许多模型随着训练次数的增加,只会有参数的更新,计算图是保持不变的,因此我们不需要在每次都重新生成一个 .meta 文件,这时候就可以指定该参数为 False

一个简单的例子

with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, MODEL_FILE)
    # saver.save(sess, MODEL_FILE, global_step=global_step)  # 在存储时指定训练步数

模型加载

saver.restore()函数

当需要使用一个已经存储的模型时,我们就可以通过 saver.restore() 函数实现

restore(
    sess,
    save_path
)

加载函数非常简单,只需要指定会话和路径就可以了,但是如果你直接运行下面的代码,是没法成功读取的

saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, MODEL_FILE)
    print(res.eval())

因为 saver.restore() 函数只是从模型中重新载入了变量的值,但是当前的模型并没有对应的变量去存储,也就是还没有定义计算图,所以我们还需要重新定义一次计算图

v1 = tf.Variable(tf.ones(shape=[2, 3]), name='v1')
v2 = tf.Variable(tf.ones(shape=[3, 2]), name='v2')
v3 = tf.Variable(tf.ones(shape=[2, 2]), name='v3')
a = tf.matmul(v1, v2)
res = a + v3
saver = tf.train.Saver()

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    # 不需要执行初始化操作
    saver.restore(sess, MODEL_FILE)

    print(res.eval())
# output:
# INFO:tensorflow:Restoring parameters from ../files/model/my_model
# [[4. 4.]
# [4. 4.]]

重命名变量

tf.train.Saver 的初始化操作中可以传入一个字典实现对模型加载时的变量映射

v1 = tf.Variable(tf.ones(shape=[2, 3]), name='v-1')
v2 = tf.Variable(tf.ones(shape=[3, 2]), name='v-2')
v3 = tf.Variable(tf.ones(shape=[2, 2]), name='v-3')
a = tf.matmul(v1, v2)
res = a + v3
init_op = tf.global_variables_initializer()

saver = tf.train.Saver({'v1': v1, 'v2': v2, 'v3': v3})  #字典的key为原变量的name属性,值为当前变量名

有的时候在当前计算图中变量的名称与存储模型中的不同,需要通过一个字典使其完成映射,就可以使用这个方法加载变量


其他的相关操作

tf.train.import_meta_graph()函数

如果说我们有一个很大的计算图,重新定义计算图工作量大,而且容易出错。难道每次载入模型的时候都需要重新定义一次所有变量吗?
当然不是,因为在存储模型的时候我们已经将模型的计算图存结构存储于 .meta 文件中,所以我们可以直接通过 tf.train.import_meta_graph() 函数载入计算图,而不需要重新定义。此时,你可以通过变量的名称来访问计算图中的变量(注意是变量的name属性,不是变量名)

saver = tf.train.import_meta_graph(MODEL_FILE + '.meta')

with tf.Session() as sess:
    saver.restore(sess, MODEL_FILE)
    
    res = tf.get_default_graph().get_tensor_by_name('add:0')
    print(res.eval())

tf.train.get_checkpoint_state()函数

如果我们在训练过程中保存了一个模型的多个不同阶段的副本,那么加载模型时如何确定最新的模型副本是比较困难的
之前我们提到过 checkpoint 文件管理了目录下的所有模型,TensorFlow提供了通过该文件获取最新模型的方法
下面是一个例子,加载模型并计算acc

with tf.Session() as sess:
    ckpt = tf.train.get_checkpoint_state(config.MODEL_DIR)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        
        # ...一些其他操作
        
        _acc = sess.run([acc], feed_dict=feed_dict)
        print('acc on data is %g' % _acc[0])

tf.train.get_checkpoint_state() 函数接受模型所在的目录,从中通过 checkpoint 文件获取最新模型的信息,检查非空后通过 ckpt.model_checkpoint_path 获取模型路径

你可能感兴趣的:(python,TensorFlow)