在自然语言处理和序列建模任务中,Transformer 架构因其强大的并行计算能力和长序列建模能力而广受欢迎。然而,传统 Transformer 的自注意力机制计算复杂度高(O(n²)),且参数量较大,这在资源受限的场景下(如移动端或实时推理)成为瓶颈。本文将介绍一种创新的自注意力解码器架构,通过优化注意力机制、门控前馈网络和参数共享策略,在保持性能的同时显著提升效率。
alpha
调节前馈网络与输入的权重,提升训练稳定性。SamOut (整体模型)
├── Embedding层:将词汇索引映射为稠密向量
├── DecoderLayer × num_layers:堆叠的解码器层
│ ├── MaxStateSuper:自注意力模块
│ └── FeedForward:门控前馈网络
└── 输出层:线性变换到词汇空间
传统自注意力的 QKV 点积计算复杂度为 O(n²),且需要存储全部键值对。
torch.cummax
计算每个位置的累积最大值,捕捉长期依赖关系,复杂度降至 O(n)。combined = nn.Linear(dim_size, 3*dim_size)
),减少参数量。softmax(out)
生成注意力权重,结合门控机制(out_score + out1
)动态调整信息流。# MaxStateSuper 的核心计算
out = torch.cummax(out, dim=2)[0] # 累积最大值
out_score = torch.softmax(out, dim=1) # 动态权重
out = (out_score + out1) * out2 + out1 # 权重融合
标准前馈网络(FFN)的线性变换可能导致信息过载或冗余。
gate = Linear(hidden_size, hidden_size//2)
生成门控信号,控制信息流的开放程度。ReLU
和 Dropout(0.1)
防止过拟合,提升泛化能力。# 门控前馈网络的计算
x1 = self.ffn1(x) # 第一线性变换
x2 = self.relu(self.gate(x)) # 门控信号生成
xx = x1 * x2 # 门控乘积
x = self.gr(self.ffn2(xx)) # 第二线性变换 + Dropout
alpha
:alpha
平衡前馈网络输出与原始输入的权重:x = alpha * FFN(x1) + (1-alpha)*x
# DecoderLayer 中的参数平衡
self.alpha = torch.nn.Parameter(torch.tensor(0.5))
x = self.layer_norm(self.alpha * self.ffn(x1) + (1 - self.alpha) * x)
模块 | 传统Transformer | 本架构 |
---|---|---|
自注意力复杂度 | O(n²) | O(n) |
参数量 | 高(QKV线性层) | 低(合并线性层) |
训练速度 | 较慢 | 更快 |
在代码中,我们模拟了以下场景:
该架构通过累积最大值注意力、门控前馈网络和参数平衡策略,在保持性能的同时显著降低了计算复杂度和参数量。其设计思想为轻量化序列建模提供了新思路,尤其适用于资源受限的场景。未来可通过结合更多优化技术(如稀疏注意力)进一步提升效率。
通过本文,我们希望读者能够理解这一创新架构的设计原理,并在实际项目中尝试应用或改进这一模型。
import time
import torch
from torch import nn, optim
class MaxStateSuper(torch.nn.Module):
def __init__(self, dim_size, heads):
super(MaxStateSuper, self).__init__()
self.heads = heads
assert dim_size % heads == 0, "Dimension size must be divisible by head size."
# 合并三个线性层为一个
self.combined = nn.Linear(dim_size, 3 * dim_size)
# self.out_proj = nn.Linear(dim_size, dim_size)
def forward(self, x, state=None):
b, s, d = x.shape
# 合并后的线性变换并分割
combined = self.combined(x).chunk(3, dim=-1)
out, out1, out2 = combined
# 调整张量形状,使用view优化
out = out.view(b, s, self.heads, -1).permute(0, 2, 1, 3)
out1 = out1.view(b, s, self.heads, -1).permute(0, 2, 1, 3)
out2 = out2.view(b, s, self.heads, -1).permute(0, 2, 1, 3)
out = torch.cummax(out, dim=2)[0]
out_score = torch.softmax(out, dim=1)
out = (out_score + out1) * out2 + out1
# 恢复形状
out = out.permute(0, 2, 1, 3).contiguous().view(b, s, d)
# out = self.out_proj(out)
return out, state
class FeedForward(torch.nn.Module):
def __init__(self, hidden_size):
super(FeedForward, self).__init__()
self.ffn1 = torch.nn.Linear(hidden_size, hidden_size // 2)
self.ffn2 = torch.nn.Linear(hidden_size // 2, hidden_size)
self.gate = torch.nn.Linear(hidden_size, hidden_size // 2)
self.relu = torch.nn.ReLU()
self.gr = torch.nn.Dropout(0.1)
def forward(self, x):
x1 = self.ffn1(x)
x2 = self.relu(self.gate(x))
xx = x1 * x2
x = self.gr(self.ffn2(xx))
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super(DecoderLayer, self).__init__()
self.self_attention = MaxStateSuper(hidden_size, num_heads)
self.ffn = FeedForward(hidden_size)
self.layer_norm = torch.nn.LayerNorm(hidden_size)
self.alpha = torch.nn.Parameter(torch.tensor(0.5))
def forward(self, x, state=None, ):
x1, state = self.self_attention(x, state)
x = self.layer_norm(self.alpha * self.ffn(x1) + (1 - self.alpha) * x)
return x, state
class SamOut(torch.nn.Module):
def __init__(self, voc_size, hidden_size, num_heads, num_layers):
super(SamOut, self).__init__()
self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
self.head = nn.Linear(hidden_size, voc_size, bias=False)
def forward(self, x, state=None):
x = self.em(x)
if state is None:
state = [None] * len(self.decoder_layers)
i = 0
for ii, decoder_layer in enumerate(self.decoder_layers):
x1, state[i] = decoder_layer(x, state[i])
x = x1 + x
i += 1
x = self.head(x)
return x, state
if __name__ == '__main__':
# 这里假设 DecoderLayer 已经定义好了,具体实现可以参考之前提供的代码或根据需要自定义
# 定义超参数
voc_size = 10000 # 词汇表大小
hidden_size = 512 # 隐藏层大小
num_heads = 8 # 注意力头的数量
num_layers = 6 # 解码器层数
learning_rate = 0.001
batch_size = 32
num_epochs = 10
# 初始化模型
model = SamOut(voc_size=voc_size, hidden_size=hidden_size, num_heads=num_heads, num_layers=num_layers)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=3) # 忽略填充标记的损失计算
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 模拟一些训练数据(实际应用中应该使用真实的数据集)
data = torch.randint(low=0, high=voc_size, size=(batch_size, 50)) # 输入序列长度为50
input_tensor = data[:, :-1]
target_tensor = data[:, 1:]
# 训练循环
start_time = time.time()
for epoch in range(num_epochs):
# 前向传播
output, _ = model(input_tensor)
# 将输出reshape以适应 CrossEntropyLoss 的输入要求
output = output.reshape(-1, voc_size)
target_tensor = target_tensor.reshape(-1)
# 计算损失
loss = criterion(output, target_tensor)
optimizer.zero_grad() # 清除梯度
# 反向传播和优化
loss.backward()
optimizer.step()
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
print("Training complete.{}".format(time.time() - start_time))