tensorflow通过checkpoint恢复模型参数

前情提要:在模型训练过程中通过dev进行验证,寻找最优的参数组合。在训练结束后选择最优dev的情况进行测试。可通过Saver进行保存,而后恢复的方法,常见的恢复方法有两种:1.通过restore函数进行恢复;2.通过加载meta的方法恢复图模型。本文记录Re2模型中使用的第三种方法:通过checkpoint恢复部分参数,并使用该参数重新实例化model对象,而后进行验证。该方法适用于模块化较强的大规模的训练。

具体代码如下:

    def evaluate(self):
        data = load_data(self.args.data_dir, self.args.test_file)
        tf.reset_default_graph()
        with tf.Graph().as_default():
            config = tf.ConfigProto()
            #当使用gpu时候,运行自动慢慢达到最大gpu的内存
            config.gpu_options.allow_growth = True
            #运行设备不满足时会自动换回cpu
            config.allow_soft_placement = True
            sess = tf.Session(config=config)
            with sess.as_default():
                model, checkpoint = Model.load(sess, self.model_path)
                args = checkpoint['args']
                interface = Interface(args)
                batches = interface.pre_process(data, training=False)
                score, stats = model.evaluate(sess, batches)
                pprint(stats)
                self.log(f'test_score {stats["score"]} .')

本文主要分析

model, checkpoint = Model.load(sess, self.model_path)

这一加载模块

该模块的函数如下:

    @classmethod
    def load(cls, sess, model_path):
        with open(model_path + '.stat', 'rb') as f:
            checkpoint = pickle.load(f)
        args = checkpoint['args']
        model = cls(args, sess, updates=checkpoint['updates'])
        init_vars = tf.train.list_variables(model_path)
        model_vars = {re.match("^(.*):\\d+$", var.name).group(1): var for var in tf.global_variables()}
        assignment_map = {name: model_vars[name] for name, _ in init_vars if name in model_vars}
        tf.train.init_from_checkpoint(model_path, assignment_map)
        sess.run(tf.global_variables_initializer())
        return model, checkpoint

模型通过tf.train.Saver进行保存,保存会得到checkpoint、xxx.meta、xxx.index、xxx.data-00000-of-00001这四个文件。
其中checkpoint文件包含以下信息:

model_checkpoint_path: “checkpoint-27650”
all_model_checkpoint_paths: “checkpoint-27650”

xxx.data-00000-of-00001储存网络参数值,xxx.index储存每一层的名字,xxx.meta储存图结构。
同时可通过构建键值对,并写入stat文件保存额外需要的某些自定义参数,具体代码如下:

params = {
	'updates': self.updates,
	'args': self.args,
	...:...
}
with open(os.path.join(self.args.summary_dir, '{}-{}.stat'.format("文件前缀", name)), 'wb') as f:
	pickle.dump(params, f)#写入文件

因此,在进行验证时,首先通过如下语句加载回自定义保存的参数,并取出所需要的参数

with open(model_path + 'checkpoint-best.stat', 'rb') as f:
	checkpoint = pickle.load(f)
args = checkpoint['args']

而后根据参数调用model对象的__init__,重新构造一个模型对象,在实例化的model对象中重新加载完整的图模型。至此完成图的加载。

在重新加载回图模型后,需要将保存的best dev的参数加载回来。这一步首先通过tf.train.list_variables(model_path)来加载回所有参数对象(并没有加载回值),其返回对象类型为一组(name, shape)的列表。
恢复图的终极目标是把加载回的参数放到图中的对应位置,所以需要进行图中参数与加载回的参数的映射,以便能正确的恢复。所以需要知道图中都有哪些参数及他们的命名,因此使用:

model_vars = {re.match("^(.*):\\d+$", var.name).group(1): var for var in tf.global_variables()}

取出所有参数。
对于re.match("^(.*):\\d+$", var.name).group(1),举一个简单的例子来解释这一步在做什么。
首先创建一些变量:

a = tf.Variable( [10,10,20] ,name = "a")
b = tf.Variable( [20,20,20,30] ,name = "b")
b = tf.Variable( [30,30,30,34,56] ,name = "c")

通过tf.global_variables()可得结果:

[<tf.Variable 'a:0' shape=(3,) dtype=int32_ref>,
 <tf.Variable 'b:0' shape=(4,) dtype=int32_ref>,
 <tf.Variable 'c:0' shape=(5,) dtype=int32_ref>]

通过该正则表达式进行匹配

model_vars = {re.match("^(.*):\\d+$", var.name).group(1): var for var in tf.global_variables()}

得到

{'a': <tf.Variable 'a:0' shape=(3,) dtype=int32_ref>,
 'b': <tf.Variable 'b:0' shape=(4,) dtype=int32_ref>,
 'c': <tf.Variable 'c:0' shape=(5,) dtype=int32_ref>}

我们定义的参数名,参数对象的一个dic。

关于正则匹配,^ 为句子的开始,$为句子的结束," .* " 表示将任意字符重复零到多次," : “匹配命名空间中的” : ",group(1)表示取出第一个圆括号中的内容。在本例中也只有一个(),所以取出该括号中的内容,即变量的名称

而后,将该图中所需要的参数通过tf.train.init_from_checkpoint(model_path, assignment_map)
进行恢复,其中assignment_map为{加载回的就图的参数:待恢复的新图的参数},因此用

`assignment_map = {name: model_vars[name] for name, _ in init_vars if name in model_vars}`

构建assignment_map。再通过

tf.train.init_from_checkpoint(model_path, assignment_map)
sess.run(tf.global_variables_initializer())

完成对计算图的恢复。该计算图属于model类,将该model类进行返回。
而后取出数据,并进行测试等操作,完成对模型性能的评估。

你可能感兴趣的:(tensorflow)