[深度学习]-Early stopping的作用及代码

1. Early stopping是什么?

EarlyStopping是Callbacks的一种,callbacks用于指定在每个epoch开始和结束的时候进行哪种特定操作。Callbacks中有一些设置好的接口,可以直接使用,如’acc’,’val_acc’,’loss’和’val_loss’等等。

2. Early stopping的作用?

当我们训练深度学习神经网络的时候通常希望能获得最好的泛化性能(generalization performance,即可以很好地拟合数据)。但是所有的标准深度学习神经网络结构如全连接多层感知机都很容易过拟合:当网络在训练集上表现越来越好,错误率越来越低的时候,实际上在某一刻,它在测试集的表现已经开始变差。
from:https://www.datalearner.com/blog/1051537860479157

常用的防止过拟合的方法是对模型加正则项,如L1、L2,dropout,但深度神经网络希望通过加深网络层次减少优化的参数,同时可以得到更好的优化结果,Early stopping的使用可以通过在模型训练整个过程中截取保存结果最优的参数模型,防止过拟合。

迭代次数增多后,达到一定程度后产生过拟合。从图中可以看出,训练集精度一直在提升,但是test set的精度在上升后下降。若是在early stopping的位置保存模型,则不必反复训练模型,即可找到最优解。
[深度学习]-Early stopping的作用及代码_第1张图片
图片来源:https://deeplearning4j.org/docs/latest/deeplearning4j-nn-early-stopping

缺点:
没有采取不同的方式来解决优化损失函数和降低方差这两个问题,而是用一种方法同时解决两个问题 ,既希望减少cost function,又希望不过拟合,那如果精度不满足要求,又无法继续训练呢?这样使得要考虑的东西变得更复杂。之所以不能独立地处理,因为如果你停止了优化代价函数,你可能会发现代价函数的值不够小,同时你又不希望过拟合。

3. Early stopping如何使用?

停止标准和停止准则参考:
https://www.datalearner.com/blog/1051537860479157

本文以代码形式进行简介:

参数解释:
https://blog.csdn.net/wjlwangluo/article/details/79758661
monitor: 监控的数据接口,有’acc’,’val_acc’,’loss’,’val_loss’等等。正常情况下如果有验证集,就用’val_acc’或者’val_loss’。但是因为笔者用的是5折交叉验证,没有单设验证集,所以只能用’acc’了。

**min_delta:**增大或减小的阈值,只有大于这个部分才算作improvement。这个值的大小取决于monitor,也反映了你的容忍程度。例如笔者的monitor是’acc’,同时其变化范围在70%-90%之间,所以对于小于0.01%的变化不关心。加上观察到训练过程中存在抖动的情况(即先下降后上升),所以适当增大容忍程度,最终设为0.003%。

**patience:**能够容忍多少个epoch内都没有improvement。这个设置其实是在抖动和真正的准确率下降之间做tradeoff。如果patience设的大,那么最终得到的准确率要略低于模型可以达到的最高准确率。如果patience设的小,那么模型很可能在前期抖动,还在全图搜索的阶段就停止了,准确率一般很差。patience的大小和learning rate直接相关。在learning rate设定的情况下,前期先训练几次观察抖动的epoch number,比其稍大些设置patience。在learning rate变化的情况下,建议要略小于最大的抖动epoch number。笔者在引入EarlyStopping之前就已经得到可以接受的结果了,EarlyStopping算是锦上添花,所以patience设的比较高,设为抖动epoch number的最大值。

mode: 就’auto’, ‘min’, ‘,max’三个可能。如果知道是要上升还是下降,建议设置一下。笔者的monitor是’acc’,所以mode=’max’。

代码:https://github.com/pytorch/ignite/blob/master/ignite/handlers/early_stopping.py

class EarlyStopping(object):

    def __init__(self, patience, score_function, trainer):

        if not callable(score_function):
            raise TypeError("Argument score_function should be a function.")

        if patience < 1:
            raise ValueError("Argument patience should be positive integer.")

        if not isinstance(trainer, Engine):
            raise TypeError("Argument trainer should be an instance of Engine.")

        self.score_function = score_function
        self.patience = patience
        self.trainer = trainer
        self.counter = 0
        self.best_score = None
        self._logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
        self._logger.addHandler(logging.NullHandler())

    def __call__(self, engine):
        score = self.score_function(engine)

        if self.best_score is None:
            self.best_score = score
        elif score <= self.best_score:
            self.counter += 1
            self._logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience))
            if self.counter >= self.patience:
                self._logger.info("EarlyStopping: Stop training")
                self.trainer.terminate()
        else:
            self.best_score = score
            self.counter = 0

你可能感兴趣的:(深度学习)