torch.nn.Embedding

文章目录

    • 方法介绍
    • 例子

方法介绍

CLASS torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)
  • 是一个存储固定字典和大小的 lookup table
  • 常用来存储word embedding,并使用索引来检索他们,输入是一个索引列表,输出是对应的word embeddings

主要参数介绍:

  • num_embeddings: embedding字典大小
  • embedding_dim:每个embedding向量的大小、
  • pddding_idx(optional):如果设置了 padding_idx,其对应的 entries 就不会对梯度有贡献。
    • 因此,在 padding_idx 位置的 embedding 在训练的时候不会更新,是一个固定的 pad
    • 对一个新创建的 Embedding,padding_idx处的embedding向量是全0,也可以用其他值来代替

例子

不含 padding_idx

input = torch.tensor([[0, 1, 2], [1, 2, 3]])
embedding = nn.Embedding(4, 3)
print(embedding(input))
>tensor([[[ 0.5663,  1.6136, -0.7869],
         [ 0.1662,  0.1890,  1.2086],
         [-1.1227, -0.0549, -0.8810]],

        [[ 0.1662,  0.1890,  1.2086],
         [-1.1227, -0.0549, -0.8810],
         [ 1.1553, -0.9264, -2.2916]]], grad_fn=<EmbeddingBackward0>)
  • 可以看到,Embedding 创建了一个 4x3 的矩阵,一共可以表示4个索引,每个表示都是一个包含3个数字的向量

包含 padding_idx

input = torch.tensor([[0, 1, 2], [1, 2, 3]])
embedding = nn.Embedding(4, 3, padding_idx=1)
print(embedding(input))
>tensor([[[-1.8803, -0.3651,  0.0161],
         [ 0.0000,  0.0000,  0.0000],
         [ 1.6608, -0.0528,  1.3493]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 1.6608, -0.0528,  1.3493],
         [ 2.0451,  0.5492,  0.5848]]], grad_fn=<EmbeddingBackward0>)
  • padding_idx=1 的embedding vector为全 0

padding_idx在backward gradient时会被忽略,参考链接

你可能感兴趣的:(PyTorch,人工智能,深度学习)