前情提要:在模型训练过程中通过dev进行验证,寻找最优的参数组合。在训练结束后选择最优dev的情况进行测试。可通过Saver进行保存,而后恢复的方法,常见的恢复方法有两种:1.通过restore函数进行恢复;2.通过加载meta的方法恢复图模型。本文记录Re2模型中使用的第三种方法:通过checkpoint恢复部分参数,并使用该参数重新实例化model对象,而后进行验证。该方法适用于模块化较强的大规模的训练。
具体代码如下:
def evaluate(self):
data = load_data(self.args.data_dir, self.args.test_file)
tf.reset_default_graph()
with tf.Graph().as_default():
config = tf.ConfigProto()
#当使用gpu时候,运行自动慢慢达到最大gpu的内存
config.gpu_options.allow_growth = True
#运行设备不满足时会自动换回cpu
config.allow_soft_placement = True
sess = tf.Session(config=config)
with sess.as_default():
model, checkpoint = Model.load(sess, self.model_path)
args = checkpoint['args']
interface = Interface(args)
batches = interface.pre_process(data, training=False)
score, stats = model.evaluate(sess, batches)
pprint(stats)
self.log(f'test_score {stats["score"]} .')
本文主要分析
model, checkpoint = Model.load(sess, self.model_path)
这一加载模块
该模块的函数如下:
@classmethod
def load(cls, sess, model_path):
with open(model_path + '.stat', 'rb') as f:
checkpoint = pickle.load(f)
args = checkpoint['args']
model = cls(args, sess, updates=checkpoint['updates'])
init_vars = tf.train.list_variables(model_path)
model_vars = {re.match("^(.*):\\d+$", var.name).group(1): var for var in tf.global_variables()}
assignment_map = {name: model_vars[name] for name, _ in init_vars if name in model_vars}
tf.train.init_from_checkpoint(model_path, assignment_map)
sess.run(tf.global_variables_initializer())
return model, checkpoint
模型通过tf.train.Saver
进行保存,保存会得到checkpoint、xxx.meta、xxx.index、xxx.data-00000-of-00001这四个文件。
其中checkpoint文件包含以下信息:
model_checkpoint_path: “checkpoint-27650”
all_model_checkpoint_paths: “checkpoint-27650”
xxx.data-00000-of-00001储存网络参数值,xxx.index储存每一层的名字,xxx.meta储存图结构。
同时可通过构建键值对,并写入stat文件保存额外需要的某些自定义参数,具体代码如下:
params = {
'updates': self.updates,
'args': self.args,
...:...
}
with open(os.path.join(self.args.summary_dir, '{}-{}.stat'.format("文件前缀", name)), 'wb') as f:
pickle.dump(params, f)#写入文件
因此,在进行验证时,首先通过如下语句加载回自定义保存的参数,并取出所需要的参数
with open(model_path + 'checkpoint-best.stat', 'rb') as f:
checkpoint = pickle.load(f)
args = checkpoint['args']
而后根据参数调用model对象的__init__
,重新构造一个模型对象,在实例化的model对象中重新加载完整的图模型。至此完成图的加载。
在重新加载回图模型后,需要将保存的best dev的参数加载回来。这一步首先通过tf.train.list_variables(model_path)
来加载回所有参数对象(并没有加载回值),其返回对象类型为一组(name, shape)
的列表。
恢复图的终极目标是把加载回的参数放到图中的对应位置,所以需要进行图中参数与加载回的参数的映射,以便能正确的恢复。所以需要知道图中都有哪些参数及他们的命名,因此使用:
model_vars = {re.match("^(.*):\\d+$", var.name).group(1): var for var in tf.global_variables()}
取出所有参数。
对于re.match("^(.*):\\d+$", var.name).group(1)
,举一个简单的例子来解释这一步在做什么。
首先创建一些变量:
a = tf.Variable( [10,10,20] ,name = "a")
b = tf.Variable( [20,20,20,30] ,name = "b")
b = tf.Variable( [30,30,30,34,56] ,name = "c")
通过tf.global_variables()
可得结果:
[<tf.Variable 'a:0' shape=(3,) dtype=int32_ref>,
<tf.Variable 'b:0' shape=(4,) dtype=int32_ref>,
<tf.Variable 'c:0' shape=(5,) dtype=int32_ref>]
通过该正则表达式进行匹配
model_vars = {re.match("^(.*):\\d+$", var.name).group(1): var for var in tf.global_variables()}
得到
{'a': <tf.Variable 'a:0' shape=(3,) dtype=int32_ref>,
'b': <tf.Variable 'b:0' shape=(4,) dtype=int32_ref>,
'c': <tf.Variable 'c:0' shape=(5,) dtype=int32_ref>}
我们定义的参数名,参数对象的一个dic。
关于正则匹配,^ 为句子的开始,$为句子的结束," .* " 表示将任意字符重复零到多次," : “匹配命名空间中的” : ",group(1)表示取出第一个圆括号中的内容。在本例中也只有一个(),所以取出该括号中的内容,即变量的名称
而后,将该图中所需要的参数通过tf.train.init_from_checkpoint(model_path, assignment_map)
进行恢复,其中assignment_map
为{加载回的就图的参数:待恢复的新图的参数},因此用
`assignment_map = {name: model_vars[name] for name, _ in init_vars if name in model_vars}`
构建assignment_map
。再通过
tf.train.init_from_checkpoint(model_path, assignment_map)
sess.run(tf.global_variables_initializer())
完成对计算图的恢复。该计算图属于model类,将该model类进行返回。
而后取出数据,并进行测试等操作,完成对模型性能的评估。