神经网络不学习的原因

Neural Network Check List

声明:

  1. 译自Reasons why your Neural Network is not working
  2. 翻译的虽不标准,也算靠谱

文章目录

  • Neural Network Check List
    • 如何使用这个指南
    • 数据问题
      • 1 检查输入数据
      • 2 试一下随机输入
      • 3 检查数据加载单元
      • 4 确保输入和输出是一一对应的
      • 5 输入和输出之间的关系映射的随机性会不会太强了?
      • 6 数据集中的噪声
      • 7 打乱数据集
      • 8 控制类别不平衡现象
      • 9 会不会是训练数据不足?
      • 10 尽量你的训练批次中的样本标签多样化
      • 11 减小batch size
      • 12 尝试经典数据集(MNIST,CIFAR10)
    • 数据归一化/正则化
      • 13 特征向量的标准化
      • 14 过量的正则化
      • 15 使用预训练网络时注意数据处理方式
      • 16 检查训练/验证/测试集的预处理
    • 网络的构建问题
      • 17 先尝试解决当前问题的简单版本
      • 18 搞定损失函数
      • 19 检查损失函数
      • 20 核查损失函数的输入
      • 21 调整损失权重
      • 22 注意其他的模型评价标准
      • 23 测试自定义层
      • 24 检查frozen层和变量
      • 25 增加网络规模
      • 26 检测隐藏层的维度错误
      • 27 检查梯度
    • 训练问题
      • 28 先处理小数据集
      • 29 检查权重初始化
      • 30 调整超参数
      • 31 减少正则化
      • 32 多多迭代
      • 33 注意训练和测试模式之间的切换
      • 34 可视化训练过程
      • 35 试一下不同的优化算法
      • 36 梯度爆炸和梯度弥散
      • 37 调整学习率
      • 空值问题

有一个训练了12个小时的神经网络,各方面看起来都不错:梯度缓慢下降、损失也在逐渐降低,但是预测结果却不好:输出全是0值(全都预测为背景),没有检测出任何标签。“到底是什么地方出错了?”——叫天天不应叫地地不灵╮(╯▽╰)╭

对于上述情况,或者另一种垃圾输出的情况——预测值只是所有标签的平均值,再或者更差的情况,模型准确率非常低…我们应该从什么地方开始检查模型呢?

如何使用这个指南

网络训练效果差,问题可能出在很多很多地方,但有些地方出问题的概率较大。所以,通常我会从以下几点开始:

  1. 首先,选择一个简单的经典模型(比如,用VGG识别图像),选择一个标准的损失函数。
  2. 关闭调优单元,如正则化和数据生成。
  3. 在最后调优(finetuning)模型之前,再次检查模型,保证其与原始模型一致。
  4. 保证输入数据无误。
  5. 先使用小数据进行训练(2-20个样本),使网络模型过拟合数据。然后慢慢增大数据集。
  6. 开启调优单元,自定义损失函数,尝试更复杂的模型

数据问题

神经网络不学习的原因_第1张图片

1 检查输入数据

保证你提供给网络的输入数据是有意义的。比如:我有时候弄混了图片的高和宽;或者不小心用全零的数据来训练网络;或者一直使用了同一个batch来不停的迭代训练。所以,你应该打印一组(输入,输出)并确保它们没有问题。

2 试一下随机输入

如果网络对随机输入的预测结果跟之前真实数据的预测结果差不多,那多半是网络某一层除了问题。这时,你需要一层一层的debug自己的网络。

3 检查数据加载单元

可能你的数据集没有问题,但是在数据的读取和预处理中除了问题,所以要检查网络第一层的输入。

4 确保输入和输出是一一对应的

也就是说,训练集中的输入不能(或者尽可能少的)有错误的标定。同时记得在打乱数据集的时候,输入和输出的打乱方式相同。

5 输入和输出之间的关系映射的随机性会不会太强了?

用机器学习里的术语来说就是, y = f ( x ) y=f(x) y=f(x) f f f的假设空间(模型空间)太大了。

比如说股票数据,想学习这个太难了。因为股票的随机性太大了,虽然通过数据获得的股票走势有一定的参考意义,但因为随机性大于数据内部的规律,光凭数据很难做出正确的预测。

6 数据集中的噪声

数据集不是完美的,比如对于MNIST手写数字识别数据集来说,可能有50%的样本时正确的可以学习的,而另外的50%的样本乱标定的噪声,那么这些乱标定的噪声会干扰网络的学习,应该实现被剔除。

当用网上爬来的数据进行机器学习或深度学习的时候,很容易出现这种问题。

7 打乱数据集

不打乱数据集的话,会导致网络在学习过程中的偏向问题。

8 控制类别不平衡现象

对于类别不平衡的分类问题,常规武器有:过采样、欠采样(不太好用)、调整损失函数等。

9 会不会是训练数据不足?

小样本学习时深度学习应用的一个重要问题。常规武器有迁移学习和生成网络(如GAN、VAE)法。如果是因为训练数据不足造成的网络学习效果不好,那将需要很大的精力来解决这个问题。

10 尽量你的训练批次中的样本标签多样化

如果一个训练批次中只有一类的样本,网络将很难收敛到最优。一般来说只要随机打乱的训练集就不会出现这个问题。再保险一点可以将batch_size搞大一点,比如128,但不要太大。

我在实验中,就经常会出现用同一数据集和网络得到的准确率忽高忽低,比如70%-90%。所以,时间允许的话,建议每次实验过程中多训练几次。

11 减小batch size

(・◇・)?上面刚说batch size可以搞大一点,这里又要调小???
研究表明:太大的batch size会降低模型的泛化能力。详见论文——戳我跳转

12 尝试经典数据集(MNIST,CIFAR10)

当使用一种新型的网络结构式,应该先用经典数据集测试一下,而不是直接应用于自己的真实数据。因为这些经典数据集都有参考标准(baseline,或者说是准确率的最低要求),而且没有数据方面的问题(如噪声、不平衡、随机性过大导致难以学习的问题等等)

数据归一化/正则化

神经网络不学习的原因_第2张图片

13 特征向量的标准化

记得以下二选一:

  • 标准化:均值为0方差为1
  • 归一化:数据大小位于0~1之间

14 过量的正则化

正则化可以防止模型过拟合,但过量的正则化会导致欠拟合。

15 使用预训练网络时注意数据处理方式

预处理的方式要跟你加载的预训练的网络一致,比如图片像素的大小是[0, 1], [-1, 1]还是[0, 255]?

16 检查训练/验证/测试集的预处理

引用斯坦福课程CS231n:

… any preprocessing statistics (e.g. the data mean) must only be computed on the training data, and then applied to the validation/test data. E.g. computing the mean and subtracting it from every image across the entire dataset and then splitting the data into train/val/test splits would be a mistake.

在数据预处理时,一些预处理需要的统计参数(如均值、标准差等)必须先从训练集中获得,再应用到验证集和测试集中。谨慎检查这一点,避免数据预处理中的错误。

网络的构建问题

神经网络不学习的原因_第3张图片

17 先尝试解决当前问题的简单版本

比如说,如果目标输出是object class and coordinates, 先试一下解决object class

18 搞定损失函数

引用斯坦福课程CS231n:

Initialize with small parameters, without regularization. For example, if we have 10 classes, at chance means we will get the correct class 10% of the time, and the Softmax loss is the negative log probability of the correct class so: -ln(0.1) = 2.302.

损失函数一般包括两部分:误分类的惩罚项和正则化项。我们应该先根据问题选择合适的惩罚项,然后在尝试正则化项

19 检查损失函数

如果你自定义了损失函数,一定要检查它,还要进行单元测试。因为自定义的损失函数经常会有细小的错误,导致网络出现学习问题,而且这通常很难发现。

20 核查损失函数的输入

使用神经网络框架中的损失函数时,一定要注意损失函数的输入。比如,在PyTorch中,我经常会弄混NLLLossCrossENtropyLoss,前者的输入是一个softmax输入(即在0到1之间),而后者不是。

21 调整损失权重

如果你的损失函数由多个子损失加权而成,那就要注意它们的关联关系和权重。

22 注意其他的模型评价标准

有时候,损失函数的值并不是评价网络训练好坏的最好方式。如果可以的话,可以使用其他的评价标准,如准确率。

23 测试自定义层

如果网络中有你自定义的网络层,多检查几遍…

24 检查frozen层和变量

可能你无意间关闭了某些层的权重更新。

25 增加网络规模

有时候网络训练效果不好是因为网络的容量太小,增加全连接层数或隐藏层神经元数目。

26 检测隐藏层的维度错误

如果你的输入数据维度(k, H, W) = (64, 64, 64),那确实很容易在维度上弄混。如果不放心维度上的问题,可义用几个容易分辨的数值试一下,检查一下它们在各个网络层中传递时的变化。不过这里一般不会出问题。

27 检查梯度

如果反向传递的求导是你自己算的,确实需要反复检查。

训练问题

28 先处理小数据集

从当前数据集中选择极其少量的数据,用自己的网络模型过拟合这些数据,确定网络没有问题。

例如,先用2个样本训练,观察你的网络能否学到这两个样本之间的不同,然后逐步扩展到更多数据。

29 检查权重初始化

保险起见,可以使用XavierHe初始化。有时候,不好的初始化确实会是网络的学习陷入一个局部最优解,所以也可以试一下其他的初始化方法,看看是否有用。

30 调整超参数

超参数的最优值可以通过经验或多次实验获得。如果时间允许的话,还可以通过交叉验证选择最优超参数。

31 减少正则化

过多的正则化会导致网络欠拟合。Dropout, batchnormalization, L2正则化作为正则化手段,当网络欠拟合时候应该先去掉这些方法。

32 多多迭代

可能你的网络需要更多的迭代次数才能获得有意义的预测。如果网络的损失函数值还在下降,务必让它继续迭代下去;如果网络的损失函数不再下降,也应该继续观察几轮确保训练没有卡在局部最小点。

33 注意训练和测试模式之间的切换

有些网络层如BatchNormalizationDropout层在训练模式和测试模式下是不一样的。

34 可视化训练过程

  • 监视每一层的权重和输出,确保激活函数和每一层的权重更新正常。
  • 可以使用TensorboardCrayon等可视化库
  • 如果有一层激活函数后的平均值远大于0,可以尝试BatchNormalization或ELU激活函数
  • 权重和偏执的数值分布:权重应该近似服从标准高斯分布,偏执应该从0开始逐渐过渡到近似高斯分布(除了个别网络外,如LSTM)。也就是说权重的正负分布应该相似,偏执一般不会过大,要不然可能就是网络的训练除了问题。

For weights, these histograms should have an approximately Gaussian (normal) distribution, after some time. For biases, these histograms will generally start at 0, and will usually end up being approximately Gaussian (One exception to this is for LSTM). Keep an eye out for parameters that are diverging to +/- infinity. Keep an eye out for biases that become very large. This can sometimes occur in the output layer for classification if the distribution of classes is very imbalanced

35 试一下不同的优化算法

一般来说优化器的选择不会导致网络训练的结果太差,除非你选择的优化器超参数太糟了。当然,合适的优化器可以是网络训练的更快。常用的优化器有AdamSGDRMSprop

36 梯度爆炸和梯度弥散

梯度爆炸和梯度弥散产生的根本原因是深度学习中多层梯度累积。如1.1的n次方无穷大,0.9的n次方无穷小。

  • 网络中某些层过大的输出会造成梯度爆炸,此时应该为该输出取一个上界。
  • 激活函数的梯度应该在0.5到2之间。

37 调整学习率

学习率决定了网络训练的速度,但学习率不是越大越好,当网络趋近于收敛时应该选择较小的学习率来保证找到更好的最优点。

一般学习率的调整是乘以/除以10的倍数。

空值问题

首先训练集中不能有空值!!其次,网络训练过程中也不能出现空值,几点建议:

  • 降低学习率,尤其是在前100次迭代中出现了NAN
  • 分母上的空值或者非正数的对数会导致NAN出现
  • 一层一层的检查空值是否出现,比如通过Python中的assert
**欢迎补充~**

你可能感兴趣的:(机器学习)