目录
PyTorch子模块Sparse functions详解
embedding
参数
输出形状
示例
带有 padding_idx 的示例
embedding_bag
参数
输出形状
示例
使用 padding_idx 的示例
one_hot
参数
返回
示例
总结
torch.nn.functional.embedding
是 PyTorch 中的一个函数,用于从固定字典和大小的简单查找表中检索嵌入(embeddings)。这个函数通常用于使用索引检索词嵌入。其输入是一个索引列表和嵌入矩阵,输出是相应的词嵌入。
input
(LongTensor):包含嵌入矩阵索引的张量。weight
(Tensor):嵌入矩阵,行数等于最大可能索引 + 1,列数等于嵌入大小。padding_idx
(int, 可选):如果指定,padding_idx 处的条目不会对梯度产生贡献;因此,在训练期间,padding_idx 处的嵌入向量不会更新,即它保持为固定的“填充”。max_norm
(float, 可选):如果给定,每个嵌入向量的范数大于 max_norm 时将被重新规范化为 max_norm。注意:这将就地修改 weight。norm_type
(float, 可选):用于计算 max_norm 选项的 p-范数的 p。默认为 2。scale_grad_by_freq
(bool, 可选):如果给定,将按照小批量中单词频率的倒数来缩放梯度。默认为 False。sparse
(bool, 可选):如果为 True,weight 相对于的梯度将是一个稀疏张量。有关稀疏梯度的更多细节,请参阅 torch.nn.Embedding
。import torch
import torch.nn.functional as F
# 两个样本的批次,每个样本有 4 个索引
input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
# 包含 10 个大小为 3 的张量的嵌入矩阵
embedding_matrix = torch.rand(10, 3)
# 使用 F.embedding 获取嵌入
output = F.embedding(input, embedding_matrix)
此例中,input
包含两个样本的索引,embedding_matrix
是一个随机初始化的嵌入矩阵。F.embedding
函数返回了这些索引对应的嵌入向量。
weights = torch.rand(10, 3)
weights[0, :].zero_() # 将索引 0 的嵌入向量设置为零
embedding_matrix = weights
input = torch.tensor([[0, 2, 0, 5]])
# 使用 padding_idx
output = F.embedding(input, embedding_matrix, padding_idx=0)
在这个例子中,索引 0 被用作填充索引(padding_idx),因此它的嵌入向量在训练过程中不会更新,且初始化为零。这对于处理可变长度的序列数据特别有用,其中某些位置可能需要被忽略。
torch.nn.functional.embedding_bag
是 PyTorch 中的一个函数,它计算嵌入向量的“包”(bag)的和、均值或最大值,而无需实例化中间的嵌入向量。这个函数对于处理文本数据特别有用,特别是在处理变长序列或者需要聚合嵌入表示时。
input
(LongTensor):包含嵌入矩阵索引的包的张量。weight
(Tensor):嵌入矩阵,行数等于最大可能索引 + 1,列数等于嵌入大小。offsets
(LongTensor, 可选):仅当输入是1D时使用。offsets确定每个包(序列)在输入中的起始索引位置。max_norm
(float, 可选):如果给定,每个嵌入向量的范数大于 max_norm 时将被重新规范化为 max_norm。注意:这将就地修改 weight。norm_type
(float, 可选):用于计算 max_norm 选项的 p-范数的 p。默认为 2。scale_grad_by_freq
(bool, 可选):如果给定,将按照小批量中单词频率的倒数来缩放梯度。默认为 False。mode
(str, 可选):可选 "sum", "mean" 或 "max"。指定聚合包的方式。默认为 "mean"。sparse
(bool, 可选):如果为 True,weight 相对于的梯度将是一个稀疏张量。per_sample_weights
(Tensor, 可选):浮点/双精度权重的张量,或 None 表示所有权重应视为 1。如果指定,per_sample_weights 的形状必须与 input 完全相同。include_last_offset
(bool, 可选):如果为 True,offsets 的大小等于包的数量 + 1。最后一个元素是输入的大小,或最后一个包(序列)的结束索引位置。padding_idx
(int, 可选):如果指定,padding_idx 处的条目不会对梯度产生贡献;因此,训练期间不会更新 padding_idx 处的嵌入向量,即它保持为固定的“填充”。输入:LongTensor,和可选的 offsets (LongTensor)
权重 (Tensor):可学习的模块权重,形状为 (num_embeddings, embedding_dim)。
per_sample_weights (Tensor, 可选):具有与输入相同形状的张量。
输出:聚合后的嵌入值,形状为 (B, embedding_dim)。
import torch
import torch.nn.functional as F
# 包含 10 个大小为 3 的张量的嵌入矩阵
embedding_matrix = torch.rand(10, 3)
# 一个样本的批次,包含 4 个索引
input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
# 指定每个包(序列)的开始位置
offsets = torch.tensor([0, 4])
# 使用 F.embedding_bag 获取嵌入
output = F.embedding_bag(input, embedding_matrix, offsets)
在此示例中,input
包含 8 个索引,表示两个序列(或“包”),由 offsets
指定其开始位置。embedding_matrix
是一个随机初始化的嵌入矩阵。F.embedding_bag
函数将返回这些索引对应的嵌入向量的聚合(默认为均值)。
embedding_matrix = torch.rand(10, 3)
input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9])
offsets = torch.tensor([0, 4])
# 使用 padding_idx
output = F.embedding_bag(input, embedding_matrix, offsets, padding_idx=2, mode='sum')
在这个例子中,索引 2 被用作填充索引(padding_idx
),这意味着在聚合时,索引为 2 的嵌入向量不会对结果产生贡献,并且在训练过程中不会更新这个嵌入向量。这在处理变长序列时特别有用,其中某些位置可能需要被忽略。示例中使用的 mode='sum'
表示对每个包内的嵌入向量进行求和操作。
embedding_bag
函数的这种处理方式比标准的 embedding
函数更高效,因为它避免了创建大量中间嵌入向量的步骤,特别是在处理包含许多短序列的大批量数据时。此外,它允许不同长度的序列共存于同一个批次中,这在处理自然语言处理任务时尤其有价值。
torch.nn.functional.one_hot
是 PyTorch 中的一个函数,它用于将长整型张量(LongTensor)转换为一种称为 one-hot 编码的形式。在 one-hot 编码中,每个类别的索引将被转换为一个向量,该向量中除了对应类别索引处的值为 1 外,其余位置均为 0。
tensor
(LongTensor):任何形状的类别值。num_classes
(int):总类别数。如果设置为 -1,则类别数将推断为输入张量中最大类别值加1。import torch
import torch.nn.functional as F
# 示例 1:基本用法
output = F.one_hot(torch.arange(0, 5) % 3)
print(output)
# tensor([[1, 0, 0],
# [0, 1, 0],
# [0, 0, 1],
# [1, 0, 0],
# [0, 1, 0]])
# 示例 2:指定 num_classes
output = F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
print(output)
# tensor([[1, 0, 0, 0, 0],
# [0, 1, 0, 0, 0],
# [0, 0, 1, 0, 0],
# [1, 0, 0, 0, 0],
# [0, 1, 0, 0, 0]])
# 示例 3:使用多维张量
output = F.one_hot(torch.arange(0, 6).view(3, 2) % 3)
print(output)
# tensor([[[1, 0, 0],
# [0, 1, 0]],
# [[0, 0, 1],
# [1, 0, 0]],
# [[0, 1, 0],
# [0, 0, 1]]])
在这些示例中,torch.arange(0, 5) % 3
生成一个周期为 3 的序列,然后 F.one_hot
将这些值转换为 one-hot 编码形式。在第二个示例中,通过指定 num_classes=5
,可以控制 one-hot 编码向量的长度。在第三个示例中,展示了如何对多维张量进行 one-hot 编码。
本篇博客探讨了 PyTorch 框架中几个关键的稀疏函数,包括 embedding
、embedding_bag
和 one_hot
。这些函数在处理自然语言处理(NLP)任务和其他需要高效、灵活处理大量类别或序列数据的应用中至关重要。embedding
函数用于从预定义的嵌入矩阵中检索指定索引的嵌入向量,支持自定义嵌入矩阵大小、填充索引和范数限制。embedding_bag
提供了一种高效的方法来处理变长序列,通过聚合(如求和、均值或最大值)嵌入向量,而无需单独处理每个序列。one_hot
函数则用于将类别标签转换为 one-hot 编码形式,适用于处理分类任务中的标签数据。这些函数的灵活性和高效性使它们成为深度学习模型设计和实现中的重要工具。