CBAM注意力机制——pytorch实现

论文传送门:CBAM: Convolutional Block Attention Module

CBAM的目的:

为网络添加注意力机制

CBAM的结构:

通道注意力机制(Channel attention module):输入特征分别经过全局最大池化和全局平均池化,池化结果经过一个权值共享的MLP,得到的权重相加,最后经过sigmoid激活函数得到通道注意力权重 M c M_c Mc
空间注意力机制(Spatial attention module):输入特征在通道维度上分别进行最大池化和平均池化,得到(2,H,W)的特征层,经过7x7的卷积,输出单通道特征层,最后经过sigmoid激活函数得到空间注意力权重 M s M_s Ms
CBAM注意力机制——pytorch实现_第1张图片
二者串联:作者将二者串联搭建,且通道注意力模块在前,空间注意力模块在后。
CBAM注意力机制——pytorch实现_第2张图片
经过实验,作者发现:串联搭建比并联搭建效果好,先进行通道注意力比先空间注意力效果好。
CBAM注意力机制——pytorch实现_第3张图片

import torch
import torch.nn as nn


class ChannelAttention(nn.Module):  # Channel attention module
    def __init__(self, channels, ratio=16):  # r: reduction ratio=16
        super(ChannelAttention, self).__init__()

        hidden_channels = channels // ratio
        self.avgpool = nn.AdaptiveAvgPool2d(1)  # global avg pool
        self.maxpool = nn.AdaptiveMaxPool2d(1)  # global max pool
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, hidden_channels, 1, 1, 0, bias=False),  # 1x1conv代替全连接,根据原文公式没有偏置项
            nn.ReLU(inplace=True),  # relu
            nn.Conv2d(hidden_channels, channels, 1, 1, 0, bias=False)  # 1x1conv代替全连接,根据原文公式没有偏置项
        )
        self.sigmoid = nn.Sigmoid()  # sigmoid

    def forward(self, x):
        x_avg = self.avgpool(x)
        x_max = self.maxpool(x)
        return self.sigmoid(
            self.mlp(x_avg) + self.mlp(x_max)
        )  # Mc(F) = σ(MLP(AvgPool(F))+MLP(MaxPool(F)))= σ(W1(W0(Fcavg))+W1(W0(Fcmax))),对应原文公式(2)


class SpatialAttention(nn.Module):  # Spatial attention module
    def __init__(self):
        super(SpatialAttention, self).__init__()

        self.conv = nn.Conv2d(2, 1, 7, 1, 3, bias=False)  # 7x7conv
        self.sigmoid = nn.Sigmoid()  # sigmoid

    def forward(self, x):
        x_avg = torch.mean(x, dim=1, keepdim=True)  # 在通道维度上进行avgpool,(B,C,H,W)->(B,1,H,W)
        x_max = torch.max(x, dim=1, keepdim=True)[0]  # 在通道维度上进行maxpool,(B,C,H,W)->(B,1,H,W)
        return self.sigmoid(
            self.conv(torch.cat([x_avg, x_max],dim=1))
        )  # Ms(F) = σ(f7×7([AvgP ool(F);MaxPool(F)])) = σ(f7×7([Fsavg;Fsmax])),对应原文公式(3)


class CBAM(nn.Module):  # Convolutional Block Attention Module
    def __init__(self, channels, ratio=16):
        super(CBAM, self).__init__()

        self.channel_attention = ChannelAttention(channels, ratio)  # Channel attention module
        self.spatial_attention = SpatialAttention()  # Spatial attention module

    def forward(self, x):
        f1 = self.channel_attention(x) * x  # F0 = Mc(F)⊗F,对应原文公式(1)
        f2 = self.spatial_attention(f1) * f1  # F00 = Ms(F0)⊗F0,对应原文公式(1)
        return f2

你可能感兴趣的:(pytorch,深度学习,人工智能)