Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour 笔记

作者使用了batch大小:8192,使用了256 GPUs,在一个小时内训练了ResNet-50,并且得到了和256大小的batch同样的训练精度。

2 Large Minibatch SGD

通常来说,我们在训练有监督任务的时候,会最小化loss:

是网络的参数,是训练集,就是损失函数。
minibatch SGD就是在一个batch的训练集上,进行参数的更新:

2.1 Learning Rates for Large Minibatches

论文的目的是在使用非常大的batch的时候能够维持训练的准确性和泛化性能。具体来说,就是在使用多个worker来进行数据并行训练的时候,不会牺牲模型的accuracy。
作者发现,下面的learning rate scaling rule能够适合于很大范围的batch size。

Linear Scaling Rule:当minibatch size乘以一个数,同样learning rate也乘以这个数。
所有其他超参数保持不变,

  • interpretation解释:为什么上面的方法会有效呢?首先考虑一个网络在某一个时刻的参数,和一组个minibatches,每一个minibatch的大小为。我们比较一下每个minibatch单独训练和这个batch一起训练的效果。
  • 第一种情况:在进行了次更新后

  • 第二种情况:训练是在这个batch的合集上进行,batch size大小为

    update rule

显然两个结果不太可能一样,但是假如 并且 那么我们就可以得到。

2.2 Warmup热身

当网络变化很剧烈的时候,上面提出的假设就不会成立,那么Linear Scaling Rule就不会有效果。但是作者发现,这样的情况可以通过一种热身的方式来缓解,具体来说就是,在训练的开始,使用一个更小的learning rate。

  • Constant warmup:一种热身的策略是使用一个小的定值作为初始的学习率,训练几个回合。这种策略对于物体检测,分割,fine-tune等问题在有些时候效果较好,但是当较大也就是batch较大的时候,就不是那么有效了,尤其在热身结束的时候会出现error的峰值。
  • gradual warmup:为了克服constant warmup的不足,作者使用了gradual warmup,就是一点一点地将学习率从小,增大。并且在增大后,回复到原始的learning rate schedule。

2.3 Batch Normalization with Large Minibatches

BN在提高训练效率和精度有很大的效果。但是一个minibatch在计算一些统计量的时候,需要整个minibatch的数据,当分布式或者多卡训练的时候,就会导致非常多的数据需要传输。
当使用BN的时候,每个sample的loss就会和整个batch的统计量相关,我们用表示单个sample的loss。用表示整个batch的loss。那么整个训练集的loss表示为。 表示一个大小为的batch。
当我们改变的大小的时候,就相当于改变loss function。More specifically the mean/variance statics computed by BN with different exhibit different levels of random variation
在分布式和多卡训练的情况下,如果每个worker的batch size大小为,那么总共的batch大小就是,相当于从许多batch中选择了个samples,每个sample就是一个batch。那么之前的公式就变为

batch update

We also note that the BN statics should not be computed across all workers, not only for the sake of reducing communication, but also for maintaining the same underlying loss function being optimized.

Subtleties and Pitfalls of Distributed SGD

Weight decay:weight decay是参数的L2-正则项。加入正则后的更新公式变为

weight decay

最后一项
是原来的loss,它可以通过反向传播来计算得到,
被分别计算,和梯度加起来用于更新。

Remark 1: Scaling the cross-entropy loss is not equivalent to scaling the learning rate.

Momentum correction:带动量的SGD被广泛应用于神经网络的更新中。一种常见的形式如下:

Momentum SGD

代替
得到下面的公式
momentum sgd

需要注意的是,
是和学习率
有关的,当学习率改变的时候,
也应该改变:
其中,作者将
称为momentum correction。作者发现,当
的时候,它对于稳定训练过程非常重要。

remark 2: Apply momentum correction after changing learning rate if using (10)

  • Gradient aggregation:对于每个worker的训练结果,需要将梯度汇聚起来,求平均用于更新参数。

remark 3: Normalize the per-worker loss by total minibatch size , not per-worker size 。

  • Data shuffling

remark 4: Use a single random shuffling of the training data (per epoch) that is divided amongst all workers.

Communication

对于每一个参数的梯度,都是通过allreduce操作来进行汇聚的。在进行allreduce之前,每个GPU都会计算自己的梯度,在allreduce*之后,每个GPU得到梯度的和。

论文中还讨论了软件和硬件的实现相关,详情可参考论文。

你可能感兴趣的:(Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour 笔记)