torch.nn.functional.embedding的参数理解,尤其是weight

介绍一下我们常用的嵌入函数torch.nn.functional.embedding,先看一下参数:torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)。我们常使用的就是前两个参数。input是在词向量矩阵中的索引列表,词向量矩阵,行数为最大可能的索引数+1,列数为词向量的维度。那么具体是什么含义呢?
input参数是我们想要用向量表示的对象(可以是文本、图等)的索引,weight储存了我们的向量表示,这个函数的目的就是输出一个索引和向量的对应关系。看一个例子:

import numpy as np
import torch.nn.functional as F
import torch

input = torch.tensor([[1,2,4,5],[4,3,2,8]])
embedding_matrix = torch.tensor(
        [[0.7330, 0.9718, 0.9023],
        [0.7769, 0.7640, 0.3664],
        [0.6036, 0.3873, 0.5681],
        [0.8422, 0.6275, 0.5400],
        [0.0346, 0.5622, 0.2547],
        [0.4926, 0.9282, 0.1762],
        [0.0037, 0.5831, 0.4443],
        [0.2001, 0.1086, 0.0518],
        [0.6574, 0.9185, 0.3451]])
print(F.embedding(input, embedding_matrix))

结果:

tensor([[[0.7769, 0.7640, 0.3664],
         [0.6036, 0.3873, 0.5681],
         [0.0346, 0.5622, 0.2547],
         [0.4926, 0.9282, 0.1762]],

        [[0.0346, 0.5622, 0.2547],
         [0.8422, 0.6275, 0.5400],
         [0.6036, 0.3873, 0.5681],
         [0.6574, 0.9185, 0.3451]]])

注意:

  1. 在变量input中第一个元素是1,那么在输出中,对应weight中下标为1的向量,就是weight[1],其值是[0.7769, 0.7640, 0.3664],其他的元素以此类推。
  2. weight的行数是可能的最大检索数+1的原因是,input的元素需要访问下标为index的数值。在例子中,最大的所以数为8,故weight 的行数至少为9,当然可以大于9.

你可能感兴趣的:(pytorch,机器学习,python)