attention_GAM

GAM_attention

import torch.nn as nn
import torch
 
 
class GAM_Attention(nn.Module):
    def __init__(self, in_channels, out_channels, rate=4):
        super(GAM_Attention, self).__init__()
 
        self.channel_attention = nn.Sequential(
            nn.Linear(in_channels, int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels / rate), in_channels)
        )
 
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
            nn.BatchNorm2d(int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(out_channels)
        )
 
    def forward(self, x):
        b, c, h, w = x.shape
        print(x.shape)
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        print("x_permute",x_permute.shape)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        print("x_att_permute",x_att_permute.shape)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)
        print("x_channel_att",x_channel_att.shape)
 
        x = x * x_channel_att
        print("x2",x.shape)
 
        x_spatial_att = self.spatial_attention(x).sigmoid()
        print("x_spatial_att",x_spatial_att.shape)
        out = x * x_spatial_att
        print("out",out.shape)
 
        return out
 
 
if __name__ == '__main__':
    x = torch.randn(1, 4, 128, 128)
    b, c, h, w = x.shape
    net = GAM_Attention(in_channels=c, out_channels=c)
    y = net(x)

你可能感兴趣的:(Attention,深度学习,神经网络,注意力机制)