后面的一堆东西,用过的都知道好像都没啥用是不是,那就别管它了。
其实一般就是前用两个参数
第2个参数 embedding_dim 就是嵌入向量的维度,即用embedding_dim值的维数来表示一个基本单位。
第1个参数 num_embeddings 就是生成num_embeddings个嵌入向量。
比如下面的代码,生成6个嵌入向量,每个嵌入向量的维度是2,
然后通过 .weight 看到生成的都是随机数:
embedding = nn.Embedding(6, 2)
print(embedding.weight)
Parameter containing:
tensor([[-0.7965, -0.2459],
[ 1.1508, -0.7320],
[ 0.0154, -0.2846],
[ 0.2236, -0.0293],
[ 0.5198, -1.1245],
[ 0.7478, 0.3544]], requires_grad=True)
我们知道嵌入向量的每个值是随机数,那么这些随机数服从什么分布呢?
embedding = nn.Embedding(1, 3000)
m = embedding.weight[0]
print('均值:',torch.mean(m),'方差:',torch.var(m))
from scipy.stats import shapiro
print('结果:',shapiro(m.detach().numpy()))
均值: tensor(0.0049, grad_fn=<MeanBackward0>) 方差: tensor(0.9697, grad_fn=<VarBackward>)
结果: ShapiroResult(statistic=0.9994555711746216, pvalue=0.5714033246040344)
通过上面的代码我们可以看到数据的均值为0,方差为1。并且同时scipy.stats.shapiro可以检测数据是否符合正太分布,输出结果中第一个为统计量,统计量越接近1越代表数据和正态分布拟合的好,可以看到第一个值接近于1。所以我们可以知道:嵌入向量中的值是服从标准正态分布的。
那么我们是如何使用这些向量表示我们的词呢?
是索引。
看下面的代码:
embedding = nn.Embedding(6, 2)
x=torch.rand((3,2))*10
embedding(x)
这样的结果是报错,因为float不能作为索引:
要求输入必须是整数型。
embedding = nn.Embedding(6, 2)
print('embedding.weight:',embedding.weight)
t = torch.ones((100,100)).to(int)
print('embedding(t):',embedding(t))
embedding.weight: Parameter containing:
tensor([[ 0.2666, 1.6938],
[-0.2191, 0.2008],
[ 0.0308, 0.3599],
[ 0.6266, -1.1199],
[ 2.0277, -0.3861],
[-0.1880, 0.6900]], requires_grad=True)
embedding(t): tensor([[[-0.2191, 0.2008],
[-0.2191, 0.2008],
[-0.2191, 0.2008],
...,
[-0.2191, 0.2008],
[-0.2191, 0.2008],
[-0.2191, 0.2008]],
[[-0.2191, 0.2008],
[-0.2191, 0.2008],
[-0.2191, 0.2008],
输出结果很长,所以这里只展示一部分。我们先生成了一个 t , t 相当于我们的文本了,一共有100个句子,每个句子100个词。而我们对 t 的编码结果是每个词都是[-0.2191, 0.2008]
,我们发现这和embedding.weight[1]
是一样的。因为Embedding就是根据索引值编码的。Embedding生成的每个嵌入向量都通过索引存储起来,以供后期编码使用。
又比如:
embedding = nn.Embedding(6, 2)
print('embedding.weight:',embedding.weight)
x=(torch.rand((3,2))*10).to(int)
print('x:',x)
print('embedding(x):',embedding(x))
embedding.weight: Parameter containing:
tensor([[-0.8326, -0.4983],
[ 2.1820, -0.1443],
[-0.4806, -0.1235],
[ 0.3978, -0.2895],
[ 1.2748, -1.1211],
[-0.4095, -0.5246]], requires_grad=True)
x: tensor([[2, 6],
[4, 5],
[4, 7]])
以及报错:
我们可以看到报错信息索引超过了范围,因为Embedding的最大索引为5,而x中的6和7超过了这个值,所以就报错了。
并且nn.Embedding()是可以训练的,参考这个:
torch.nn.Embedding是否有梯度,是否会被训练
注:其实这个和nn.Parameter()的工作原理是一样的,都是先生成一些嵌入向量,然后调用它对自己的输入进行编码时,通过索引来取值编码。