ResNet 论文笔记

Deep Residual Learning for Image Recognition

ResNet

论文链接: https://arxiv.org/abs/1512.03385

一、 Problem Statement

神经网络的深度很重要,但比较难训练。因此提出了一个问题: 学习更好的网络就像叠加更多层一样容易吗? 因为堆叠网络层会导致梯度的消失或者爆炸,导致在一开始就阻碍网络的收敛。这个问题很大程度可以由标准初始化和中间标准化层来解决,但是 学习退化(degradation) 的问题就出现了。 随着网络深度的增加,精度达到饱和(这可能并不奇怪),然后迅速下降。 这种degradation并不是由overfitting导致的。所以这种degradation说明了并不是所有的网络结构都是容易优化的。

二、 Direction

作者发现:随着网络深度的增加,精度达到饱和(这可能并不奇怪),然后迅速下降。这一点并不符合常理:如果存在某个 K K K层的网络 f f f是当前最优的网络,那么可以构造一个更深的网络,其最后几层仅是该网络 f f f K K K层输出的恒等映射(Identity Mapping),就可以取得与 f f f一致的结果;也许 K K K还不是所谓“最佳层数”,那么更深的网络就可以取得更好的结果。总而言之,与浅层网络相比,更深的网络的表现不应该更差。因此,一个合理的猜测就是,对神经网络来说,恒等映射并不容易拟合。 因此,提出了residual learning framework。

三、 Method

假设期望映射是 H ( x ) H(x) H(x),让多个堆叠的非线性层拟合另外一个映射 F ( x ) : = H ( x ) − x F(x):=H(x)-x F(x):=H(x)x。那原先期望的映射就是 H ( x ) : = F ( x ) + x H(x):=F(x)+x H(x):=F(x)+x。这里认为,最优化残差映射比最优化原先的映射更容易。 所以就添加了一个"shortcut connection",其包括identity mapping(恒等映射)和与Residual block的输出相加。Residual block的结构图如下:

ResNet 论文笔记_第1张图片

这个结构并不会引入额外的参数量和计算复杂度。如果所添加的网络层可以构建为恒等映射,那么更深的网络模型相对于它更小的模型来说,应该拥有较低training error。所以退化问题意味着:优化器通过多个非线性层来近似恒等映射可能存在困难。而使用了上面的residual learning framework,如果恒等映射是最优化的话,优化器会简单地使得多个非线性层的权重趋向于0来近似恒等映射。

恒等映射的公式可以定义为:

y = F ( x , { W i } ) + x y = F(x, \{W_i\}) + x y=F(x,{ Wi})+x

其中 x x x y y y是网络层的输入和输出。函数 F ( x , { W i } ) F(x,\{W_i\}) F(x,{ Wi})表示residual mapping。 根据上图显示,有两个网络层,所以可以进一步写为:
F = W 2 σ ( W 1 x ) F=W_2\sigma(W_1x) F=W2σ(W1x)
其中 σ \sigma σ是ReLU, F + x F + x F+x表示shortcut connection的element-wise addition,且 x x x F F F的维度必须一致。最后再使用一个ReLU函数。

identity mapping能够有效地解决退化问题。

上图是对于一维的,而对于卷积层来说,residual block的结构如下:
ResNet 论文笔记_第2张图片

最后来看一下ResNet的网络结构:

ResNet 论文笔记_第3张图片

四、 Conclusion

本文出发点并不是解决梯度消失和梯度爆炸的问题,而是解决网络退化的问题。梯度消失可以由Batch Normalization来解决,保证前向的梯度信号是一个非零方差值。作者发现,更深的网络可以得到更好的精度,但是随着网络加深会导致网络退化的问题,因此引入了Residual learning framework。在residual-1202网络结构中,CIFAR-10的training error 小于 0.1%,但testing error和residual-110层差不多,说明对于这个小型的数据来说,存在了过拟合,但表明了网络越深,精度肯定会好。

Reference

  1. https://zhuanlan.zhihu.com/p/80226180

你可能感兴趣的:(网络Backbone,神经网络,深度学习,计算机视觉)