Pytorch学习-Embedding

官方文档的解释:

embedding方法为一个存储了固定字典和大小的简单的查找表

通常利用indices去检索存储的word embedding,输入是一系列indices,输出对应的word embeddings

Pytorch学习-Embedding_第1张图片

from_pretrained method

根据输入的2d FloatTensor 格式的embeddings参数 生成embedding实例,其中embeddings参数第一维度为词典大小num_embeddings , 第二维度为 词典中每个词的维度embedding_dim
Pytorch学习-Embedding_第2张图片

具体例子

首先设定embeddings参数为(10,3),代表词典中有十个词, 每个词的维度embedding_dim 为3。看到LongTensor 1对应的embedding为 [ 0.7417, 0.2151, -2.0011] 维度为3,扩展到多个LongTensor,输出的embedding的每一行对应输入的每个LongTensor
Pytorch学习-Embedding_第3张图片

需要注意词典大小不能小于输入的LongTensor的个数,否则会报数组越界的错误
Pytorch学习-Embedding_第4张图片

  • pretrained方法
    embeddings参数为(10,3)保持不变,用from_pretrained方法传入预先设定的weight,代表了每个词的embedding,之后的索引操作和上述相同。
    Pytorch学习-Embedding_第5张图片

你可能感兴趣的:(pytorch,深度学习,pytorch)