【自然语言处理】使用Tensorflow-Bert进行分类任务时输出每个Train Epoch的信息

前言

最近任务需要用到Bert,一个头疼的地方是官方代码只有在跑完指定的epoch次数之后才进行评估。可是基于任务的要求,需要输出每轮的评估信息(比如Acc, Loss)。
相似的需求类似:How to get a total training set loss for an epoch using Tensorflow Estimator api

方法

由于Bert在Tensorflow使用了Estimator,所以一个办法是使用Hook。思路如下:
1.需要明确的是Estimator是如何进行Eval的,这段代码可以参考run_classifier中if FLAGS.do_eval:之后的代码。
2.在训练过程中加入Hook,然后判断是否为一个epoch(或者指定的步数),随后进行eval即可。
一个可用的Hook是tf.train.SessionRunHook(),具体用法请点击tf.train.SessionRunHook()用法。在tf.train.SessionRunHook()中有个after_run方法,方法在每个session.run()之后会调用。而session.run()按我的理解是run一个batch(Bert有定义)。所以思路很明确,run到一个epoch(num_of_example//batch_size)的时候,进行模型的评估即可。
具体先定义一个hook。

class EvaluationHook(tf.train.SessionRunHook):
    def __init__(self, **wrapper):
        self.estimator = wrapper['estimator']
        self.eval_steps = wrapper['eval_steps']
        self.eval_input_fn = wrapper['eval_input_fn']
        self.train_steps = wrapper['train_steps']
        self.step = 0

    def after_run(self, run_context, run_values):
        self.step += 1
        if (self.step % self.train_steps == 0):   #it means an epoch
            epoch = self.step // self.train_steps
            result = self.estimator.evaluate(input_fn=self.eval_input_fn, steps=self.eval_steps)
            print("epoch:{}, eval_accuracy: {}".format(epoch, result["eval_accuracy"]))
            print("epoch:{}, eval_loss: {}".format(epoch, result["eval_loss"]))

首先Hook的创建需要一个Estimator,eval_steps,eval_input_fn,这些用于模型的评估。train_steps用于判断是否到达一个epoch。self.step代表运行的步数(sess.run一次+1)。
其次在after_run中定义模型的输出。
接着回到run_classifier中,

    """used for evaluation"""
    eval_examples = processor.get_dev_examples(FLAGS.data_dir)
    num_actual_eval_examples = len(eval_examples)
    if FLAGS.use_tpu:
        # TPU requires a fixed batch size for all batches, therefore the number
        # of examples must be a multiple of the batch size, or else examples
        # will get dropped. So we pad with fake examples which are ignored
        # later on. These do NOT count towards the metric (all tf.metrics
        # support a per-instance weight, and these get a weight of 0.0).
        while len(eval_examples) % FLAGS.eval_batch_size != 0:
            eval_examples.append(PaddingInputExample())

    eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
    file_based_convert_examples_to_features(
        eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file)

    tf.logging.info("***** Running evaluation *****")
    tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                    len(eval_examples), num_actual_eval_examples,
                    len(eval_examples) - num_actual_eval_examples)
    tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

    # This tells the estimator to run through the entire set.
    eval_steps = None
    # However, if running eval on the TPU, you will need to specify the
    # number of steps.
    if FLAGS.use_tpu:
        assert len(eval_examples) % FLAGS.eval_batch_size == 0
    eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

    eval_drop_remainder = True if FLAGS.use_tpu else False
    eval_input_fn = file_based_input_fn_builder(
        input_file=eval_file,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=eval_drop_remainder)
    eval_hook = EvaluationHook(eval_steps=eval_steps, train_steps=train_steps, estimator=estimator, eval_input_fn=eval_input_fn)
    """end"""
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps, hooks=[eval_hook])

以上代码的大致意思为:
1.定义Estimator.eval时所需的eval_steps,eval_input_fn,然后同estimator和train_steps(用于判断是否达到一个epoch)一起打包传入hook。
2.将hook传入estimator.train。(hooks=[eval_hook])
注意:尽量使得FLAGS.save_checkpoints_steps步数小于等于train_steps,因为estimator调用eval时是从模型中恢复,如果还没保存的话容易出现eval的值是上一轮的。

参考

tf.train.SessionRunHook()用法

你可能感兴趣的:(tensorflow,Bert,epoch,Hook,estimator)