混合域包括两部分,分别是前面的通道域(Channel Attention Module)和后部分的空间域(Spatial Attention Module)
其具体结构如下图所示:
# This is CBAM block code
# import torch
from torch import nn
from torch.nn import functional as F
## 定义通道域类
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.shared_MLP = nn.Sequential(
nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
)
# self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
# self.relu1 = nn.ReLU()
# self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out =self.shared_MLP(self.avg_pool(x))# self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out =self.shared_MLP(self.max_pool(x))# self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)
## 定义空间域类
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
## CBAM将通道域与空间域串联起来形成混合域
class CBAM(nn.Module):
def __init__(self, planes):
super(CBAM, self).__init__()
self.ca = ChannelAttention(planes)
self.sa = SpatialAttention()
def forward(self, x):
x = self.ca(x) * x # 执行通道注意力机制,并为通道赋予权重
x = self.sa(x) * x # 执行空间注意力机制,并为通道赋予权重
return x
if __name__ == '__main__':
# 假设输入是 batchsize=16, channel=32, figsieze = 20*20
img = torch.randn(16, 32, 20, 20)
net = CBAM(32)
print(net)
out = net(img)
print(out.size()) #输出为torch.Size([16, 32, 20, 20])
output>>>
CBAM(
(ca): ChannelAttention(
(avg_pool): AdaptiveAvgPool2d(output_size=1)
(max_pool): AdaptiveMaxPool2d(output_size=1)
(shared_MLP): Sequential(
(0): Conv2d(32, 2, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): ReLU()
(2): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
(sigmoid): Sigmoid()
)
(sa): SpatialAttention(
(conv1): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
(sigmoid): Sigmoid()
)
)
torch.Size([16, 32, 20, 20])
混合域的本质还是为输入数据的各个通道赋予权重