CNN入门之常见过拟合解决方法汇总

1. 引言

不管是在训练机器学习或是深度学习的模型,想必大家都有遇过欠拟合与过拟合的状况,而其中又以模型过拟合最让人头疼。

CNN入门之常见过拟合解决方法汇总_第1张图片

上图给了我们一个很好的例子,左边的图描了模型欠拟合的状况,中间的图描述了良好的模型该有的划分曲线,右边的图则是典型的过拟合的示例。

在正式踏入更深入的CNN模型之前,我们势必要了解过拟合这个名词,到底过拟合是什么?而我们又该如何避免这种情形的发生呢?

2. 欠拟合定义

在正式介绍过拟合(Overfitting)之前,我们不妨先来认识一下另一个常见的问题欠拟合(Underfitting)。当我们在训练一个模型时,发现不管是在训练集或是测试集数据都无法达到一定的准度时,就可能是遇到了Underfitting的状况。
通常造成Underfitting的主要原因包含 :

  • 训练时间不足
  • 模型复杂度不足

针对上述两种状况,在深度学习领域,一般来说可以通过过增加训练迭代的次数来解决训练时间不足的问题;同时可以通过调整隐含层的数量、模型的深度等策略来解决模型复杂度不足的问题。如果上述策略仍然无法改善模型效果,可能就要考虑是训练数据本身的问题。

3. 过拟合定义

相较于欠拟合,过拟合Overfitting通常是让我们更为之头疼的问题,而在深度学习中我们该如何观察Overfitting的状况呢?
CNN入门之常见过拟合解决方法汇总_第2张图片

参考上图,随着模型训练时间的增长、迭代的次数增加,我们训练集与测试集的Error都会逐步的下降,但当我们观察到训练集与测试集的Error开始分道扬镳时,就可能是Ovefitting的状况发生的临界点。
我们可以用一句话描述上述情况:

模型过度去学习、硬背训练数据

4.导致过拟合的原因

在深度学习领域,通常导致过拟合的原因如下:

  • 训练数据不足

模型的训练数据不足是大家最常会遇到的问题之一,当我们的训练集太小时,模型找不到数据泛化的特征,因此就顷向去学习死记硬背所拥有的信息。

  • 迭代次数过多

当我们不断让模型去学习训练数据集时,到了最后模型就会试图去硬背特征,而大大降低了泛化能力。

  • 模型复杂度太高
    此时刚好与Underfitting相反,当我们使用一个很强大的模型去学习较小的数据集时,模型较容易发生过拟合的状况。

5.解决过拟合的策略

5.1 解决训练数据不足

针对训练数据不足所导致的过拟合的状况,最直接的办法就是去收集或爬取更场景更丰富的训练数据,然而数据的获取往往是最困难的部分。但是在CV领域,数据增强就能派上用场,简单来说数据增强就是基于有限的数据生成更多等价(同样有效)的数据,丰富训练数据的分布,使通过训练集得到的模型泛化能力更强。

5.2 解决迭代次数过多

针对迭代次数过多、训练时间过长导致过拟合问题,我们可以透过设置Early Stopping来解决,Early Stopping其实就是
通过观察测试数据集Loss的变化来提前停止训练,并通过调整patience决定容忍度,不妨假设我们设置patience=10,也就是说当测试数据集的Loss在10个epoch后都没有下降,此时就要停止训练,我们可以参考以下代码来简单通过手动设置Early Stopping。

patience=10
## te_loss数据类型为list,存储每个epoch后的testset loss
if testing_loss > np.min(te_loss):
      n_patience += 1
else:
      n_patience = 0

if n_patience >= patience:
    print("The model didn't improve for %i rounds, break it!" % patience)
    break

5.3 Dropout

Dropout为深度学习中常用的技巧之一,即随机关闭NN层中的一定比例神经元(让其值为0),借此降低模型对各个神经元的依赖性。但过高比例的Dropout会影响模型的收敛,所以大家使用上也要特别注意。
CNN入门之常见过拟合解决方法汇总_第3张图片

5.4 正则化

正则化不管是在ML/DL中都时常会被使用,其原理是在loss 函数后加上惩罚项,数学公式如下:
CNN入门之常见过拟合解决方法汇总_第4张图片
希望大家不要被数学式子吓到了,我们可以这样想像,模型在收敛过程中,当发现某个特征比较有用时,最直接的作法就是提高这个特征的权重,也就是上图中的Wj,而为了避免model过度依赖这个特征,我们就在Loss函数上加上这个特征的权重(也就是我们常说的正则项),当模型试图提高这个权重时,也会因此提高Loss,借此来抑制这个weigh的起伏变化。

正则化又可以分为L1、L2正则化,通常来说我们会使用L2 正则化,因为其表现较稳定也比较不会造成权重归0的状况。

5.5 BatchNorm

虽然Batch Normalization主要用途不是避免Overfitting,但其仍然能起到一定Regularize(正则化)的作用,由于每次mini-batch中的mean跟variance都会不一样,借此影响到scale & shift参数的调整。其原理如下图所示:
CNN入门之常见过拟合解决方法汇总_第5张图片

6. 总结

本文重点介绍了在训练模型时常见的过拟合的现象,并针对造成过拟合的不同原因分别介绍了常见的解决方法。

您学废了嘛?

在这里插入图片描述
关注公众号《AI算法之道》,获取更多AI算法资讯。

参考链接: 戳我

你可能感兴趣的:(深度学习,cnn,计算机视觉,机器学习)