Sampled Softmax


sampled softmax原论文:On Using Very Large Target Vocabulary for Neural Machine Translation
以及tensorflow关于candidate sampling的文档:candidate sampling


1. 问题背景

在神经机器翻译中,训练的复杂度以及解码的复杂度和词汇表的大小成正比。当输出的词汇表巨大时,传统的softmax由于要计算每一个类的logits就会有问题。在论文Neural Machine Translation by Jointly Learning to Align and Translate 中,带有attention的decoder中权重的公式如下:

其中的 a a a为一个单层的前馈神经网络,根据 α t \alpha_t αt和输入的因状态,我们就可以得到一个context vector c t c_t ct。在decoder的t时刻,输出的目标词汇的概率可以使用如下公式计算:

其中, y t − 1 y_{t-1} yt1是上一个次的输出, z t z_t zt为当前decoder的隐状态, c t c_t ct为context vector。
因为我们输出的是一个概率值,所以(6)式的归一化银子 Z Z Z的计算就需要将词汇表当中的logits都计算一遍,这个代价是很大的。
基于此,作者提出了一种采样的方法,使得我们在训练的时候,输出为原来输出的一个子集。(关于其它的解决方法,作者也有提,感兴趣的可以看原文,本篇博客只关注Sampled Softmax)

2. 解决方法

上面已经说过,计算归一化的因子 Z Z Z,因为所用的词太多造成复杂度的上升,那么原文的方法就是使用一个子集 V ′ V' V来近似的计算出 Z Z Z, 假设我们现在已经知道的这个子集,那么之前计算输出的概率公式就为:

Sampled Softmax_第1张图片

好了,那么 V ′ V' V怎么取?

我们看看tensorflow中的文档吧: https://www.tensorflow.org/extras/candidate_sampling.pdf
对于Sampled Softmax的每一个训练样例 ( x i , { t i } ) (x_i,\{t_i\}) (xi,{ti}),我们根据采样函数 Q ( y ∣ x ) Q(y|x) Q(yx),从所有的输出集合中挑选一个小的子集 S i S_{i} Si。要求选择子集的函数和具体的训练样本无关。假设full softmax的输出全集为 L L L, 那么在给定 x i x_i xi的情况下,根据分布 Q Q Q L L L中抽取的子集似然函数为:

然后我们生成一个包含 S i S_i Si和训练目标类的候选集合 V V V
V ′ = S i ∪ t i V'=S_i \cup{t_i} V=Siti
之后我们的训练目标就是找出样本为 V ′ V' V的哪一个类别了。
(感觉还是tensorflow文档说的清楚一点,最初看论文的时候还以为是相当于把一个单词划分到最近的一个类,那样的话,应该会有不同类别的关系啊不然也不make sense啊,但是看tensorflow源码就只有采样的过程啊,笑cry)

3. tensorflow的实现

def sampled_softmax_loss(weights,
                         biases,
                         labels,
                         inputs,
                         num_sampled, # 每一个batch随机选择的类别
                         num_classes, # 所有可能的类别
                         num_true=1, #每一个sample的类别数量
                         sampled_values=None,
                         remove_accidental_hits=True,
                         partition_strategy="mod",
                         name="sampled_softmax_loss"):

tensorflow对于使用的建议:仅仅在训练阶段使用,在inference或者evaluation的时候还是需要使用full softmax。

原文:
This operation is for training only. It is generally an underestimate of
the full softmax loss.
A common use case is to use this method for training, and calculate the full softmax loss for evaluation or inference.

这个函数的主体主要调用了另外一个函数:

logits, labels = _compute_sampled_logits(
      weights=weights,
      biases=biases,
      labels=labels,
      inputs=inputs,
      num_sampled=num_sampled,
      num_classes=num_classes,
      num_true=num_true,
      sampled_values=sampled_values,
      subtract_log_q=True,
      remove_accidental_hits=remove_accidental_hits,
      partition_strategy=partition_strategy,
      name=name)
 

上述函数的返回值shape为:[batch_size, num_true + num_sampled]即可能的class为: S i ∪ t i S_i \cup{t_i} Siti
而这个函数采样集合的代码如下:

sampled_values=candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,# 真实的label
          num_true=num_true,
          num_sampled=num_sampled, # 需要采样的子集大小
          unique=True,
          range_max=num_classes)

而这个函数主要是按照log-uniform distribution(Zipfian distribution)来采样出一个子集,Zipfian distribution
即Zipf法则,以下为Wikipedia关于Zipf’s law的解释:

Zipf’s law states that given some corpus of natural language utterances, the frequency of any word is inversely proportional to its rank in the frequency table.

你可能感兴趣的:(Machine,Learning,Deep,Learning)