简介
skip-gram(跳元模型)
负采样
算法结构
取样方法
本文参考了李沐老师在《动手学深度学习》中的代码,加入了自己的理解,希望能让各位更理解负采样在skip-gram中的应用。
跳元模型是一种用中心词来预测上下文词的方法。但预测上下文词的水平究竟怎样不是我们最关心的事情,我们关心的是其嵌入模型究竟训练的如何。
假设一个训练句子为:The dog is barking.
我们需要设置一个超参数-上下文窗口大小,即以当前词为中心词可以向前向后找多远的上下文词,具体来说,设定上下文窗口大小为2。则训练样本(中心词,上下文词)为
(The,dog) (The,is)
(dog,The) (dog,is) (dog,barking)
(is,dog) (is,The) (is,barking)
(barking,is) (barking,dog)
共10个样本
网络结构如下图所示:
输入为中心词的one-hot向量,向量长度为词汇表的长度V,输出为一个长度为V的向量,每个元素代表词汇表中相应单词是上下文词的预测概率,计算公式为。
代表该词作为上下文词
的嵌入向量,
代表中心词
的嵌入向量。
负采样本身是一种机器学习中用来从负样本中进行采样,从而减少负样本比例的方法。在skip-gram中的应用主要在于削减中的大量负样本的计算成本,假如我们词汇表里有100k个词汇,那么在每个样本输入都要计算一次100k的指数运算,当词汇表内词汇量更大时,该计算量会变得更大。
负采样的思路是:对于每个中心词,选择一个真正的上下文词当作正样本,随机选择k个词典中的词作为负样本,网络想要解决的问题从分辨哪个词是上下文词变为哪个词是真正的上下文词,哪些词是随机选的噪声词。
具体来说网络结构变成了下图:
在前述的基本skip-gram中,对于每个上下文词,我们需要计算,其计算规模和词汇表大小成正比,而负采样后则是和k成正比,论文原文中建议对于小数据集k可以取5-20,对于大数据集,可以取2-5,在大数据集中由于多个样本的共同作用,可以抵消负样本较少的缺点,k就可以小一些。
实现代码如下:
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
v = embed_v(center) #获得中心词的嵌入向量(batch_size,seq_len,嵌入向量维度)
u = embed_u(contexts_and_negatives) #获得上下文词及噪声词的嵌入向量
pred = torch.bmm(v, u.permute(0, 2, 1)) #转置并做矩阵乘法(对每批次做矩阵乘法)
return pred
每个样本被取到作为噪声词的概率和他在所有样本中出现的频率的0.75次方成正比。