tensorflow eatimator实现early-stopping

相信大家,为了避免过拟合,经常需要用到early-stopping,即在你的loss接近收敛的时候,就可以提前停止训练了。

预备知识

tensorflow estimator详细介绍,实现模型的高效训练

tensorflow通过tfrecord高效读写数据

API介绍

tf.estimator.experimental.stop_if_no_increase_hook(
    estimator, metric_name, max_steps_without_increase, eval_dir=None, min_steps=0,
    run_every_secs=60, run_every_steps=None
)
Args
estimator A tf.estimator.Estimator instance.
metric_name str, metric to track. “loss”, “accuracy”, etc.
max_steps_without_increase int, maximum number of training steps with no increase in the given metric.
eval_dir If set, directory containing summary files with eval metrics. By default, estimator.eval_dir() will be used.
min_steps int, stop is never requested if global step is less than this value. Defaults to 0.
run_every_secs If specified, calls should_stop_fn at an interval of run_every_secs seconds. Defaults to 60 seconds. Either this or run_every_steps must be set.
run_every_steps If specified, calls should_stop_fn every run_every_steps steps. Either this or run_every_secs must be set.
  1. estimator:定义你的模型结构,以及训练(train)、验证(evaluate)、预测(predict)过程
  2. metric_name:评判是否要early-stopping的度量
  3. max_steps_without_increase:当评判的度量metric如loss,最多 多少步不下降就early-stopping
  4. min_steps:至少训练多少步,才开始考虑early-stopping
  5. run_every_steps:每n步进行一次early-stopping的评估

说明

  1. 首先,需要明确一点:early-stopping是在验证(evaluate)过程中进行的,所以只能用tf.estimator.train_and_evaluate,并且度量metric_name是针对于验证集eval的,不是训练集;
  2. 整个过程是这样:训练train --> 保存模型 --> 验证evaluate --> 判断是否要early-stopping --> 训练train
  3. 所以,evaluate和early stop的频率实际上是由你模型保存的频率决定的
  4. max_steps_without_increase和run_every_steps的步数是在验证(evaluate)时才计算的,即run_every_steps是指每几次eval就进行early stop的判定,max_steps_without_increase是指evaluate多少次,loss不下降就early-stopping
  5. evaluate过程还可以定义accuracy这样的度量,这种是需要提高的,所以就有对应的tf.estimator.experimental.stop_if_no_decrease_hook

代码

import tensorflow as tf

from estimator import model_fn, input_fn_bulider

# 设置训练多少步就进行模型的保存
runConfig = tf.estimator.RunConfig(save_checkpoints_steps=10)

estimator = tf.estimator.Estimator(model_fn,
                                   model_dir='your_save_path',
                                   config=runConfig,
                                   params={'lr': 0.01})

# 在这里定义一个early-stopping
# 在eval过程执行early-stopping判断,所以评判标准也是eval数据集的metric_name
# max_steps_without_decrease:loss最多多少次不降低就停止。进行一次eval相当于一步。
early_stop = tf.estimator.experimental.stop_if_no_decrease_hook(estimator,
                                                                metric_name='loss',
                                                                max_steps_without_decrease=1,
                                                                run_every_steps=1,
                                                                run_every_secs=None)

logging_hook = tf.train.LoggingTensorHook(every_n_iter=1,
                                          tensors={'loss': 'loss:0'})

# 定义训练(train)过程的数据输入方式
train_input_fn = input_fn_bulider('train.tfrecord', batch_size=1, is_training=True)
# 定义验证(eval)过程的数据输入方式
eval_input_fn = input_fn_bulider('eval.tfrecord', batch_size=1, is_training=False)

# 创建一个TrainSpec实例
train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=100,
                                    hooks=[logging_hook, early_stop])
# 创建一个EvalSpec实例
eval_spec = tf.estimator.EvalSpec(eval_input_fn)

# 流程:训练train --> 保存模型 --> 验证eval --> 判断是否要early-stopping --> 训练train
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

你可能感兴趣的:(tensorflow,tensorflow,python,深度学习)