ResNet详解:ResNet到底在解决什么问题?

原作者开源代码:https://github.com/KaimingHe/deep-residual-networks

论文:https://arxiv.org/pdf/1512.03385.pdf

1、网络退化问题

在ResNet诞生之前,AlexNet、VGG等这些比较主流的网络都是简单的堆叠层,比较明显的现象是,网络层数越深,识别效果越好。但事实上,当网络层数达到一定深度的时候,准确率就会达到饱和,然后迅速下降。

ResNet详解:ResNet到底在解决什么问题?_第1张图片

2、网络退化的原因

由于反向传播算法中的链式法则,如果层层之间的梯度均在(0,1)之间,层层缩小,那么就会出现梯度消失。反之,如果层层传递的梯度大于1,那么经过层层扩大,就会出现梯度爆炸。所以,简单的堆叠层将不可避免的出现网络退化的现象。

虽然梯度消失/爆炸是网络隐藏层太深所导致的,但是在论文中,已经说了这问题主要通过标准化初始化和中间规范化层来解决。所以网络退化并不是因为梯度消失/爆炸导致的,那网络退化问题到底是由什么导致的呢?另一篇论文给出了答案:The Shattered Gradients Problem: If resnets are the answer, then what is the question?

大意是神经网络越来越深的时候,反传回来的梯度之间的相关性会越来越差,最后接近白噪声。因为我们知道图像是具备局部相关性的,那其实可以认为梯度也应该具备类似的相关性,这样更新的梯度才有意义,如果梯度接近白噪声,那梯度更新可能根本就是在做随机扰动。

3、残差网络

基于网络退化问题,论文的作者提出了残差网络的概念。一个残差块的数学模型如下图所示。残差网络和之前的网络最大的不同就是多了一条identity的捷径分支。而因为这一条分支的存在,使得网络在反向传播时,损失可以通过这条捷径将梯度直接传向更前的网络,从而减缓了网络退化的问题。

在第二节分析网络退化的原因时,我们了解到梯度之间是有相关性的。我们在有了梯度相关性这个指标之后,作者分析了一系列的结构和激活函数,发现resnet在保持梯度相关性方面很优秀(相关性衰减从 1 2 L \frac{1}{\sqrt{2^L}} 2L 1 1 L \frac{1}{\sqrt{L}} L 1了。这一点其实也很好理解,从梯度流来看,有一路梯度是保持原样不动地往回传,这部分的相关性是非常强的。

ResNet详解:ResNet到底在解决什么问题?_第2张图片
除此之外,残差网络并没有增加新的参数,只是多了一步加法。而在GPU的加速下,这一点额外的计算量几乎可以忽略不计。

不过我们可以看到,因为残差块最后是 F ( x ) + x F(x) + x F(x)+x的操作,那么意味着 F ( x ) F(x) F(x) x x x的shape必须一致。但在实际的网络搭建中,还可以利用1x1的卷积改变通道数目,其中上图左边是ResNet-34所用到的结构,右图这种类似瓶颈一样的结构就是ResNet-50/101/152所用到的结构。

ResNet详解:ResNet到底在解决什么问题?_第3张图片
而右边这样做有效减少了参数量,两者计算量对比:

  • 左边的参数量为:3x3x64x64+3x3x64x64 = 73,728
  • 右边的参数量为:1x1x256x64+3x3x64x64+1x1x64x256 = 69,632

可以看到,我们在一个残差块上就减少了2个数量级的参数,而在ResNet的一系列网络搭建过程中,是将这些结构大量堆叠起来。

4、实验结果

ResNet详解:ResNet到底在解决什么问题?_第4张图片

ResNet推荐参数如上图所示,作者还用全局平均池化替代了全连接层,一方面减少了参数量,另一方面全连接层易于过拟合并且严重依赖于 dropout 正则化,而全局平均池化本身就是起到了正则化作用,其本身防止整体结构的过拟合。此外,全局平均池汇总了空间信息,因此对输入的空间转换更加健壮。

ResNet详解:ResNet到底在解决什么问题?_第5张图片

最后的实验对比结果也是非常的明显,ResNet-34有效的减缓了梯度消失/爆炸的现象。而对于更深层网络的探索,ResNet甚至将网络层数堆叠到了1000层,虽然在工业上应用不多,但在学术理论上却是有很大的意义。

5、总结

而最后,来回答一下我们提出的问题:“ResNet到底在解决什么问题?”,我们重新来看一下Res Block的结构。
ResNet详解:ResNet到底在解决什么问题?_第6张图片
现在假设 x = 5 , H ( x ) = 5.1 x=5,H(x) = 5.1 x=5H(x)=5.1

  • 如果是非残差的结构,那么网络映射为: F ( 5 ) ′ = 5.1 F(5)' = 5.1 F(5)=5.1
  • 如果是残差结构,网络映射为: F ( 5 ) + 0.1 = 5.1 F(5) + 0.1 = 5.1 F(5)+0.1=5.1

这里的 F ′ F' F F F F都表示网络参数映射,引入残差后的映射对输出的变化更敏感。比如原来是从5.1到5.2,映射 F ′ F' F的输出增加了1/51=2%,而对于残差结构从5.1到5.2,映射 F F F是从0.1到0.2,增加了100%。明显后者输出变化对权重的调整作用更大,所以效果更好。(转自:resnet(残差网络)的F(x)究竟长什么样子?)后续的实验也是证明了假设的,残差网络比plain网络更好训练。因此,ResNet解决的是更好地训练网络的问题。

最后放上笔者用Keras和tf2实现的ResNet。

  • Keras-ResNet

你可能感兴趣的:(图像分类,人工智能,深度学习,算法)