上一节介绍了如何训练输入和输出均为不定长序列的编码器—解码器。本节我们介绍如何使用编码器—解码器来预测不定长的序列。
上一节里已经提到,在准备训练数据集时,我们通常会在样本的输入序列和输出序列后面分别附上一个特殊符号"
让我们先来看一个简单的解决方案:贪婪搜索(greedy search)。对于输出序列任一时间步 t ′ t' t′,我们从 ∣ Y ∣ |\mathcal{Y}| ∣Y∣个词中搜索出条件概率最大的词
y t ′ = argmax y ∈ Y P ( y ∣ y 1 , … , y t ′ − 1 , c ) y _ { t ^ { \prime } } = \underset { y \in \mathcal { Y } } { \operatorname { argmax } } P \left( y | y _ { 1 } , \ldots , y _ { t ^ { \prime } - 1 } , c \right) yt′=y∈YargmaxP(y∣y1,…,yt′−1,c)
作为输出。一旦搜索出"
我们在描述解码器时提到,基于输入序列生成输出序列的条件概率是 ∏ t ′ = 1 T ′ P ( y t ′ ∣ y 1 , … , y t ′ − 1 , c ) \prod_{t'=1}^{T'} P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \boldsymbol{c}) ∏t′=1T′P(yt′∣y1,…,yt′−1,c)。我们将该条件概率最大的输出序列称为最优输出序列。而贪婪搜索的主要问题是不能保证得到最优输出序列。
下面来看一个例子。假设输出词典里面有“A”“B”“C”和“
接下来,观察图10.10演示的例子。与图10.9中不同,图10.10在时间步2中选取了条件概率第二大的词“C”。由于时间步3所基于的时间步1和2的输出子序列由图10.9中的“A”“B”变为了图10.10中的“A”“C”,图10.10中时间步3生成各个词的条件概率发生了变化。我们选取条件概率最大的词“B”。此时时间步4所基于的前3个时间步的输出子序列为“A”“C”“B”,与图10.9中的“A”“B”“C”不同。因此,图10.10中时间步4生成各个词的条件概率也与图10.9中的不同。我们发现,此时的输出序列“A”“C”“B”“
如果目标是得到最优输出序列,我们可以考虑穷举搜索(exhaustive search):穷举所有可能的输出序列,输出条件概率最大的序列。
虽然穷举搜索可以得到最优输出序列,但它的计算开销 O ( ∣ Y ∣ T ′ ) \mathcal{O}(\left|\mathcal{Y}\right|^{T'}) O(∣Y∣T′)很容易过大。例如,当 ∣ Y ∣ = 10000 |\mathcal{Y}|=10000 ∣Y∣=10000且 T ′ = 10 T'=10 T′=10时,我们将评估 1000 0 10 = 1 0 40 10000^{10} = 10^{40} 1000010=1040个序列:这几乎不可能完成。而贪婪搜索的计算开销是 O ( ∣ Y ∣ T ′ ) \mathcal{O}(\left|\mathcal{Y}\right|T') O(∣Y∣T′),通常显著小于穷举搜索的计算开销。例如,当 ∣ Y ∣ = 10000 |\mathcal{Y}|=10000 ∣Y∣=10000且 T ′ = 10 T'=10 T′=10时,我们只需评估 10000 × 10 = 1 0 5 10000\times10=10^5 10000×10=105个序列。
束搜索(beam search)是对贪婪搜索的一个改进算法。它有一个束宽(beam size)超参数。我们将它设为 k k k。在时间步1时,选取当前时间步条件概率最大的 k k k个词,分别组成 k k k个候选输出序列的首词。在之后的每个时间步,基于上个时间步的 k k k个候选输出序列,从 k ∣ Y ∣ k\left|\mathcal{Y}\right| k∣Y∣个可能的输出序列中选取条件概率最大的 k k k个,作为该时间步的候选输出序列。最终,我们从各个时间步的候选输出序列中筛选出包含特殊符号“
图10.11通过一个例子演示了束搜索的过程。假设输出序列的词典中只包含5个元素,即 Y = { A , B , C , D , E } \mathcal{Y} = \{A, B, C, D, E\} Y={A,B,C,D,E},且其中一个为特殊符号“
在最终候选输出序列的集合中,我们取以下分数最高的序列作为输出序列:
1 L α log P ( y 1 , … , y L ) = 1 L α ∑ t ′ = 1 L log P ( y t ′ ∣ y 1 , … , y t ′ − 1 , c ) , \frac{1}{L^\alpha} \log P(y_1, \ldots, y_{L}) = \frac{1}{L^\alpha} \sum_{t'=1}^L \log P(y_{t'} \mid y_1, \ldots, y_{t'-1}, \boldsymbol{c}), Lα1logP(y1,…,yL)=Lα1t′=1∑LlogP(yt′∣y1,…,yt′−1,c),
其中 L L L为最终候选序列长度, α \alpha α一般可选为0.75。分母上的 L α L^\alpha Lα是为了惩罚较长序列在以上分数中较多的对数相加项。分析可知,束搜索的计算开销为 O ( k ∣ Y ∣ T ′ ) \mathcal{O}(k\left|\mathcal{Y}\right|T') O(k∣Y∣T′)。这介于贪婪搜索和穷举搜索的计算开销之间。此外,贪婪搜索可看作是束宽为1的束搜索。束搜索通过灵活的束宽 k k k来权衡计算开销和搜索质量。
注:本节与原书基本相同,原书传送门