介绍一下我们常用的嵌入函数torch.nn.functional.embedding
,先看一下参数:torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)
。我们常使用的就是前两个参数。input是在词向量矩阵中的索引列表,词向量矩阵,行数为最大可能的索引数+1,列数为词向量的维度。那么具体是什么含义呢?
input
参数是我们想要用向量表示的对象(可以是文本、图等)的索引,weight
储存了我们的向量表示,这个函数的目的就是输出一个索引和向量的对应关系。看一个例子:
import numpy as np
import torch.nn.functional as F
import torch
input = torch.tensor([[1,2,4,5],[4,3,2,8]])
embedding_matrix = torch.tensor(
[[0.7330, 0.9718, 0.9023],
[0.7769, 0.7640, 0.3664],
[0.6036, 0.3873, 0.5681],
[0.8422, 0.6275, 0.5400],
[0.0346, 0.5622, 0.2547],
[0.4926, 0.9282, 0.1762],
[0.0037, 0.5831, 0.4443],
[0.2001, 0.1086, 0.0518],
[0.6574, 0.9185, 0.3451]])
print(F.embedding(input, embedding_matrix))
结果:
tensor([[[0.7769, 0.7640, 0.3664],
[0.6036, 0.3873, 0.5681],
[0.0346, 0.5622, 0.2547],
[0.4926, 0.9282, 0.1762]],
[[0.0346, 0.5622, 0.2547],
[0.8422, 0.6275, 0.5400],
[0.6036, 0.3873, 0.5681],
[0.6574, 0.9185, 0.3451]]])
注意:
input
中第一个元素是1,那么在输出中,对应weight中下标为1的向量,就是weight[1]
,其值是[0.7769, 0.7640, 0.3664]
,其他的元素以此类推。weight
的行数是可能的最大检索数+1的原因是,input
的元素需要访问下标为index的数值。在例子中,最大的所以数为8,故weight
的行数至少为9,当然可以大于9.