目录
torch.nn子模块transformer详解
nn.Transformer
Transformer 类描述
Transformer 类的功能和作用
Transformer 类的参数
forward 方法
参数
输出
示例代码
注意事项
nn.TransformerEncoder
TransformerEncoder 类描述
TransformerEncoder 类的功能和作用
TransformerEncoder 类的参数
forward 方法
参数
返回类型
形状
示例代码
nn.TransformerDecoder
TransformerDecoder 类描述
TransformerDecoder 类的功能和作用
TransformerDecoder 类的参数
forward 方法
参数
返回类型
形状
示例代码
nn.TransformerEncoderLayer
TransformerEncoderLayer 类描述
TransformerEncoderLayer 类的功能和作用
TransformerEncoderLayer 类的参数
forward 方法
参数
返回类型
形状
示例代码
nn.TransformerDecoderLayer
TransformerDecoderLayer 类描述
TransformerDecoderLayer 类的功能和作用
TransformerDecoderLayer 类的参数
forward 方法
参数
返回类型
形状
示例代码
总结
torch.nn.Transformer
类是 PyTorch 中实现 Transformer 模型的核心类。基于 2017 年的论文 “Attention Is All You Need”,该类提供了构建 Transformer 模型的完整功能,包括编码器(Encoder)和解码器(Decoder)部分。用户可以根据需要调整各种属性。
forward
方法用于处理带掩码的源/目标序列。
(T, N, E)
或 (N, T, E)
(如果 batch_first=True
),其中 T
是目标序列长度,N
是批次大小,E
是特征数。import torch
import torch.nn as nn
# 创建 Transformer 实例
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
# 输入数据
src = torch.rand((10, 32, 512))
tgt = torch.rand((20, 32, 512))
# 前向传播
out = transformer_model(src, tgt)
这段代码展示了如何创建并使用 Transformer 模型。在这个例子中,src
和 tgt
分别是随机生成的编码器和解码器的输入张量。输出 out
是模型的最终输出。
generate_square_subsequent_mask
方法来生成序列的因果掩码。torch.nn.TransformerEncoder
类在 PyTorch 中实现了 Transformer 模型的编码器部分。它是一系列编码器层的堆叠,用户可以通过这个类构建类似于 BERT 的模型。
TransformerEncoderLayer
实例,表示单个编码器层(必需)。forward
方法用于顺序通过编码器层处理输入。
import torch
import torch.nn as nn
# 创建 TransformerEncoderLayer 实例
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
# 创建 TransformerEncoder 实例
transformer_encoder = nn.TransformeEncoder(encoder_layer, num_layers=6)
# 输入数据
src = torch.rand(10, 32, 512) # 随机输入
# 前向传播
out = transformer_encoder(src)
这段代码展示了如何创建并使用 TransformerEncoder
。在这个例子中,src
是随机生成的输入张量,transformer_encoder
是由 6 层编码器层组成的编码器。输出 out
是编码器的最终输出。
torch.nn.TransformerDecoder
类实现了 Transformer 模型的解码器部分。它是由多个解码器层堆叠而成,用于处理编码器的输出并生成最终的输出序列。
TransformerDecoderLayer
实例,表示单个解码器层(必需)。forward
方法用于将输入(及掩码)依次通过解码器层进行处理。
import torch
import torch.nn as nn
# 创建 TransformerDecoderLayer 实例
decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
# 创建 TransformerDecoder 实例
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
# 输入数据
memory = torch.rand(10, 32, 512) # 编码器的输出
tgt = torch.rand(20, 32, 512) # 解码器的输入
# 前向传播
out = transformer_decoder(tgt, memory)
这段代码展示了如何创建并使用 TransformerDecoder
。在这个例子中,memory
是编码器的输出,tgt
是解码器的输入。输出 out
是解码器的最终输出。
torch.nn.TransformerEncoderLayer
类构成了 Transformer 编码器的基础单元,每个编码器层包含一个自注意力机制和一个前馈网络。这种标准的编码器层基于论文 "Attention Is All You Need"。
forward
方法用于将输入通过编码器层进行处理。
import torch
import torch.nn as nn
# 创建 TransformerEncoderLayer 实例
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
# 输入数据
src = torch.rand(10, 32, 512) # 随机输入
# 前向传播
out = encoder_layer(src)
或者在 batch_first=True
的情况下:
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
src = torch.rand(32, 10, 512)
out = encoder_layer(src)
这段代码展示了如何创建并使用 TransformerEncoderLayer
。在这个例子中,src
是随机生成的输入张量。输出 out
是编码器层的输出。
torch.nn.TransformerDecoderLayer
类是构成 Transformer 模型解码器的基本单元。这个标准的解码器层基于论文 "Attention Is All You Need"。它由自注意力机制、多头注意力机制和前馈网络组成。
forward
方法用于将输入(及掩码)通过解码器层进行处理。
import torch
import torch.nn as nn
# 创建 TransformerDecoderLayer 实例
decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
# 输入数据
memory = torch.rand(10, 32, 512) # 编码器的输出
tgt = torch.rand(20, 32, 512) # 解码器的输入
# 前向传播
out = decoder_layer(tgt, memory)
或者在 batch_first=True
的情况下:
decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
memory = torch.rand(32, 10, 512)
tgt = torch.rand(32, 20, 512)
out = decoder_layer(tgt, memory)
这段代码展示了如何创建并使用 TransformerDecoderLayer
。在这个例子中,memory
是编码器的输出,tgt
是解码器的输入。输出 out
是解码器层的输出。
本篇博客深入探讨了 PyTorch 的 torch.nn
子模块中与 Transformer 相关的核心组件。我们详细介绍了 nn.Transformer
及其构成部分 —— 编码器 (nn.TransformerEncoder
) 和解码器 (nn.TransformerDecoder
),以及它们的基础层 —— nn.TransformerEncoderLayer
和 nn.TransformerDecoderLayer
。每个部分的功能、作用、参数配置和实际应用示例都被全面解析。这些组件不仅提供了构建高效、灵活的 NLP 模型的基础,还展示了如何通过自注意力和多头注意力机制来捕捉语言数据中的复杂模式和长期依赖关系。