SE-ResNet的实现

见:D:\pythonCodes\深度学习实验\4.1_经典分类网络\inference代码汇总\models\se_resnet.py

一、SE-ResNet的实现方法

读了senet这篇论文之后,可以知道senet并没有提出一个新的网络,而是提出了一个即插即用的模块。这个模块叫做SE Block(在实现的时候,为了防止与SEBasicBlock这个名字混淆,叫做SELayer)。

本文希望实现se_resnet网络,也就是将SE Block嵌入到ResNet中形成的网络。se_resnet与resnet的差别就是,就是在BasicBlock(resnet18/34使用的是BasicBlock堆叠,而resnet50/101/152使用的是Bottleneck进行堆叠,这里就以BasicBlock举例,Bottleneck完全一样)中增加了SE Block这个操作。

比如下图,上面是BasicBlock的结构,下图就是SEBasic的结构,就是多出来了一个小圈圈。

SE-ResNet的实现_第1张图片

 通过读resnet的源码,我们知道是通过Resnet()这个类来组织成整个网络的。比如:

          resnet34 = ResNet(BasicBlock, [3,4,6,3])

          resnet18 = ResNet(BasicBlock, [2,2,2,2])

          resnet50 = ResNet(Bottleneck, [3,4,6,3])

          resnet101 = ResNet(Bottleneck, [3,4,23,3])

          resnet152 = ResNet(Bottleneck, [3,8,36,3])
 

ResNet()接收两个参数,一个是block,另一个是堆叠的次数layers。只要传入参数,就能组织成一个网络了。比如传入的是BasicBlock,[3,4,6,3]就能得到resnet34了。这个函数就会自动地用3个BasicBlock组成layer1,用4个BasicBlock组成layer2,用6个BasicBlock组成layer3,用3个BasicBlock组成layer3,然后加上头尾等,组成一个网络。

我们可以利用ResNet()函数来构建我们的se_resent网络。只要给ResNet()传入SEBasicBlock和[3,4,6,3]就可以得到se_resnet34了。。

因此最关键的就是实现SEBasicBlock。而SEBasicBlock代码简直就是照抄BasicBlock代码,只要加上SELayer就行了。

(1) SELayer的实现

就是论文中的SE Block,在实现的时候,为了防止与SEBasicBlock这个名字混淆,叫做SELayer。

就是实现下面这个操作:

SE-ResNet的实现_第2张图片

代码:

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)        #全局平均池化,输入BCHW -> 输出 B*C*1*1
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),   #可以看到channel得被reduction整除,否则可能出问题
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)     #得到B*C*1*1,然后转成B*C,才能送入到FC层中。
        y = self.fc(y).view(b, c, 1, 1)     #得到B*C的向量,C个值就表示C个通道的权重。把B*C变为B*C*1*1是为了与四维的x运算。
        return x * y.expand_as(x)           #先把B*C*1*1变成B*C*H*W大小,其中每个通道上的H*W个值都相等。*表示对应位置相乘。

(2)SEBasicBlock的写法

有了SELayer之后,就可以很容易地写SEBasicBlock了。就是照抄BasicBlock代码,只要加上SELayer就行了。

class SEBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 *, reduction=16):
        # 参数列表里的 * 星号,标志着位置参数的就此终结,之后的那些参数,都只能以关键字形式来指定。
        super(SEBasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, 1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.se = SELayer(planes, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

就是多了以下两句:

SE-ResNet的实现_第3张图片

 其余的与senet中的basicBlock完全一致。SEBottleneck也是比Bottleneck多了这么一点。

二、完整代码

有了SELayer和SEBasicBlock之后就可以,借用resnet中的ResNet类来搭建SENet网络了。完整的代码如下,包含se_resnet18/34/50/101/152的实现。

# -*- coding: utf-8 -*-
"""
# @file name  : se_resnet.py
# @author     : https://github.com/moskomule/senet.pytorch
# @date       : 2020-08-07
# @brief      : se_resnet 模型搭建
"""
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url
from torchvision.models import ResNet



# 论文核心 SE Block, 这里称为 SE layer
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)        #全局平均池化,输入BCHW -> 输出 B*C*1*1
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False), #可以看到channel得被reduction整除,否则可能出问题
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)   #得到B*C*1*1,然后转成B*C,才能送入到FC层中。
        y = self.fc(y).view(b, c, 1, 1)   #得到B*C的向量,C个值就表示C个通道的权重。把B*C变为B*C*1*1是为了与四维的x运算。
        return x * y.expand_as(x)         #先把B*C*1*1变成B*C*H*W大小,其中每个通道上的H*W个值都相等。*表示对应位置相乘。


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class SEBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 *, reduction=16):
        # 参数列表里的 * 星号,标志着位置参数的就此终结,之后的那些参数,都只能以关键字形式来指定。
        super(SEBasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, 1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.se = SELayer(planes, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class SEBottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 *, reduction=16):
        # 参数列表里的 * 星号,标志着位置参数的就此终结,之后的那些参数,都只能以关键字形式来指定。
        super(SEBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.se = SELayer(planes * 4, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


def se_resnet18(num_classes=1_000):

    model = ResNet(SEBasicBlock, [2, 2, 2, 2], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def se_resnet34(num_classes=1_000):

    model = ResNet(SEBasicBlock, [3, 4, 6, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def se_resnet50(num_classes=1_000, pretrained=False):

    model = ResNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    if pretrained:
        model.load_state_dict(load_state_dict_from_url(
            "https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl"))
    return model


def se_resnet101(num_classes=1_000):

    model = ResNet(SEBottleneck, [3, 4, 23, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


def se_resnet152(num_classes=1_000):

    model = ResNet(SEBottleneck, [3, 8, 36, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


if __name__ == "__main__":
    inputs = torch.randn(2, 3, 224, 224)
    model = se_resnet50(pretrained=False)
    outputs = model(inputs)
    print(outputs.size())


运行结果:

你可能感兴趣的:(11,Python/DL/ML,pytorch,深度学习,神经网络)