如何训练好一个神经网络?

文章目录

    • 参考依据
    • 两个现象
        • 1.神经网络的训练没有想象中简单
        • 2. 神经网络训练的失败往往是悄无声息的
    • 正确的训练方式
        • 1. 数据第一!
        • 2. 制作端到端的训练/验证框架 + 得到baselines
        • 3. 过拟合
        • 4. 正则化
        • 5. 调参
        • 6. 精益求精

参考依据

参考自Andrej Karpathy大佬(特斯拉AI总监,李飞飞学生)的博客:http://karpathy.github.io/2019/04/25/recipe/

在此文章的基础上,结合了自己的理解和想法,写下了这篇博文。

两个现象

1.神经网络的训练没有想象中简单

很多框架,包括Pytorch、Tensorflow在内,都会提供一些可以直接调用神经网络模型接口,这往往会给人一种错觉,就是神经网络模型是即插即用的。甚至有时候,人们对于模型的训练会简化到以下这种地步:

>>> your_data = # plug your awesome dataset here
>>> model = SuperCrossValidator(SuperDuper.fit, your_data, ResNet50, SGDOptimizer)
# conquer world here

说明:
上面这个是Andrej在描述这种现象写的一个示例代码,注释真的是太搞笑了。

虽然我们十分熟悉这种API调用的模式,并且希望达到这种效果,但是神经网络没有我们想的这么容易。神经网络是数据驱动的,不同的数据分布在同一神经网络模型展现的效果不总是同样好的,神经网络有一个著名的理论叫No Free Lunch Theorem:对于基于迭代而产生的最优算法,不存在某种算法对于所有问题都有效。 并且简单地使用反向传播和随机梯度下降不会使得你的模型总是work,Batch Norm也不总是能加速收敛,在一些数据集上,丢掉BatchNorm反而能更好地拟合数据。所以如果你不了解其中的原理,你可能就会失败(我觉得大概率,至少不是最优解)。

2. 神经网络训练的失败往往是悄无声息的

这些错误往往不是显性的语法错误,而是一些内容上的、逻辑上的错误。你有时候会发现你的模型能够work,但实际上并不是这样。比如,我有时候会因为标签标错,然后在训练进行测试的时候才发现这个问题,只通过训练时反馈的信息无法看出任何异常。还有一次我训练分割模型,对数据集的标签进行增广时,由于缩放和旋转采用的方式是双线性插值,导致模型在训练过程中loss越来越大。 所以,如果你的模型在训练时报错,那么你是幸运的,因为在很多时候,它往往是悄无声息的。

总的来说,快速而暴力的训练方式在神经网络的训练中是不起作用的,必须要有耐心,并且循序渐进。

正确的训练方式

1. 数据第一!

训练一个神经网络,第一步绝对不是敲代码,而是检查数据(特别是自己的数据集)!!这个我真的深有体会,当你在魔改网络,加各种骚操作之前,一定要确保你的数据是正确的!! 所以与其在之后花大工夫来检查,然后发现其实是数据问题,还不如一开始就确保数据的正确性。所以先浏览一遍数据是必要的,观察数据的分布特点以及是否存在类别不平衡情况,这将取决我们应该去探索哪种网络模型。例如,局部特征就足够了吗?是否需要全局信息特征?应该采取什么形式的数据增广?空间信息重要吗?还是直接可以平均池化?图像的细节重要吗?我们应该下采样到什么程度?标签是否有噪声?
除了定性的观察一遍数据,也可以编写程序(搜索、过滤、排序等)对数据进行一些定量分析(例如标签类型、标注数量、标注大小等),然后可视化数据分布,并且找出异常值(outliers)。

2. 制作端到端的训练/验证框架 + 得到baselines

下一步就需要建立一个完整的训练/验证框架,并通过一系列的实验来确保正确性。最好先选择一个简单的模型把流程跑通 我们在这个模型上完成训练、可视化loss和一些其他的指标、进行模型预测和进行一系列消融实验。

Tips & Tricks:

  1. 设置随机数种子 设置好随机数种子,确保模型的可复现性。
  2. 简化 不要去抱有一些不必要的想法。例如数据增广,因为它是用来提高模型的泛化能力的操作,在目前这个阶段,是不必要引入的,徒增训练的负荷。
  3. 有效的评估 当绘制loss值时,要以整个test/val数据集的loss为单位,而不是以batch为一个单位。
  4. 在初始阶段验证损失函数 从初始化阶段就要确保损失函数计算的正确性。
  5. 一个好的初始化 正确初始化最后一层的权重。例如你要对一个均值为50的数据集做回归预测,那么最后的logits的bias就可以初始化为50;对于分类任务,如果你的数据不平衡,假设正负样本比例为1:10,就可以让最后的logits的bias初始化为0.1。一个好的初始化能够加速收敛,避免你的网络在前几个迭代过程都是在学习偏差。
  6. human baseline 监控所以的对于人类来说是可解释的或者可检查的指标(比如accuracy、Mean IOU等),将这些指标和人类的指标相对比(比如,如果让你人为地去分类这批数据,你觉得你的准确率会是多少)。
  7. 与输入无关的baseline 训练一个与输入无关的baseline(最简单的方法是将所有的输入变为0),这个可以检测你的模型是否能够获取输入信息。说实话,这个我没太懂,有大佬懂了的话,麻烦指点一二。
  8. 过拟合一个单batch模型 令batch_size = 1(或者2)来训练部分样本,从而得到一个过拟合的模型。这样做有两个好处:首先是可以增加模型的容量(比如增加一些层和卷积核),其次是可以观察现在这个模型能达到的最小loss值。这个可以检验模型的能力,要确保现在选的模型具有足够的拟合能力。
  9. 确保训练loss减少 在所有数据上训练,如果训练集的loss不再减少,需要再次确定你的模型具有足够的拟合能力。
  10. 数据送入网络前先可视化 在将数据送入模型前,先可视化一下数据,确保数据是正确。在y = model(x)之前,进行可视化。
  11. 动态可视化预测 在训练过程中,对固定批次上的样本可视化预测结果。
  12. 使用反向传播来绘制关系依赖图

3. 过拟合

在这一阶段,我们应该对数据集有一个很好的了解,并且有完整的训练/验证流程。对于任何的模型,我们都能计算得出我们需要的指标。现在可以开始迭代一个好模型了。一般为两个阶段:首先是使得模型足够强,能够在训练集上过拟合;然后在使用归一化策略,放弃一些训练loss,从而降低验证loss,达到一个平衡。

Tips & Tricks:

  1. 选择模型 先要为数据选择一个好的、合适的模型。其中重要的一点就是:不要逞英雄!不然一开始尝试复杂的、花里胡哨的模型,然后疯狂地做一些骚操作,先选择一个最简明、最普遍应用的模型。比如,如果是分类任务,直接上ResNet-50。
  2. Adam是保险的 Adam对学习率的设置更宽容,但是SGD的表现性能要更好(学习率的调整范围更窄:需要更精确的学习率)
  3. 一次只复杂化一个 这意味着,当我们有很多可以增加模型复杂度的方法,不要一股脑地全部用上去,一次只使用一个。
  4. 注意学习率衰减 要主要学习率的衰减策略,最开始的可以不用学习率衰减策略,而是使用恒定的学习率。这可以避免你的学习率过低导致模型不够拟合。

4. 正则化

理想情况下,进行到这一步的时候,我们可以得到一个能够拟合训练集的模型了(有可能存在过拟合现象)。现在我们需要加一些正则化操作,是的模型具有更强的泛化能力。
Tips & Tricks:

  1. 更多的数据 提高模型的泛化能力最重要的一条就是尽可能地收集多的真实样本数据。花尽功夫想使得小样本获得一个好的拟合性能是不现实的。
  2. 数据增强 这个不用多说,使用现有的数据来模拟一些数据。
  3. 有创意的增强 如果2.还不能满足要求,可以使用例如GANs的方法用来进行数据增强。
  4. 预训练模型 即使你拥有足够多的数据量,使用预训练模型也是没有坏处的。
  5. 坚持监督学习 目前还没有任何版本的无监督训练模型在计算机视觉领域取得显著成果。
  6. 更小的输入维度 如果图像细节不重要,可以将图像缩放小一些。
  7. 更小的模型 许多情况下,可以给网络加上领域知识限制(Domain Knowledge Constraints),使得模型变小。比如,我们之前都是给分类网络加上全连接层,后来,逐渐被简单的平均池化所代替,大大减少了参数
  8. 减少Batch Size 小的Batch size对于Batch Norm来说,在某种程度上能够增强模型的泛化能力。
  9. Drop 增加Dropout层,但是对于内置Batch Norm的网络来说,Dropout层似乎效果不太好(所以慎用)
  10. 增加weight decay
  11. 提前停止 出现loss不再下降的情况,应该提前停止训练,避免过拟合。
  12. 尝试更大的模型 最后提到这一点,并且在提前停止后才提到。虽然大的模型会过拟合得更厉害,但使用“提前停止”会使得它们比小模型表现得更好。

最后,提到一点,可以对于网络的第一层参数做一个可视化效果,看网络是否能够捕捉到一些有用的边缘信息。如果看起来像一团噪音,那么就要注意模型是否出现问题了。

5. 调参

Tips & Tricks:

  1. 随机网格搜索 这种方法有点费时,但是可以对于某些重要的超参做随机网格搜索。
  2. 超参数优化

6. 精益求精

Tips & Tricks:

  1. ensembles 把几个模型融合在一起,至少可以提高2%。如果算力顶不住,可以尝试使用网络蒸馏(https://arxiv.org/abs/1503.02531)
  2. 让模型’飞’一会 有时候你需要做的只是什么都不管,让模型继续训练下去,保持耐心可能会获得额外的守护喔,静静地让‘子弹飞一会’。

你可能感兴趣的:(深度学习)