即插即用的注意力机制21种

提示:谬误之处请指出更正

摘要

随着深度学习特别是自然语言处理领域的飞速发展,注意力机制(Attention Mechanism)已成为提升模型表现的关键技术,本文主要记录了即插即用的注意力机制结构的功能、出处及核心代码。

1、SE Block (Squeeze-and-Excitation)

功能:自适应学习通道权重,增强重要通道特征。
出处:SENet

# SE Block (PyTorch)
class SEBlock(nn.Module):
    def __init__(self, channel, ratio=16):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel//ratio, bias=False),
            nn.ReLU(),
            nn.Linear(channel//ratio, channel, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.gap(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

2、CBAM (Convolutional Block Attention Module)

功能:结合通道和空间注意力,增强特征选择性。
出处:CBAM

# CBAM (PyTorch)
class CBAMBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        # Channel Attention
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel//reduction),
            nn.ReLU(),
            nn.Linear(channel//reduction, channel),
            nn.Sigmoid()
        )
        # Spatial Attention
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)
    def forward(self, x):
        # Channel
        avg_out = self.fc(self.avg_pool(x).squeeze())
        max_out = self.fc(self.max_pool(x).squeeze())
        channel_out = (avg_out + max_out).unsqueeze(-1).unsqueeze(-1)
        x = x * channel_out
        # Spatial
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        spatial_out = torch.cat([avg_out, max_out], dim=1)
        spatial_out = self.conv(spatial_out).sigmoid()
        return x * spatial_out

3、ECA-Net (Efficient Channel Attention)

功能:轻量级通道注意力,通过1D卷积替代全连接层。
出处:ECA-Net

# ECA Block (PyTorch)
class ECABlock(nn.Module):
    def __init__(self, channel, b=1, gamma=2):
        super().__init__()
        kernel_size = int(abs((math.log(channel, 2) + b) / gamma)) | 1
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.conv(y.unsqueeze(1)).squeeze(1)
        y = self.sigmoid(y).view(b, c, 1, 1)
        return x * y

4、Scaled Dot-Product Self-Attention

功能:标准自注意力,计算全局特征关联。
出处:Attention is All You Need (Vaswani et al., 2017)

# Self-Attention (PyTorch)
class SelfAttention(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (x.size(-1)**0.5)
        attn = self.softmax(scores)
        return torch.matmul(attn, V)
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.out = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        query = self.query(x).view(x.size(0), x.size(1), -1)
        key = self.key(x).view(x.size(0), x.size(1), -1).permute(0, 2, 1)
        value = self.value(x).view(x.size(0), x.size(1), -1)
        attention = F.softmax(torch.bmm(query, key), dim=-1)
        out = torch.bmm(attention, value)
        out = out.view(x.size(0), x.size(1), x.size(2), x.size(3))
        return self.out(out)

5、Multi-Head Attention

功能:多头机制捕获不同子空间特征。
出处:Transformer

# Multi-Head Attention (PyTorch)
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, embed_size):
        super().__init__()
        self.heads = heads
        self.head_dim = embed_size // heads
        self.Wq = nn.Linear(embed_size, embed_size)
        self.Wk = nn.Linear(embed_size, embed_size)
        self.Wv = nn.Linear(embed_size, embed_size)
        self.fc = nn.Linear(embed_size, embed_size)
        
    def forward(self, x):
        batch = x.size(0)
        Q = self.Wq(x).view(batch, -1, self.heads, self.head_dim).transpose(1,2)
        K = self.Wk(x).view(batch, -1, self.heads, self.head_dim).transpose(1,2)
        V = self.Wv(x).view(batch, -1, self.heads, self.head_dim).transpose(1,2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, V).transpose(1,2).contiguous().view(batch, -1, embed_size)
        return self.fc(out)

6、CoTAttention (Contextual Transformer)

功能:融合局部上下文信息增强注意力。
出处:

# CoTAttention (PyTorch)
class CoTAttention(nn.Module):
    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim
        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed = nn.Conv2d(dim, dim, 1, bias=False)
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2*dim, 2*dim//4, 1, bias=False),
            nn.BatchNorm2d(2*dim//4),
            nn.ReLU(),
            nn.Conv2d(2*dim//4, kernel_size**2 * dim, 1)
        )
    
    def forward(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)
        v = self.value_embed(x).view(bs, c, -1)
        y = torch.cat([k1, x], dim=1)
        att = self.attention_embed(y)
        att = att.reshape(bs, c, -1, h, w)
        att = att.mean(2).view(bs, c, -1)
        k2 = F.softmax(att, dim=-1) * v
        k2 = k2.view(bs, c, h, w)
        return k1 + k2

7、Efficient Self-Attention (AFT)

功能:线性复杂度注意力,适用于长序列。
出处:AFT

# AFT (PyTorch)
class AFTFull(nn.Module):
    def __init__(self, d_model, n=49):
        super().__init__()
        self.d = d_model
        self.w = nn.Parameter(torch.randn(n, n))
        self.act = nn.Sigmoid()
        
    def forward(self, x):
        B, T, C = x.shape
        w = self.act(self.w)
        x_t = x.unsqueeze(2)
        x = x.unsqueeze(1)
        y = torch.matmul(x, w) * x_t
        return y.sum(dim=2)

8、Triplet Attention

功能:跨维度交互,增强空间和通道关系。
出处:Triplet Attention

# Triplet Attention (PyTorch)
class TripletAttention(nn.Module):
    def __init__(self, reduction_ratio=16):
        super().__init__()
        self.cw = ChannelGate(reduction_ratio)
        self.hw = SpatialGate()
        self.hc = SpatialGate()
    
    def forward(self, x):
        x_perm1 = x.permute(0,2,1,3).contiguous()
        x_out1 = self.cw(x_perm1)
        x_out1 = x_out1.permute(0,2,1,3).contiguous()
        x_perm2 = x.permute(0,3,1,2).contiguous()
        x_out2 = self.hw(x_perm2)
        x_out2 = x_out2.permute(0,2,3,1).contiguous()
        x_out = (x_out1 + x_out2) * 0.5
        x_out = self.hc(x_out)
        return x_out

9、 Criss-Cross Attention

功能:十字交叉注意力捕获全图上下文。
出处:CCNet

# Criss-Cross Attention (PyTorch)
class CrissCrossAttention(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim//8, 1)
        self.key_conv = nn.Conv2d(in_dim, in_dim//8, 1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        proj_query = self.query_conv(x)
        proj_key = self.key_conv(x)
        proj_value = self.value_conv(x)
        energy = torch.matmul(proj_query, proj_key.permute(0,2,3,1))
        attention = F.softmax(energy, dim=-1)
        out = torch.matmul(proj_value.permute(0,2,3,1), attention.permute(0,3,1,2))
        out = out.permute(0,3,1,2)
        return self.gamma * out + x

10、Coordinate Attention

功能:坐标位置编码增强空间感知。
出处:Coordinate Attention

# Coordinate Attention (PyTorch)
class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super().__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None,1))
        self.pool_w = nn.AdaptiveAvgPool2d((1,None))
        mip = max(8, inp//reduction)
        self.conv1 = nn.Conv2d(inp, mip, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(mip)
        self.conv_h = nn.Conv2d(mip, oup, 1)
        self.conv_w = nn.Conv2d(mip, oup, 1)
    
    def forward(self, x):
        identity = x
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0,1,3,2)
        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = F.relu(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0,1,3,2)
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        return identity * a_w * a_h

11、Hierarchical Attention

功能:文档分类中的词-句子层级注意力。
出处:Hierarchical Attention Networks

# Hierarchical Attention (TensorFlow)
class HierarchicalAttention(tf.keras.Model):
    def __init__(self, hidden_size):
        super().__init__()
        self.gru = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(hidden_size, return_sequences=True))
        self.attention = Dense(1, activation='tanh')
    
    def call(self, inputs):
        outputs = self.gru(inputs)
        attention_weights = tf.nn.softmax(self.attention(outputs), axis=1)
        context_vector = tf.reduce_sum(attention_weights * outputs, axis=1)
        return context_vector

12、Global Filter Networks (GFNet)

功能:频域全局滤波实现长程依赖。
出处:GFNet

# GFNet Block (PyTorch)
class GFNetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.fft = lambda x: torch.fft.fft2(x, dim=(-2, -1))
        self.ifft = lambda x: torch.fft.ifft2(x, dim=(-2, -1)).real
        self.filter = nn.Parameter(torch.randn(dim, dim, dtype=torch.cfloat))
    
    def forward(self, x):
        B, C, H, W = x.shape
        x_freq = self.fft(x)
        x_freq = x_freq * self.filter
        x = self.ifft(x_freq)
        return x

13、Gated Attention

功能:通过门控机制控制信息流,结合多种注意力策略进行有效信息提取。
出处:Gated Attention Network (Jia et al., 2018)

import torch
import torch.nn as nn

class GatedAttention(nn.Module):
    def __init__(self, in_channels):
        super(GatedAttention, self).__init__()
        self.gate = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.query = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        gate = torch.sigmoid(self.gate(x))
        query = self.query(x)
        key = self.key(x)
        attention = torch.sigmoid(query * key)
        return attention * gate * x

14、Bottleneck Attention Module (BAM)

功能:通过通道和空间两方面的注意力机制优化网络的表示能力。CBAM的修改。
出处:Bottleneck Attention Module (Woo et al., 2018)

import torch
import torch.nn as nn
import torch.nn.functional as F

class BAM(nn.Module):
    def __init__(self, in_channels):
        super(BAM, self).__init__()
        self.se = SEBlock(in_channels)
        self.cbam = CBAM(in_channels)

    def forward(self, x):
        return self.cbam(self.se(x))

15、Global Context Attention (GCA)

功能:通过全局上下文信息来增强重要特征的权重,尤其在图像分类等任务中有效。
出处:Global Context Attention for Visual Recognition (Gao et al., 2019)

import torch
import torch.nn as nn
import torch.nn.functional as F

class GlobalContextAttention(nn.Module):
    def __init__(self, in_channels):
        super(GlobalContextAttention, self).__init__()
        self.attn = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.fc = nn.Linear(in_channels, 1)

    def forward(self, x):
        # Global average pooling to get global context
        context = torch.mean(x, dim=[2, 3], keepdim=True)
        context = context.view(x.size(0), x.size(1))
        attention = self.fc(context)
        attention = torch.sigmoid(attention).view(-1, 1, 1, 1)
        return x * attention

16、Adaptive Attention

功能:自适应调整注意力的大小,能够根据输入的特性进行动态选择。
出处:Adaptive Attention (Zhang et al., 2020)

import torch
import torch.nn as nn
import torch.nn.functional as F

class AdaptiveAttention(nn.Module):
    def __init__(self, in_channels):
        super(AdaptiveAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        query = self.query(x).view(x.size(0), x.size(1), -1)
        key = self.key(x).view(x.size(0), x.size(1), -1).permute(0, 2, 1)
        value = self.value(x).view(x.size(0), x.size(1), -1)
        attn = torch.bmm(query, key)
        attn = torch.softmax(attn, dim=-1)
        return torch.bmm(attn, value).view(x.size(0), x.size(1), x.size(2), x.size(3))

17、Laplacian Attention

功能:基于图结构和图拉普拉斯算子进行建模,用于处理图像或图数据中的空间关系。
出处:Laplacian Attention Networks for Graph-Based Learning (Lee et al., 2020)

import torch
import torch.nn as nn
import torch.nn.functional as F

class LaplacianAttention(nn.Module):
    def __init__(self, in_channels, kernel_size=3):
        super(LaplacianAttention, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding=1)

    def forward(self, x):
        laplacian = torch.cat([x[:, :, 1:, :] - x[:, :, :-1, :], x[:, :, :, 1:] - x[:, :, :, :-1]], dim=1)
        laplacian = self.conv(laplacian)
        return x + laplacian

18、Pixel Attention

功能:通过对每个像素点分配不同的权重,增强图像中最具信息的区域。
出处:Pixel Attention for Image Enhancement (Liu et al., 2019)

import torch
import torch.nn as nn
import torch.nn.functional as F

class PixelAttention(nn.Module):
    def __init__(self, in_channels):
        super(PixelAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.key = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)

    def forward(self, x):
        query = self.query(x).view(x.size(0), x.size(1), -1)
        key = self.key(x).view(x.size(0), x.size(1), -1).permute(0, 2, 1)
        value = self.value(x).view(x.size(0), x.size(1), -1)
        attn = torch.bmm(query, key)
        attn = torch.softmax(attn, dim=-1)
        return torch.bmm(attn, value).view(x.size(0), x.size(1), x.size(2), x.size(3))

19、Cross-Modal Attention

功能:用于多模态任务,能够对不同模态的信息进行跨模态的融合。
出处:Cross-Modal Attention Networks (Lu et al., 2019)

import torch
import torch.nn as nn

class CrossModalAttention(nn.Module):
    def __init__(self, in_channels):
        super(CrossModalAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.key = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x, y):
        query = self.query(x).view(x.size(0), x.size(1), -1)
        key = self.key(y).view(y.size(0), y.size(1), -1).permute(0, 2, 1)
        value = self.value(y).view(y.size(0), y.size(1), -1)
        attn = torch.bmm(query, key)
        attn = torch.softmax(attn, dim=-1)
        return torch.bmm(attn, value).view(x.size(0), x.size(1), x.size(2), x.size(3))

20、Axial Attention

功能:通过轴向注意力机制,提升了在处理图像和视频等2D数据时的效率,减少了计算的空间复杂度。
出处:Axial Attention (Ho et al., 2021)

import torch
import torch.nn as nn

class AxialAttention(nn.Module):
    def __init__(self, in_channels):
        super(AxialAttention, self).__init__()
        self.attn = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        # Compute axial attention over height and width dimensions
        attn_h = torch.softmax(self.attn(x), dim=2)
        attn_w = torch.softmax(self.attn(x), dim=3)
        return x * attn_h * attn_w

21、Sparse Attention

功能:通过稀疏化注意力矩阵,减少了模型在计算注意力时的内存消耗,适合处理较长序列。
出处:Sparse Attention (Child et al., 2021)

import torch
import torch.nn as nn

class SparseAttention(nn.Module):
    def __init__(self, in_channels):
        super(SparseAttention, self).__init__()
        self.attn = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        attn = torch.sigmoid(self.attn(x))
        attn = attn * (torch.rand_like(attn) > 0.5)  # Sparse masking
        return x * attn

你可能感兴趣的:(论文阅读,AttentionModule,注意力机制,即插即用)