[深度学习]--注意力机制--混合域之CBAM--代码实现

原理图[深度学习]--注意力机制--混合域之CBAM--代码实现_第1张图片

混合域包括两部分,分别是前面的通道域(Channel Attention Module)和后部分的空间域(Spatial Attention Module)
其具体结构如下图所示:
[深度学习]--注意力机制--混合域之CBAM--代码实现_第2张图片

代码实现

# This is CBAM block code
# import torch
from torch import nn
from torch.nn import functional as F

## 定义通道域类
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.shared_MLP = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        # self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        # self.relu1 = nn.ReLU()
        # self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out =self.shared_MLP(self.avg_pool(x))# self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out =self.shared_MLP(self.max_pool(x))# self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

## 定义空间域类
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

## CBAM将通道域与空间域串联起来形成混合域
class CBAM(nn.Module):
    def __init__(self, planes):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(planes)
        self.sa = SpatialAttention()

    def forward(self, x):
        x = self.ca(x) * x  #  执行通道注意力机制,并为通道赋予权重
        x = self.sa(x) * x  #  执行空间注意力机制,并为通道赋予权重
        return x

if __name__ == '__main__':
	# 假设输入是 batchsize=16, channel=32, figsieze = 20*20
    img = torch.randn(16, 32, 20, 20)
    net = CBAM(32)
    print(net)
    out = net(img)
    print(out.size()) #输出为torch.Size([16, 32, 20, 20])
output>>>
CBAM(
  (ca): ChannelAttention(
    (avg_pool): AdaptiveAvgPool2d(output_size=1)
    (max_pool): AdaptiveMaxPool2d(output_size=1)
    (shared_MLP): Sequential(
      (0): Conv2d(32, 2, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): ReLU()
      (2): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )
    (sigmoid): Sigmoid()
  )
  (sa): SpatialAttention(
    (conv1): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (sigmoid): Sigmoid()
  )
)
torch.Size([16, 32, 20, 20])

混合域的本质还是为输入数据的各个通道赋予权重

你可能感兴趣的:(基于深度学习的故障诊断,深度学习,pytorch,人工智能)