过拟合是训练神经网络中常见的问题,本文讨论了产生过拟合的原因,如何发现过拟合,以及简单的解决方法。
发现过拟合问题
在训练神经网络时,我们常常有训练集、测试集和验证集三种数据集。
有时候我们会发现,训练出来的神经网络在训练集上表现很好(准确率很高),但在测试集上的准确率比较差。这种现象一般被认为是过拟合,也就是过度学习了训练集上的特征,导致泛化能力较差。
hold out 方法
那么如何发现是否存在过拟合方法呢?一种简单的思路就是把训练集分为训练集和验证集,其中训练集用来训练数据,验证集用来检测准确率。
我们在每个迭代期的最后都计算在验证集上的分类准确率,一旦分类准确率已经饱和,就停止训练。这个策略被称为提前停止。
示例
以MNIST数据集为例,这里使用1000个样本作为训练集,迭代周期为400,使用交叉熵代价函数,随机梯度下降,我们可以画出其损失值与准确率。
训练集上的损失值和准确率:
验证集上的损失值和准确率:
对比测试集与验证集的准确率:
可以发现:训练集上的损失值越来越小,正确率已经达到了100%,而验证集上的损失会突然增大,正确率没有提升。这就产生了过拟合问题。
增大训练量
一个最直观,也是最有效的方式就是增大训练量。有了足够的训练数据,就算是一个规模很大的网络也不太容易过拟合。
例如,如果我们将MNIST的训练数据增大到50000(扩大了50倍),则可以发现训练集和测试集的正确率差距不大,且一直在增加(这里只迭代了30次):
但很不幸,一般来说,训练数据时有限的,这种方法不太实际。
人为扩展训练数据
当我们缺乏训练数据时,可以使用一种巧妙的方式人为构造数据。
例如,对于MNIST手写数字数据集,我们可以将每幅图像左右旋转15°。这应该还是被识别成同样的数字,但对于我们的神经网络来说(像素级),这就是完全不同的输入。
因此,将这些样本加入到训练数据中很可能帮助我们的网络学习更多如何分类数字。
这个想法很强大并且已经被广泛应用了,更多讨论可以查看这篇论文。
再举个例子,当我们训练神经网络进行语音识别时,我们可以对这些语音随机加上一些噪音--加速或减速。
规范化(regularization)
除了增大训练样本,另一种能减轻过拟合的方法是降低网络的规模。但往往大规模的神经网络有更强的潜力,因此我们想使用另外的技术。
规范化是神经网络中常用的方法,虽然没有足够的理论,但规范化的神经网络往往能够比非规范化的泛化能力更强。
一般来说,我们只需要对进行规范化,而几乎不对进行规范化。
L2规范化
学习规则
最常用的规范化手段,也称为权重衰减(weight decay)。
L2规范化的想法是增加一个额外的项到代价函数上,这个项被称为规范化项。例如,对于规范化的交叉熵:
对于其他形式的代价函数,都可以写成:
由于我们的目的是使得代价函数越小越好,因此直觉的看,规范化的效果是让网络倾向于学习小一点的权重。
换言之,规范化可以当做一种寻找小的权重和最小化原始代价函数之间的折中。
现在,我们再对和求偏导:
因此,我们计算规范化的代价函数的梯度是很简单的:仅仅需要反向传播,然后加上得到所有权重的偏导数。而偏置的偏导数不需要变化。所以权重的学习规则为:
这里也表明,我们倾向于使得权重更小一点。
那这样,是否会让权重不断下降变为0呢?但实际上不是这样的,因为如果在原始代价函数中的下降会造成其他项使得权重增加。
示例
我们依然来看MNIST的例子。这里,我使用的规范化项进行学习。
训练集上的准确率和损失值和之前一样:
测试集上的损失值不断减少,准确率不断提高,符合预期:
L1规范化
学习规则
这个方法是在未规范化的代价函数上加一个权重绝对值的和:
对其进行求偏导得:
其中就是的正负号。
与L2规范化的联系
我们将L1规范化与L2规范化进行对比:
两种规范化都惩罚大的权重,但权重缩小的方式不同。
在L1规范化中,权重通过一个常量向0进行收缩;
而L2规范化中,权重通过一个和成比例的量进行收缩。
所以,当一个特定的权重绝对值很大时,L1规范化的权重缩小远比L2小很多;而当很小时,L1规范化的缩小又比L2大很多。
因此,L1规范化倾向于聚集网络的权值在相对少量的高重要连接上,而其他权重就会被趋向于0。
Dropout
Dropout是一种相当激进的技术,和之前的规范化技术不同,它不改变网络本身,而是会随机地删除网络中的一般隐藏的神经元,并且让输入层和输出层的神经元保持不变。
我们每次使用梯度下降时,只使用随机的一般神经元进行更新权值和偏置,因此我们的神经网络时再一半隐藏神经元被丢弃的情况下学习的。
而当我们运行整个网络时,是两倍的神经元会被激活。因此,我们将从隐藏神经元的权重减半。
这种技术的直观理解为:当我们Dropout不同的神经元集合时,有点像我们在训练不同的神经网络。而不同的神经网络会以不同的方式过拟合,所以Dropout就类似于不同的神经网络以投票的方式降低过拟合。
对于不同的技术,其实都可以理解为:我们在训练网络的健壮性。无论是L1、L2规范化倾向于学习小的权重,还是Dropout强制学习在神经元子集中更加健壮的特征,都是让网络对丢失个体连接的场景更加健壮。
参考
- What is regularization in machine learning?
- Improving neural networks by preventing co-adaptation of feature detectors
- best practices for convolutional neural networks applied to visual