不正经的理解:ResNet 残差与负反馈

本文分享一些自己在理解ResNet《Deep Residual Learning for Image Recognition》的一些心得,主要在于理解什么是学习残差,即学习输入与输出的差别。(ResNet论文下载)

作为一个电子信息工程专业的学生,初次见到ResNet,我一点也不陌生,我首先就联想到了反馈。在模电中,我们一般用负反馈来展宽频带,提高稳定性。ResNet中的反馈把输出给降低了,那个加法器前面的那些堆叠的卷积网络的输出是H(x)-x。为什么是H(x)-x,我们下面来解释。
不正经的理解:ResNet 残差与负反馈_第1张图片
对于任何一个神经网络,它的目的都是通过学习各个参数的值,找到一个函数H(x),使得y=H(x),从而具有“预测”(泛化)的功能。所有的神经网络,他都是使得输出接近H(x),包括ResNet也是这样(就是那个加法器后面的输出)。但是,和以前的AlexNet,VGG,GoogLeNet不同(如果不明白这些网络结构,可以大胆试试看我之前的博客),它加了一个shortcut的结构,就是那个长长的从输入直接连到加法器的曲线,这样呢,对一个很高层的这样一个结构来说,那个加法器后面的输出是H(x)(先不考虑那个加法器后面的relu操作),而又是谁相加呢,是那几个堆叠的卷积网络(就是么有那根曲线short cut的那几层)的输出F(x),和这一个网络小块的输入x,即H(x)=F(x)+x,那么,这些堆叠的卷积网络就是去学习F(x)=H(x)-x。这就是为什么叫做残差。

解释了这个残差结构之后,我们再分析一下为什么这样的网络能更深更稳定。首先我们知道一个这样的事实,就是就算很深的网络,其中起主要作用的就是中间的一些层,越往后,网络对于结果的贡献就越来越小了, 那么对于一个很高层的网络而言,它的输入x (注意,不是说的是整个ResNet的输入,就单单是这个小模块(a building block)的输入),应当就和他的输出H(x),随着训练的进行,二者会越来越接近,因为高层的网络对结果的输出基本没什么影响了。那么,中间的堆叠的卷积网络学习的 F(x)=H(x)-x,是不是就越来越接近0了,也就是残差越来越小了。 那么为什么这样会更容易使网络的层数加深呢?因为这种网络的方式是类似于去学习一种恒等映射,恒等映射就可以传的深一些,另外这种网络的训练就是推动这些堆叠的卷积网络去学习使得F(x)=0, 显然使参数接近零就可以了,这比推动 F(x)=H(x)更简单,也就使得网络能够更深。

这就是论文中一直提到的恒等映射,这样在深层,每一个这样的shortcut结构,他的输入x和输出H(x)近似相等,也就是所谓的恒等映射。看到这里,我么自然会有一个问题,既然是恒等,那么为什么要加这么几层呢,有什么意义呢,不加不是更好吗?我觉得是这样的,这就像是一个负反馈的网络结构,除了在最高的一些层外,他在更低的层使用这些结构是使得网络更加稳定。

我认为可以用负反馈的观点来解释为什么更加稳定,在不是高层的那些层,他们总是使得堆叠层的输出F(x)=H(x)-x 向着0的方向前进,假如这个小模块的输入x 偏离H(x)严重,就会被拉回来,使其拉回到接近 H(x) 的范围来,这就是负反馈,因而,整个网络结构能够非常稳定,而且更快速收敛。因而,所以不如整个网络都大量使用这样的building block。

为了减少计算量,文中还是用到了1x1卷积来降维减少计算量,如下:
不正经的理解:ResNet 残差与负反馈_第2张图片
再用PyTorch实现一下这个结构,以更深入地理解:

###############  pytorch 实现 ResNet  #########################################
import torch
from torch import nn
def conv3x3(in_planes,out_planes,stride=1):
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False)

class BasicBlock(nn.Module):
    def __init__(self,inplanes,planes,stride):
        super(BasicBlock,self).__init__()
        self.conv1=conv3x3(inplanes,planes,stride)
        self.bn1=nn.BatchNorm2d(planes)
        self.relu=nn.ReLU(inplanes=True)
        self.conv2=conv3x3(planes,planes)
        self.bn2=nn.BatchNorm2d(planes)
        self.downsample=downsample
        self.stride=stride
    
    def forward(self,x):
        residual=x
        
        out=self.conv1(x)
        out=self.bn(out)
        out=self.relu(out)

        out=self.conv2(out)
        out=self.bn2(out)
        
        if self.downsample is not None:
            residual=self.downsample(x)
            
        out += residual
        out=self.relu(out)
        
        return out

你可能感兴趣的:(我的读书笔记)