class Transformer(nn.Module):
该实现继承PyTorch的nn.Module
基类,采用面向对象设计模式。核心架构包含编码器-解码器双塔结构,通过参数配置实现NLP任务的通用处理能力。
章节跳转: 多头注意力机制(MultiHeadAttention)
def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx,
enc_voc_size, dec_voc_size, d_model, n_head,
max_len, ffn_hidden, n_layers, drop_prob, device):
参数矩阵说明:
参数名称 | 类型 | 作用域 | 典型值示例 |
---|---|---|---|
d_model | int | 模型维度 | 512 |
n_head | int | 注意力头数 | 8 |
ffn_hidden | int | 前馈层维度 | 2048 |
max_len | int | 序列最大长度 | 100 |
drop_prob | float | 正则化概率 | 0.1 |
self.src_pad_idx = src_pad_idx # 源序列填充符索引(如)
self.trg_pad_idx = trg_pad_idx # 目标序列填充符索引
self.trg_sos_idx = trg_sos_idx # 目标序列起始符索引(如)
self.encoder = Encoder(d_model=d_model,
n_head=n_head,
max_len=max_len,
ffn_hidden=ffn_hidden,
enc_voc_size=enc_voc_size,
drop_prob=drop_prob,
n_layers=n_layers,
device=device)
编码器初始化流程图:
章节跳转: 编码器模块Encoder
self.decoder = Decoder(d_model=d_model,
n_head=n_head,
max_len=max_len,
ffn_hidden=ffn_hidden,
dec_voc_size=dec_voc_size,
drop_prob=drop_prob,
n_layers=n_layers,
device=device)
解码器特有结构:
章节跳转: 解码器模块Decoder
def make_src_mask(self, src):
src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
return src_mask
张量形状演变示例:
原始输入:batch_size × seq_len
unsqueeze(1):batch_size × 1 × seq_len
unsqueeze(2):batch_size × 1 × 1 × seq_len
最终掩码:用于广播到多头注意力计算
def make_trg_mask(self, trg):
# 为目标序列生成填充掩码,填充索引位置的值为False,其余为True
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)
# 获取目标序列的长度
trg_len = trg.shape[1]
# 生成一个下三角矩阵作为序列掩码,确保解码时只能看到当前位置及之前的信息
trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor).to(self.device)
# 将填充掩码和序列掩码进行逻辑与操作,得到最终的目标序列掩码
trg_mask = trg_pad_mask & trg_sub_mask
return trg_mask
组合掩码可视化:
填充掩码:[[1,1,0], [1,0,0]] → 三维扩展
序列掩码:[[1,0,0], [1,1,0]] → 下三角矩阵
最终掩码:按位与操作后的4维张量
enc_src = self.encoder(src, src_mask)
output = self.decoder(trg, enc_src, trg_mask, src_mask)
def forward(self, src, trg):
# 为源序列生成掩码
src_mask = self.make_src_mask(src)
# 为目标序列生成掩码
trg_mask = self.make_trg_mask(trg)
# 通过编码器处理源序列
enc_src = self.encoder(src, src_mask)
# 通过解码器处理目标序列,同时利用编码器的输出和掩码
output = self.decoder(trg, enc_src, trg_mask, src_mask)
# 返回解码器的输出
return output
数据流动示意图:
self.device = device
trg_sub_mask = ...to(self.device)
1. 张量预分配: 提前创建掩码模板
2. 内存复用: 编码器输出直接用于解码
3. 并行计算: 利用矩阵运算加速注意力计算
4. 缓存机制: 重复使用位置编码
1. 机器翻译: 端到端序列转换
2. 文本摘要: 长文本压缩生成
3. 对话系统: 上下文感知响应
4. 代码生成: 结构化文本合成
假设输入:
源掩码生成过程:
原始比较:[[True,True,False], [True,False,False]]
维度扩展:2x1x1x3 → 广播到注意力头
最终掩码:在softmax前将False位置设为-1e9
目标掩码组合演示:
填充掩码:2x1x3x1 → 第三维扩展
序列掩码:3x3的下三角矩阵
组合结果:每个时间步仅可见之前位置的有效token
plt.matshow()
展示矩阵结构torch.autograd.gradcheck()
验证反向传播tensor.device
属性检查assert src_mask.dim() == 4, "源掩码应为4维张量"
该实现完整保留了原始Transformer的核心设计思想,通过模块化封装实现了高效的序列到序列建模。开发者在实际应用中可根据任务需求调整编码器/解码器层数、注意力头数等关键参数,平衡模型容量与计算效率。
八、完整代码实现
(此处完整保留原始代码,按功能模块添加注释说明)
"""
@author : Hyunwoong
@when : 2019-12-18
@homepage : https://github.com/gusdnd852
"""
import torch
from torch import nn
# 从其他模块导入Encoder和Decoder类
from models.model.decoder import Decoder
from models.model.encoder import Encoder
class Transformer(nn.Module):
# 初始化方法(参数配置中心)
def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx,
enc_voc_size, dec_voc_size, d_model, n_head,
max_len, ffn_hidden, n_layers, drop_prob, device):
super().__init__()
# 关键索引注册(实现序列控制)
self.src_pad_idx = src_pad_idx # 源填充索引
self.trg_pad_idx = trg_pad_idx # 目标填充索引
self.trg_sos_idx = trg_sos_idx # 起始符索引
self.device = device # 硬件环境绑定
# 编码器构造(特征提取核心)
self.encoder = Encoder(
d_model=d_model, # 隐层维度
n_head=n_head, # 并行注意力头数
max_len=max_len, # 最大序列长度
ffn_hidden=ffn_hidden, # 前馈层扩展维度
enc_voc_size=enc_voc_size, # 源词汇表大小
drop_prob=drop_prob, # 正则化比率
n_layers=n_layers, # 堆叠层数
device=device) # 设备一致性
# 解码器构造(自回归生成核心)
self.decoder = Decoder(
d_model=d_model,
n_head=n_head,
max_len=max_len,
ffn_hidden=ffn_hidden,
dec_voc_size=dec_voc_size, # 目标词汇表大小
drop_prob=drop_prob,
n_layers=n_layers,
device=device)
# 前向传播管道(计算图入口)
def forward(self, src, trg):
# 动态掩码生成(适配不同输入)
src_mask = self.make_src_mask(src) # 源序列掩码
trg_mask = self.make_trg_mask(trg) # 目标序列掩码
# 编码-解码信息流
enc_src = self.encoder(src, src_mask) # 上下文编码
# 通过解码器处理目标序列,同时利用编码器的输出和掩码
output = self.decoder(trg, enc_src, trg_mask, src_mask) # 自回归解码
return output # 输出概率分布
# 源掩码生成器(处理变长输入)
def make_src_mask(self, src):
# 布尔掩码生成(True表示有效token)
mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
# 维度扩展:batch_size × 1 × 1 × seq_len
return mask # 适配多头注意力计算
# 目标掩码生成器(防止信息泄露)
def make_trg_mask(self, trg):
# 填充掩码(基础有效性判断)
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)
# 序列掩码(限制未来信息)
trg_len = trg.shape[1]
# 生成一个下三角矩阵作为序列掩码,确保解码时只能看到当前位置及之前的信息
trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor).to(self.device)
# 组合掩码(逻辑与运算)
trg_mask = trg_pad_mask & trg_sub_mask
return trg_mask