网络中加入注意力机制SE模块

        SENet是由自动驾驶公司Momenta在2017年公布的一种全新的图像识别结构,它通过对特征通道间的相关性进行建模,把重要的特征进行强化来提升准确率。SENet 是2017 ILSVR竞赛的冠军。

论文:Squeeze-and-Excitation Networks

网络中加入注意力机制SE模块_第1张图片

SE block的基本结构 

  1. 给定一个输入 ,其特征通道数为C ,通过一系列卷积等一般变换后得到一个特征通道数为C的特征。
  2. Squeeze:顺着空间维度进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。
  3. Excitation:基于特征通道间的相关性,每个特征通道生成一个权重,用来代表特征通道的重要程度。
  4. Reweight:将Excitation输出的权重看做每个特征通道的重要性,然后通过乘法逐通道加权到之前的特征上,完成在通道维度上的对原始特征的重标定。

代码:

import torch
import torch.nn as nn
import math
from torchvision import models
class se_block(nn.Module):
    def __init__(self, channel, ratio=16):
        super(se_block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, channel // ratio, bias=False),
                nn.ReLU(inplace=True),
                nn.Linear(channel // ratio, channel, bias=False),
                nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class Mobilenet_v2(nn.Module):
    def __init__(self):
        super(Mobilenet_v2, self).__init__()
        model = models.mobilenet_v2(pretrained=True)
        # Remove linear and pool layers (since we're not doing classification)
        modules = list(model.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.pool = nn.AvgPool2d(kernel_size=7)
        self.fc = nn.Linear(1280, 16)
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=-1)
        self.attention = se_block(1280) # 1280 为上层输出通道

    def forward(self, images):
        x = self.resnet(images)  # [N, 1280, 1, 1]
        x=self.attention(x)  # 此处加入se—block
        x = self.pool(x)
        x = x.view(-1, 1280)  # [N, 1280]
        x = self.fc(x)
        return x


if __name__=="__main__":
    input = torch.rand(2, 3, 224, 224)
    mode = Mobilenet_v2()
    out = mode(input)
    print(out.size())

小结: 

        1、SE网络可以通过堆叠SE模块得到。

        2、SE模块也可以嵌入到现在几乎所有的网络结构中。

你可能感兴趣的:(模型设计,网络,深度学习,人工智能)