tf.nn.sampled_softmax_loss的计算原理

可以看作是Softmax计算的另一种加速。 (注意区分下Word2Vec的huffman softmax)

在正常每个time-step输出的时候,我们使用的是直接在整个字典大小上softmax:


其中的分母Z是所有的词语的响应和:

Z就是每个词响应的累加,可见Z的计算量太大了,如果每个time-step我们都需要计算每个词的响应并且累加的话,耗时很长。

所以我们进一步分析发现,式子(6)对于目标词的梯度可以写成(左右取log,所以分子分母变加减法):

其中E定义为某个词语的能量(其实也是这个词的响应)

Z对目标词的梯度就是式子8中的负项
不看负号,这个项实际是能量E在字典V上的的梯度期望啊(每个词y_k的概率×它的梯度),也就是:

作者用importance sampling的方法,找到只跟目标词相关的词语。
我们假设预定义一个分布Q,并且有一个从Q采样的字典V',用Q逼近式子9

如何得到Q分布?就是把语料进行partition,每个partition当字典大小到达阈值thres的时候,停止增加词语。比如第i个partition的字典是V'_i。

这个时候的Q_i定义为:


代入式子10-11,Q项可以被抵消,所以最后的p项(式子8中的负项的p)为

(12)

可以看成,训练的时候的softmax就是在一个小的set V'上计算的

而我猜想V'是有正相关和负相关词语的。所以只在小范围内optimize,让网络向目标的词汇更快地迈进,也是合理的。相关的词会直接update到W和b权重里面。

在测试的时候,恢复在整个字典V上进行softmax。


引用

  1. 论文原文
  2. TF文档

你可能感兴趣的:(tf.nn.sampled_softmax_loss的计算原理)