Transformer 代码剖析15 - Transformer模型代码 (pytorch实现)

一、模型架构全景解析

1.1 类定义与继承关系

class Transformer(nn.Module):

该实现继承PyTorch的nn.Module基类,采用面向对象设计模式。核心架构包含编码器-解码器双塔结构,通过参数配置实现NLP任务的通用处理能力。

Transformer
Encoder
Decoder
Multi-Head Attention
Feed Forward
Masked Multi-Head Attention
Cross-Multi-Head Attention

章节跳转: 多头注意力机制(MultiHeadAttention)

1.2 构造函数参数详解

def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, 
             enc_voc_size, dec_voc_size, d_model, n_head, 
             max_len, ffn_hidden, n_layers, drop_prob, device):

参数矩阵说明:

参数名称 类型 作用域 典型值示例
d_model int 模型维度 512
n_head int 注意力头数 8
ffn_hidden int 前馈层维度 2048
max_len int 序列最大长度 100
drop_prob float 正则化概率 0.1

二、核心组件初始化剖析

2.1 掩码索引初始化

self.src_pad_idx = src_pad_idx  # 源序列填充符索引(如
self.trg_pad_idx = trg_pad_idx  # 目标序列填充符索引
self.trg_sos_idx = trg_sos_idx  # 目标序列起始符索引(如

2.2 编码器构建流程

self.encoder = Encoder(d_model=d_model,
                      n_head=n_head,
                      max_len=max_len,
                      ffn_hidden=ffn_hidden,
                      enc_voc_size=enc_voc_size,
                      drop_prob=drop_prob,
                      n_layers=n_layers,
                      device=device)

编码器初始化流程图:

词汇表大小
词嵌入层
位置编码
注意力机制
前馈网络
层归一化
多头注意力
残差连接

章节跳转: 编码器模块Encoder

2.3 解码器构建流程

self.decoder = Decoder(d_model=d_model,
                       n_head=n_head,
                       max_len=max_len,
                       ffn_hidden=ffn_hidden,
                       dec_voc_size=dec_voc_size,
                       drop_prob=drop_prob,
                       n_layers=n_layers,
                       device=device)

解码器特有结构:

  • 带掩码的多头注意力: 防止信息泄露
  • 交叉注意力机制: 连接编码器输出
  • 三阶段残差连接: 保持梯度流动

章节跳转: 解码器模块Decoder

三、前向传播过程详解

3.1 掩码生成机制

3.1.1 源序列掩码生成

def make_src_mask(self, src):
    src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
    return src_mask

张量形状演变示例:

原始输入:batch_size × seq_len
unsqueeze(1):batch_size × 1 × seq_len
unsqueeze(2):batch_size × 1 × 1 × seq_len
最终掩码:用于广播到多头注意力计算

3.1.2 目标序列掩码生成

def make_trg_mask(self, trg):
    # 为目标序列生成填充掩码,填充索引位置的值为False,其余为True
    trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)
    # 获取目标序列的长度
    trg_len = trg.shape[1]
    # 生成一个下三角矩阵作为序列掩码,确保解码时只能看到当前位置及之前的信息
    trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor).to(self.device)
    # 将填充掩码和序列掩码进行逻辑与操作,得到最终的目标序列掩码
    trg_mask = trg_pad_mask & trg_sub_mask
    return trg_mask

组合掩码可视化:

填充掩码:[[1,1,0], [1,0,0]] → 三维扩展
序列掩码:[[1,0,0], [1,1,0]] → 下三角矩阵
最终掩码:按位与操作后的4维张量

3.2 编码-解码流程

enc_src = self.encoder(src, src_mask)
output = self.decoder(trg, enc_src, trg_mask, src_mask)

3.3 前向传播

def forward(self, src, trg):
    # 为源序列生成掩码
    src_mask = self.make_src_mask(src)
    # 为目标序列生成掩码
    trg_mask = self.make_trg_mask(trg)
    # 通过编码器处理源序列
    enc_src = self.encoder(src, src_mask)
    # 通过解码器处理目标序列,同时利用编码器的输出和掩码
    output = self.decoder(trg, enc_src, trg_mask, src_mask)
    # 返回解码器的输出
    return output

数据流动示意图:

源序列 编码器 解码器 目标序列 Output 输入序列+源掩码 上下文向量 目标序列+双掩码 预测分布 源序列 编码器 解码器 目标序列 Output

四、关键技术创新点解析

4.1 动态掩码生成

  • 自适应序列长度: 自动处理变长输入
  • 硬件加速设计: 利用广播机制减少显存占用
  • 类型安全处理: ByteTensor类型避免计算误差

4.2 设备感知设计

self.device = device
trg_sub_mask = ...to(self.device)
  • 自动适配CPU/GPU环境
  • 统一设备内存管理
  • 分布式训练兼容性

五、性能优化策略

1. 张量预分配: 提前创建掩码模板
2. 内存复用: 编码器输出直接用于解码
3. 并行计算: 利用矩阵运算加速注意力计算
4. 缓存机制: 重复使用位置编码

六、扩展应用场景

1. 机器翻译: 端到端序列转换
2. 文本摘要: 长文本压缩生成
3. 对话系统: 上下文感知响应
4. 代码生成: 结构化文本合成

七、典型张量运算示例

假设输入:

  • 源序列: [[2,5,0], [3,0,0]](0为填充)
  • 目标序列: [[1,4,6], [1,0,0]]

源掩码生成过程:

原始比较:[[True,True,False], [True,False,False]]
维度扩展:2x1x1x3 → 广播到注意力头
最终掩码:在softmax前将False位置设为-1e9

目标掩码组合演示:

填充掩码:2x1x3x1 → 第三维扩展
序列掩码:3x3的下三角矩阵
组合结果:每个时间步仅可见之前位置的有效token

八、进阶调试技巧

  1. 掩码可视化:使用plt.matshow()展示矩阵结构
  2. 梯度检查:torch.autograd.gradcheck()验证反向传播
  3. 设备一致性验证:tensor.device属性检查
  4. 形状断言:在关键步骤插入维度检查
assert src_mask.dim() == 4, "源掩码应为4维张量"

该实现完整保留了原始Transformer的核心设计思想,通过模块化封装实现了高效的序列到序列建模。开发者在实际应用中可根据任务需求调整编码器/解码器层数、注意力头数等关键参数,平衡模型容量与计算效率。


原项目代码+注释(附)

八、完整代码实现
(此处完整保留原始代码,按功能模块添加注释说明)

"""
@author : Hyunwoong
@when : 2019-12-18
@homepage : https://github.com/gusdnd852
"""
import torch
from torch import nn

# 从其他模块导入Encoder和Decoder类
from models.model.decoder import Decoder
from models.model.encoder import Encoder


class Transformer(nn.Module):
    # 初始化方法(参数配置中心)
    def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, 
                enc_voc_size, dec_voc_size, d_model, n_head,
                max_len, ffn_hidden, n_layers, drop_prob, device):
        super().__init__()
        # 关键索引注册(实现序列控制)
        self.src_pad_idx = src_pad_idx  # 源填充索引
        self.trg_pad_idx = trg_pad_idx  # 目标填充索引
        self.trg_sos_idx = trg_sos_idx  # 起始符索引
        self.device = device  # 硬件环境绑定

        # 编码器构造(特征提取核心)
        self.encoder = Encoder(
            d_model=d_model,  # 隐层维度
            n_head=n_head,    # 并行注意力头数
            max_len=max_len,  # 最大序列长度
            ffn_hidden=ffn_hidden,  # 前馈层扩展维度
            enc_voc_size=enc_voc_size,  # 源词汇表大小
            drop_prob=drop_prob,  # 正则化比率
            n_layers=n_layers,  # 堆叠层数
            device=device)  # 设备一致性

        # 解码器构造(自回归生成核心)
        self.decoder = Decoder(
            d_model=d_model,
            n_head=n_head,
            max_len=max_len,
            ffn_hidden=ffn_hidden,
            dec_voc_size=dec_voc_size,  # 目标词汇表大小
            drop_prob=drop_prob,
            n_layers=n_layers,
            device=device)

    # 前向传播管道(计算图入口)
    def forward(self, src, trg):
        # 动态掩码生成(适配不同输入)
        src_mask = self.make_src_mask(src)  # 源序列掩码
        trg_mask = self.make_trg_mask(trg)  # 目标序列掩码

        # 编码-解码信息流
        enc_src = self.encoder(src, src_mask)  # 上下文编码
        # 通过解码器处理目标序列,同时利用编码器的输出和掩码
        output = self.decoder(trg, enc_src, trg_mask, src_mask)  # 自回归解码
        return output  # 输出概率分布

    # 源掩码生成器(处理变长输入)
    def make_src_mask(self, src):
        # 布尔掩码生成(True表示有效token)
        mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        # 维度扩展:batch_size × 1 × 1 × seq_len
        return mask  # 适配多头注意力计算

    # 目标掩码生成器(防止信息泄露)
    def make_trg_mask(self, trg):
        # 填充掩码(基础有效性判断)
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)
        # 序列掩码(限制未来信息)
        trg_len = trg.shape[1]
        # 生成一个下三角矩阵作为序列掩码,确保解码时只能看到当前位置及之前的信息
        trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor).to(self.device)
        # 组合掩码(逻辑与运算)
        trg_mask = trg_pad_mask & trg_sub_mask
        return trg_mask

你可能感兴趣的:(Transformer代码剖析,transformer,pytorch,深度学习,embedding,人工智能,python)