Beam search 算法的通俗理解

Beam search 算法在文本生成中用得比较多,用于选择较优的结果(可能并不是最优的)。接下来将以seq2seq机器翻译为例来说明这个Beam search的算法思想。
在机器翻译中,beam search算法在测试的时候用的,因为在训练过程中,每一个decoder的输出是有与之对应的正确答案做参照,也就不需要beam search去加大输出的准确率。
有如下从中文到英语的翻译:
中文:

我 爱 学习,学习 使 我 快乐

英语:

I love learning, learning makes me happy

在这个测试中,中文的词汇表是{我,爱,学习,使,快乐},长度为5。英语的词汇表是{I, love, learning, make, me, happy}(全部转化为小写),长度为6。那么首先使用seq2seq中的编码器对中文序列(记这个中文序列为 X X X)进行编码,得到语义向量 C C C
Beam search 算法的通俗理解_第1张图片
得到语义向量 C C C后,进入解码阶段,依次翻译成目标语言。在正式解码之前,有一个参数需要设置,那就是beam search中的beam size,这个参数就相当于top-k中的k,选择前k个最有可能的结果。在本例中,我们选择beam size=3。

来看解码器的第一个输出 y 1 y_1 y1,在给定语义向量 C C C的情况下,首先选择英语词汇表中最有可能k个单词,也就是依次选择条件概率 P ( y 1 ∣ C ) P(y_1|C) P(y1C)前3大对应的单词,比如这里概率最大的前三个单词依次是 I I I l e a r n i n g learning learning h a p p y happy happy

接着生成第二个输出 y 2 y_2 y2,在这个时候我们得到了那些东西呢,首先我们得到了编码阶段的语义向量 C C C,还有第一个输出 y 1 y_1 y1。此时有个问题, y 1 y_1 y1有三个,怎么作为这一时刻的输入呢(解码阶段需要将前一时刻的输出作为当前时刻的输入),答案就是都试下,具体做法是:

  • 确定 I I I为第一时刻的输出,将其作为第二时刻的输入,得到在已知 ( C , I ) (C, I) (C,I)的条件下,各个单词作为该时刻输出的条件概率 P ( y 2 ∣ C , I ) P(y_2|C,I) P(y2C,I),有6个组合,每个组合的概率为 P ( I ∣ C ) P ( y 2 ∣ C , I ) P(I|C)P(y_2|C, I) P(IC)P(y2C,I)
  • 确定 l e a r n i n g learning learning为第一时刻的输出,将其作为第二时刻的输入,得到该条件下,词汇表中各个单词作为该时刻输出的条件概率 P ( y 2 ∣ C , l e a r n i n g ) P(y_2|C, learning) P(y2C,learning),这里同样有6种组合;
  • 确定 h a p p y happy happy为第一时刻的输出,将其作为第二时刻的输入,得到该条件下各个单词作为输出的条件概率 P ( y 2 ∣ C , h a p p y ) P(y_2|C, happy) P(y2C,happy),得到6种组合,概率的计算方式和前面一样。

这样就得到了18个组合,每一种组合对应一个概率值 P ( y 1 ∣ C ) P ( y 2 ∣ C , y 1 ) P(y_1|C)P(y_2|C, y_1) P(y1C)P(y2C,y1),接着在这18个组合中选择概率值top3的那三种组合,假设得到 I l o v e I love Ilove I h a p p y I happy Ihappy l e a r n i n g m a k e learning make learningmake
接下来要做的重复这个过程,逐步生成单词,直到遇到结束标识符停止。最后得到概率最大的那个生成序列。其概率为:
P ( Y ∣ C ) = P ( y 1 ∣ C ) P ( y 2 ∣ C , y 1 ) , . . . , P ( y 6 ∣ C , y 1 , y 2 , y 3 , y 4 , y 5 ) P(Y|C)=P(y_1|C)P(y_2|C,y_1),...,P(y_6|C,y_1,y_2,y_3,y_4,y_5) P(YC)=P(y1C)P(y2C,y1),...,P(y6C,y1,y2,y3,y4,y5)
以上就是Beam search算法的思想,当beam size=1时,就变成了贪心算法。

Beam search算法也有许多改进的地方,根据最后的概率公式可知,该算法倾向于选择最短的句子,因为在这个连乘操作中,每个因子都是小于1的数,因子越多,最后的概率就越小。解决这个问题的方式,最后的概率值除以这个生成序列的单词数(记生成序列的单词数为 N N N),这样比较的就是每个单词的平均概率大小。
此外,连乘因子较多时,可能会超过浮点数的最小值,可以考虑取对数来缓解这个问题。

参考文献:
吴恩达-《序列模型》课程

你可能感兴趣的:(常见算法,搜索算法)