深度学习炼丹师的养成之路之——Batch size/Epoch/Learning Rate的设置和学习策略

这个名字好长啊。。但是,考虑到每一次训练都要耗费长达数日的GPU时间,每次启动训练前,细致而缜密的前期准备工作其实非常必要而且至关重要,这直接影响着数日之后的loss和最终的performance。

首先推荐的一个文章是前几日看到的,知乎上谭旭的一个回答,谈到了最近facebook的training ImageNet in one hour,比较详细地阐释了batch size的大小对收敛性的影响,以及batch size和lr的搭配问题,怎么调整lr以减小大batch size造成的performance下降,训练初期的warm-up的概念等,看完很有启发。

然后,这两天和Liuyongcheng师兄有一些交流,他大概提到了,一开始训练的话,大概到10个epoch可以看一下loss,太早的话看不出什么有意义的信息。我后来感觉确实如此,因为用ResNet50-SSD跑的时候,前面的iteration看起来没有VGG-SSD的loss降得那么快,结果隔天一早一看,发现loss还是降下去了。这说明深度学习炼丹确实还是要耐心的。

然后我训练ResNet50-SSD使用的caffe中的MultiStep的learing policy,基本上每20多个epoch左右把lr降低到原先的1/10,和Liuyongcheng师兄交流,说到一般10个epoch降一下

我这个多训练了10个epoch我想问题应该不会太大,大不了浪费一些计算资源,何况我看了一下,10个epoch的时候loss还没有flat,所以训20个应该也可以。另外,10个epoch其实也不好界定,因为毕竟做了各种data augmentation,所以数据最终一个epoch有多少也是不好说的。

然后他大概说,一般训练40-50个epoch就差不多了,太多话可能会过拟合。

还有一个有关Keras中选择合适lr的帖子,大概也是讲到了两种随时间调整lr的方法

  • 一个是每个epoch都降低,把lr当做epoch的一个函数,用epoch加上一个decay的参数作为分母。
  • 另一个就是像caffe中的multistep的策略一样,每隔多少个epoch降一下

然后讲到建议initial lr建议大一点(毕竟后面会逐步降),然后学习过程中要用比较大的momentum,这样即便lr变小后,也可以让梯度沿着正确的方向继续前进;再就是要多试验,各种花哨的lr schedule都是要靠试验找到更好的(想起了DeepLab用的poly learning rate。。)

  • Increase the initial learning rate. Because the learning rate will very likely decrease, start with a larger value to decrease from. A larger learning rate will result in a lot larger changes to the weights, at least in the beginning, allowing you to benefit from the fine tuning later.
  • Use a large momentum. Using a larger momentum value will help the optimization algorithm to continue to make updates in the right direction when your learning rate shrinks to small values.
  • Experiment with different schedules. It will not be clear which learning rate schedule to use so try a few with different configuration options and see what works best on your problem. Also try schedules that change exponentially and even schedules that respond to the accuracy of your model on the training or test datasets.

而且作者说到,如果在adam中(先前讲的是用在SGD中),那么这些可能就有点多余了,因为adam本身自己就有很强的自适应性,不过最终还是要看数据,看实验,data wins.

你可能感兴趣的:(deep-learning)