经典分类网络 ResNet 论文阅读及PYTORCH示例代码

上一篇说要尝试一下用 se_ResNeXt 来给 WS-DAN 网络提取特征,在此之前需要先搞懂 ResNeXt 的原理,而 ResNeXt 则是在 ResNet 基础上的改进,所以绕了一大圈,还得从 ResNet 开始。说来惭愧,之前只是用过 ResNet 来做分类任务,论文还真没有仔细读过,正好趁这个机会读一读这篇“神作”。

论文地址: https://arxiv.org/pdf/1512.03385.pdf

论文阅读

其实论文的思想在今天看来是不难的,不过在当时 ResNet 提出的时候可是横扫了各大分类任务,这个网络解决了随着网络的加深,分类的准确率不升反降的问题。通过一个名叫“残差”的网络结构(如下图所示),使作者可以只通过简单的网络深度堆叠便可达到提升准确率的目的。

经典分类网络 ResNet 论文阅读及PYTORCH示例代码_第1张图片
残差结构

残差结构的处理过程分成两个部分,左边的 与右边的 ,最后结果为两者相加。其中右边那根线不会对 做任何处理,所以没有可学习的参数; 为网络中 负责学习特征的部分,把整个残差结构看做是一个 函数的话,则 负责学习的部分可以表示为 ,这个结构学习的其实是 输出结果与输入的差值,这也是残差名字的由来。完整的 ResNet 网络由多个上图中所示的残差结构组成,每个结构学习的都是 输出与输入之间的差值,通过步步逼近,达到了比直接学习输入好得多的效果。

文中残差结构的具体实现分为两种,首先介绍 ResNet-18 与 ResNet-34 使用的残差结构称为 Basic Block,如下图所示,图中的结构包含了两个卷积操作用于提取特征。

经典分类网络 ResNet 论文阅读及PYTORCH示例代码_第2张图片
Basic Block

对应到代码中,这是 Pytorch 自带的 ResNet 实现中的一部分,跟上图对应起来看更加好理解,我个人比较喜欢论文与代码结合起来看,因为我除了需要知道原理之外,也要知道如何去使用,而代码更给我一种一目了然的感觉:

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

另一种残差结构称为 Bottleneck,就是瓶颈的意思:

经典分类网络 ResNet 论文阅读及PYTORCH示例代码_第3张图片
瓶颈

作者起名字真的很形象,网络结构也正如这瓶颈一样, 首先做一个降维,然后做卷积,然后升维,这样做的好处是可以大大减少计算量,专门用于网络层数较深的的网络,ResNet-50 以上的网络都有这种基础结构构成(不同层级的输入输出维度可能会不一样,但结构类似):
经典分类网络 ResNet 论文阅读及PYTORCH示例代码_第4张图片
Bottleneck

Pytorch 中的代码,注意到上图中为了减少计算量,作者将 256 维的输入缩小了 4 倍变为 64 进入卷积,在升维时需要升到 256 维,对应代码中的 expansion 参数:

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

由上面介绍的基本结构再加上池化以及全连接层,就构成了各种完整的网络:


经典分类网络 ResNet 论文阅读及PYTORCH示例代码_第5张图片
各网络结构

图中的网络在 Pytorch 中都已经集成进去了,而且都是预训练好的,我们可以在预训练好的模型上面训练自己的分类器,大大减少我们的训练时间。下面简单介绍一下如何使用 ResNet。

在 Pytorch 中使用 ResNet

Pytorch 是一个对初学者很友好的深度学习框架,入门的话非常推荐,官方提供了一小时入门教程:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html
在 Pytorch 中使用 ResNet 只需要 4 行代码:

from torch import nn
# torchvision 专用于视觉方面
import torchvision 
  
# pretrained :使用在 ImageNet 数据集上预训练的模型
model = torchvision.models.resnet18(pretrained=True)
# 修改模型的全连接层使其输出为你需要类型数,这里是10
# 由于使用了预训练的模型 而预训练的模型输出为1000类,所以要修改全连接层
# 若不使用预训练的模型可以直接在创建模型时添加参数 num_classes=10 而不需要修改全连接层
model.fc = nn.Linear(model.fc.in_features, 10)

下面你就可以使用这个模型来做分类了,当然到这里还没在自己的数据集上进行训练,关于如何训练可以参考官方教程:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
如果对代码以及源码有疑问的话可以在下面留言我们一起讨论。

最后,求赞求关注,欢迎关注我的微信公众号[MachineLearning学习之路] ,深度学习 & CV 方向的童鞋不要错过!!

你可能感兴趣的:(经典分类网络 ResNet 论文阅读及PYTORCH示例代码)