pytorch之torch.nn.Embedding介绍

pytorch之torch.nn.Embedding介绍

  • 简介
  • 参数详解
  • 代码

简介

词嵌入层,该模块通常用于存储单词嵌入并使用索引检索它们。模块的输入是索引列表,而输出是相应的词嵌入。

参数详解

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx: Optional[int] = None, max_norm: Optional[float] = None, norm_type: float = 2.0, scale_grad_by_freq = False, sparse=False, _weight: Optional[torch.Tensor] = None)

  • num_embeddings(int):嵌入字典的大小

  • embedding_dim(int):每个嵌入向量的大小

  • padding_idx(int,optional):如果给定,则在padding_idx 遇到索引时,将输出嵌入矢量(初始化为零)。

  • max_norm(float,optional):如果给定,则将范数大于的每个嵌入向量max_norm 重新归一化为norm max_norm。

  • norm_type(float,optional):为该max_norm选项计算的p范数的p ,默认2。

  • scale_grad_by_freq(boolean ,可选):默认False。

  • sparse(bool,可选):默认为False。

代码

import torch


embedding1 = torch.nn.Embedding(10, 3)
input = torch.LongTensor([1, 2, 3, 4])
print(embedding1(input))

embedding2 = torch.nn.Embedding(10, 3, padding_idx=1)
print(embedding2(input))

显示结果:
tensor([[ 1.5469, -0.1033, -0.0621],
[ 1.0687, -0.2366, -0.6160],
[-0.7858, 0.9059, 0.4532],
[-0.3800, 1.0814, -0.9565]], grad_fn=)

tensor([[ 0.0000, 0.0000, 0.0000],
[ 0.4901, 0.0957, 0.5926],
[ 0.0638, -0.3168, -1.0536],
[ 0.1272, 1.1345, -0.5279]], grad_fn=)

你可能感兴趣的:(pytorch)