Codes of pytorch:
import torch.nn as nn
import torch
class GAM_Attention(nn.Module):
def __init__(self, in_channels, out_channels, rate=4):
super(GAM_Attention, self).__init__()
self.channel_attention = nn.Sequential(
nn.Linear(in_channels, int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Linear(int(in_channels / rate), in_channels) ###通道注意力 MLP来实现
)
self.spatial_attention = nn.Sequential(
nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
nn.BatchNorm2d(int(in_channels / rate)),
nn.ReLU(inplace=True),
nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3), #空间注意力 卷积实现
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
b, c, h, w = x.shape
print("Input size:",x.shape)
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c) #(b,c,h*w)
print("维度转换:x_permute = x.permute(0, 2, 3, 1).view(b, -1, c):",x_permute.shape)
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c) #(b,h,w,c)
x_channel_att = x_att_permute.permute(0, 3, 1, 2) #(b,c,h,w)
print("送入通道注意力子模块然后恢复到原始的size以便于x计算通道注意力:x_att_permute = self.channel_attention(x_permute).view(b, h, w, c).permute(0, 3, 1, 2):",x_channel_att.shape)
x = x * x_channel_att ###计算通道注意力
print("Get channel attention map:",x.shape)
x_spatial_att = self.spatial_attention(x).sigmoid()
print("把通道注意力图送入空间注意力子模块:",x_spatial_att.shape)
out = x * x_spatial_att
print("得到的空间通道注意力图与通道注意力图点乘得到最后的GAM注意力图:",out.shape)
return out
if __name__ == '__main__':
x = torch.randn(1, 64, 32, 48)
b, c, h, w = x.shape
net = GAM_Attention(in_channels=c, out_channels=c)
y = net(x)
Title and authors:
paper address:
https://arxiv.org/pdf/2112.05561v1.pdf
Overview of GAM:
Channel and Spatial attention submodule
Experiment in ImageNet-1k