提示:谬误之处请指出更正
随着深度学习特别是自然语言处理领域的飞速发展,注意力机制(Attention Mechanism)已成为提升模型表现的关键技术,本文主要记录了即插即用的注意力机制结构的功能、出处及核心代码。
功能:自适应学习通道权重,增强重要通道特征。
出处:SENet
# SE Block (PyTorch)
class SEBlock(nn.Module):
def __init__(self, channel, ratio=16):
super().__init__()
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel//ratio, bias=False),
nn.ReLU(),
nn.Linear(channel//ratio, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.gap(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
功能:结合通道和空间注意力,增强特征选择性。
出处:CBAM
# CBAM (PyTorch)
class CBAMBlock(nn.Module):
def __init__(self, channel, reduction=16):
super().__init__()
# Channel Attention
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel//reduction),
nn.ReLU(),
nn.Linear(channel//reduction, channel),
nn.Sigmoid()
)
# Spatial Attention
self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3)
def forward(self, x):
# Channel
avg_out = self.fc(self.avg_pool(x).squeeze())
max_out = self.fc(self.max_pool(x).squeeze())
channel_out = (avg_out + max_out).unsqueeze(-1).unsqueeze(-1)
x = x * channel_out
# Spatial
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
spatial_out = torch.cat([avg_out, max_out], dim=1)
spatial_out = self.conv(spatial_out).sigmoid()
return x * spatial_out
功能:轻量级通道注意力,通过1D卷积替代全连接层。
出处:ECA-Net
# ECA Block (PyTorch)
class ECABlock(nn.Module):
def __init__(self, channel, b=1, gamma=2):
super().__init__()
kernel_size = int(abs((math.log(channel, 2) + b) / gamma)) | 1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size, padding=kernel_size//2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.conv(y.unsqueeze(1)).squeeze(1)
y = self.sigmoid(y).view(b, c, 1, 1)
return x * y
功能:标准自注意力,计算全局特征关联。
出处:Attention is All You Need (Vaswani et al., 2017)
# Self-Attention (PyTorch)
class SelfAttention(nn.Module):
def __init__(self, embed_size):
super().__init__()
self.query = nn.Linear(embed_size, embed_size)
self.key = nn.Linear(embed_size, embed_size)
self.value = nn.Linear(embed_size, embed_size)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
Q = self.query(x)
K = self.key(x)
V = self.value(x)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (x.size(-1)**0.5)
attn = self.softmax(scores)
return torch.matmul(attn, V)
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, in_channels):
super(SelfAttention, self).__init__()
self.query = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.key = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def forward(self, x):
query = self.query(x).view(x.size(0), x.size(1), -1)
key = self.key(x).view(x.size(0), x.size(1), -1).permute(0, 2, 1)
value = self.value(x).view(x.size(0), x.size(1), -1)
attention = F.softmax(torch.bmm(query, key), dim=-1)
out = torch.bmm(attention, value)
out = out.view(x.size(0), x.size(1), x.size(2), x.size(3))
return self.out(out)
功能:多头机制捕获不同子空间特征。
出处:Transformer
# Multi-Head Attention (PyTorch)
class MultiHeadAttention(nn.Module):
def __init__(self, heads, embed_size):
super().__init__()
self.heads = heads
self.head_dim = embed_size // heads
self.Wq = nn.Linear(embed_size, embed_size)
self.Wk = nn.Linear(embed_size, embed_size)
self.Wv = nn.Linear(embed_size, embed_size)
self.fc = nn.Linear(embed_size, embed_size)
def forward(self, x):
batch = x.size(0)
Q = self.Wq(x).view(batch, -1, self.heads, self.head_dim).transpose(1,2)
K = self.Wk(x).view(batch, -1, self.heads, self.head_dim).transpose(1,2)
V = self.Wv(x).view(batch, -1, self.heads, self.head_dim).transpose(1,2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim**0.5)
attn = torch.softmax(scores, dim=-1)
out = torch.matmul(attn, V).transpose(1,2).contiguous().view(batch, -1, embed_size)
return self.fc(out)
功能:融合局部上下文信息增强注意力。
出处:
# CoTAttention (PyTorch)
class CoTAttention(nn.Module):
def __init__(self, dim=512, kernel_size=3):
super().__init__()
self.dim = dim
self.key_embed = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=4, bias=False),
nn.BatchNorm2d(dim),
nn.ReLU()
)
self.value_embed = nn.Conv2d(dim, dim, 1, bias=False)
self.attention_embed = nn.Sequential(
nn.Conv2d(2*dim, 2*dim//4, 1, bias=False),
nn.BatchNorm2d(2*dim//4),
nn.ReLU(),
nn.Conv2d(2*dim//4, kernel_size**2 * dim, 1)
)
def forward(self, x):
bs, c, h, w = x.shape
k1 = self.key_embed(x)
v = self.value_embed(x).view(bs, c, -1)
y = torch.cat([k1, x], dim=1)
att = self.attention_embed(y)
att = att.reshape(bs, c, -1, h, w)
att = att.mean(2).view(bs, c, -1)
k2 = F.softmax(att, dim=-1) * v
k2 = k2.view(bs, c, h, w)
return k1 + k2
功能:线性复杂度注意力,适用于长序列。
出处:AFT
# AFT (PyTorch)
class AFTFull(nn.Module):
def __init__(self, d_model, n=49):
super().__init__()
self.d = d_model
self.w = nn.Parameter(torch.randn(n, n))
self.act = nn.Sigmoid()
def forward(self, x):
B, T, C = x.shape
w = self.act(self.w)
x_t = x.unsqueeze(2)
x = x.unsqueeze(1)
y = torch.matmul(x, w) * x_t
return y.sum(dim=2)
功能:跨维度交互,增强空间和通道关系。
出处:Triplet Attention
# Triplet Attention (PyTorch)
class TripletAttention(nn.Module):
def __init__(self, reduction_ratio=16):
super().__init__()
self.cw = ChannelGate(reduction_ratio)
self.hw = SpatialGate()
self.hc = SpatialGate()
def forward(self, x):
x_perm1 = x.permute(0,2,1,3).contiguous()
x_out1 = self.cw(x_perm1)
x_out1 = x_out1.permute(0,2,1,3).contiguous()
x_perm2 = x.permute(0,3,1,2).contiguous()
x_out2 = self.hw(x_perm2)
x_out2 = x_out2.permute(0,2,3,1).contiguous()
x_out = (x_out1 + x_out2) * 0.5
x_out = self.hc(x_out)
return x_out
功能:十字交叉注意力捕获全图上下文。
出处:CCNet
# Criss-Cross Attention (PyTorch)
class CrissCrossAttention(nn.Module):
def __init__(self, in_dim):
super().__init__()
self.query_conv = nn.Conv2d(in_dim, in_dim//8, 1)
self.key_conv = nn.Conv2d(in_dim, in_dim//8, 1)
self.value_conv = nn.Conv2d(in_dim, in_dim, 1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
proj_query = self.query_conv(x)
proj_key = self.key_conv(x)
proj_value = self.value_conv(x)
energy = torch.matmul(proj_query, proj_key.permute(0,2,3,1))
attention = F.softmax(energy, dim=-1)
out = torch.matmul(proj_value.permute(0,2,3,1), attention.permute(0,3,1,2))
out = out.permute(0,3,1,2)
return self.gamma * out + x
功能:坐标位置编码增强空间感知。
出处:Coordinate Attention
# Coordinate Attention (PyTorch)
class CoordAtt(nn.Module):
def __init__(self, inp, oup, reduction=32):
super().__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None,1))
self.pool_w = nn.AdaptiveAvgPool2d((1,None))
mip = max(8, inp//reduction)
self.conv1 = nn.Conv2d(inp, mip, 1, bias=False)
self.bn1 = nn.BatchNorm2d(mip)
self.conv_h = nn.Conv2d(mip, oup, 1)
self.conv_w = nn.Conv2d(mip, oup, 1)
def forward(self, x):
identity = x
n,c,h,w = x.size()
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0,1,3,2)
y = torch.cat([x_h, x_w], dim=2)
y = self.conv1(y)
y = self.bn1(y)
y = F.relu(y)
x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0,1,3,2)
a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()
return identity * a_w * a_h
功能:文档分类中的词-句子层级注意力。
出处:Hierarchical Attention Networks
# Hierarchical Attention (TensorFlow)
class HierarchicalAttention(tf.keras.Model):
def __init__(self, hidden_size):
super().__init__()
self.gru = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(hidden_size, return_sequences=True))
self.attention = Dense(1, activation='tanh')
def call(self, inputs):
outputs = self.gru(inputs)
attention_weights = tf.nn.softmax(self.attention(outputs), axis=1)
context_vector = tf.reduce_sum(attention_weights * outputs, axis=1)
return context_vector
功能:频域全局滤波实现长程依赖。
出处:GFNet
# GFNet Block (PyTorch)
class GFNetBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.fft = lambda x: torch.fft.fft2(x, dim=(-2, -1))
self.ifft = lambda x: torch.fft.ifft2(x, dim=(-2, -1)).real
self.filter = nn.Parameter(torch.randn(dim, dim, dtype=torch.cfloat))
def forward(self, x):
B, C, H, W = x.shape
x_freq = self.fft(x)
x_freq = x_freq * self.filter
x = self.ifft(x_freq)
return x
功能:通过门控机制控制信息流,结合多种注意力策略进行有效信息提取。
出处:Gated Attention Network (Jia et al., 2018)
import torch
import torch.nn as nn
class GatedAttention(nn.Module):
def __init__(self, in_channels):
super(GatedAttention, self).__init__()
self.gate = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.query = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.key = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def forward(self, x):
gate = torch.sigmoid(self.gate(x))
query = self.query(x)
key = self.key(x)
attention = torch.sigmoid(query * key)
return attention * gate * x
功能:通过通道和空间两方面的注意力机制优化网络的表示能力。CBAM的修改。
出处:Bottleneck Attention Module (Woo et al., 2018)
import torch
import torch.nn as nn
import torch.nn.functional as F
class BAM(nn.Module):
def __init__(self, in_channels):
super(BAM, self).__init__()
self.se = SEBlock(in_channels)
self.cbam = CBAM(in_channels)
def forward(self, x):
return self.cbam(self.se(x))
功能:通过全局上下文信息来增强重要特征的权重,尤其在图像分类等任务中有效。
出处:Global Context Attention for Visual Recognition (Gao et al., 2019)
import torch
import torch.nn as nn
import torch.nn.functional as F
class GlobalContextAttention(nn.Module):
def __init__(self, in_channels):
super(GlobalContextAttention, self).__init__()
self.attn = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.fc = nn.Linear(in_channels, 1)
def forward(self, x):
# Global average pooling to get global context
context = torch.mean(x, dim=[2, 3], keepdim=True)
context = context.view(x.size(0), x.size(1))
attention = self.fc(context)
attention = torch.sigmoid(attention).view(-1, 1, 1, 1)
return x * attention
功能:自适应调整注意力的大小,能够根据输入的特性进行动态选择。
出处:Adaptive Attention (Zhang et al., 2020)
import torch
import torch.nn as nn
import torch.nn.functional as F
class AdaptiveAttention(nn.Module):
def __init__(self, in_channels):
super(AdaptiveAttention, self).__init__()
self.query = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.key = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def forward(self, x):
query = self.query(x).view(x.size(0), x.size(1), -1)
key = self.key(x).view(x.size(0), x.size(1), -1).permute(0, 2, 1)
value = self.value(x).view(x.size(0), x.size(1), -1)
attn = torch.bmm(query, key)
attn = torch.softmax(attn, dim=-1)
return torch.bmm(attn, value).view(x.size(0), x.size(1), x.size(2), x.size(3))
功能:基于图结构和图拉普拉斯算子进行建模,用于处理图像或图数据中的空间关系。
出处:Laplacian Attention Networks for Graph-Based Learning (Lee et al., 2020)
import torch
import torch.nn as nn
import torch.nn.functional as F
class LaplacianAttention(nn.Module):
def __init__(self, in_channels, kernel_size=3):
super(LaplacianAttention, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, padding=1)
def forward(self, x):
laplacian = torch.cat([x[:, :, 1:, :] - x[:, :, :-1, :], x[:, :, :, 1:] - x[:, :, :, :-1]], dim=1)
laplacian = self.conv(laplacian)
return x + laplacian
功能:通过对每个像素点分配不同的权重,增强图像中最具信息的区域。
出处:Pixel Attention for Image Enhancement (Liu et al., 2019)
import torch
import torch.nn as nn
import torch.nn.functional as F
class PixelAttention(nn.Module):
def __init__(self, in_channels):
super(PixelAttention, self).__init__()
self.query = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.key = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
def forward(self, x):
query = self.query(x).view(x.size(0), x.size(1), -1)
key = self.key(x).view(x.size(0), x.size(1), -1).permute(0, 2, 1)
value = self.value(x).view(x.size(0), x.size(1), -1)
attn = torch.bmm(query, key)
attn = torch.softmax(attn, dim=-1)
return torch.bmm(attn, value).view(x.size(0), x.size(1), x.size(2), x.size(3))
功能:用于多模态任务,能够对不同模态的信息进行跨模态的融合。
出处:Cross-Modal Attention Networks (Lu et al., 2019)
import torch
import torch.nn as nn
class CrossModalAttention(nn.Module):
def __init__(self, in_channels):
super(CrossModalAttention, self).__init__()
self.query = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.key = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def forward(self, x, y):
query = self.query(x).view(x.size(0), x.size(1), -1)
key = self.key(y).view(y.size(0), y.size(1), -1).permute(0, 2, 1)
value = self.value(y).view(y.size(0), y.size(1), -1)
attn = torch.bmm(query, key)
attn = torch.softmax(attn, dim=-1)
return torch.bmm(attn, value).view(x.size(0), x.size(1), x.size(2), x.size(3))
功能:通过轴向注意力机制,提升了在处理图像和视频等2D数据时的效率,减少了计算的空间复杂度。
出处:Axial Attention (Ho et al., 2021)
import torch
import torch.nn as nn
class AxialAttention(nn.Module):
def __init__(self, in_channels):
super(AxialAttention, self).__init__()
self.attn = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def forward(self, x):
# Compute axial attention over height and width dimensions
attn_h = torch.softmax(self.attn(x), dim=2)
attn_w = torch.softmax(self.attn(x), dim=3)
return x * attn_h * attn_w
功能:通过稀疏化注意力矩阵,减少了模型在计算注意力时的内存消耗,适合处理较长序列。
出处:Sparse Attention (Child et al., 2021)
import torch
import torch.nn as nn
class SparseAttention(nn.Module):
def __init__(self, in_channels):
super(SparseAttention, self).__init__()
self.attn = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def forward(self, x):
attn = torch.sigmoid(self.attn(x))
attn = attn * (torch.rand_like(attn) > 0.5) # Sparse masking
return x * attn