Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions(GAM)

Codes of pytorch:

import torch.nn as nn  
import torch  


class GAM_Attention(nn.Module):  
    def __init__(self, in_channels, out_channels, rate=4):  
        super(GAM_Attention, self).__init__()  

        self.channel_attention = nn.Sequential(  
            nn.Linear(in_channels, int(in_channels / rate)),  
            nn.ReLU(inplace=True),  
            nn.Linear(int(in_channels / rate), in_channels)  ###通道注意力  MLP来实现
        )  
      
        self.spatial_attention = nn.Sequential(  
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),  
            nn.BatchNorm2d(int(in_channels / rate)),  
            nn.ReLU(inplace=True),  
            nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3),  #空间注意力  卷积实现
            nn.BatchNorm2d(out_channels)  
        )  

    def forward(self, x):  
        b, c, h, w = x.shape  
        print("Input size:",x.shape)
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)  #(b,c,h*w)
        print("维度转换:x_permute = x.permute(0, 2, 3, 1).view(b, -1, c):",x_permute.shape)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)    #(b,h,w,c)
        
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)  #(b,c,h,w)
        print("送入通道注意力子模块然后恢复到原始的size以便于x计算通道注意力:x_att_permute = self.channel_attention(x_permute).view(b, h, w, c).permute(0, 3, 1, 2):",x_channel_att.shape)
      
        x = x * x_channel_att  ###计算通道注意力
        print("Get channel attention map:",x.shape)
      
        x_spatial_att = self.spatial_attention(x).sigmoid()  
        print("把通道注意力图送入空间注意力子模块:",x_spatial_att.shape)
        out = x * x_spatial_att  
        print("得到的空间通道注意力图与通道注意力图点乘得到最后的GAM注意力图:",out.shape)
      
        return out  

  

if __name__ == '__main__':  
    x = torch.randn(1, 64, 32, 48)  
    b, c, h, w = x.shape  
    net = GAM_Attention(in_channels=c, out_channels=c)  
    y = net(x)  

code results:
Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions(GAM)_第1张图片

Title and authors:
Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions(GAM)_第2张图片
paper address:
https://arxiv.org/pdf/2112.05561v1.pdf

Overview of GAM:
Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions(GAM)_第3张图片
Channel and Spatial attention submodule
Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions(GAM)_第4张图片
Experiment in ImageNet-1k
Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions(GAM)_第5张图片

你可能感兴趣的:(注意力总结,深度学习,pytorch,神经网络)