词嵌入层,该模块通常用于存储单词嵌入并使用索引检索它们。模块的输入是索引列表,而输出是相应的词嵌入。
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq = False, sparse=False, _weight: Optional[torch.Tensor] = None)
num_embeddings(int):嵌入字典的大小
embedding_dim(int):每个嵌入向量的大小
padding_idx(int,optional):如果给定,则在padding_idx 遇到索引时,将输出嵌入矢量(初始化为零)。
max_norm(float,optional):如果给定,则将范数大于的每个嵌入向量max_norm 重新归一化为norm max_norm。
norm_type(float,optional):为该max_norm选项计算的p范数的p ,默认2。
scale_grad_by_freq(boolean ,可选):默认False。
sparse(bool,可选):默认为False。
import torch
embedding1 = torch.nn.Embedding(10, 3)
input = torch.LongTensor([1, 2, 3, 4])
print(embedding1(input))
embedding2 = torch.nn.Embedding(10, 3, padding_idx=1)
print(embedding2(input))
显示结果:
tensor([[ 1.5469, -0.1033, -0.0621],
[ 1.0687, -0.2366, -0.6160],
[-0.7858, 0.9059, 0.4532],
[-0.3800, 1.0814, -0.9565]], grad_fn=)
tensor([[ 0.0000, 0.0000, 0.0000],
[ 0.4901, 0.0957, 0.5926],
[ 0.0638, -0.3168, -1.0536],
[ 0.1272, 1.1345, -0.5279]], grad_fn=)