Beam search是一种动态规划算法,能够极大的减少搜索空间,增加搜索效率,并且其误差在可接受范围内,常被用于Sequence to Sequence模型,CTC解码等应用中
对于 T × N T\times N T×N的时间序列,如果我们要遍历所有可能能,则其所需的时间复杂度为 O ( N + N 2 + N 3 + . . . + N T ) \mathcal{O}(N+N^2+N^3+...+N^T) O(N+N2+N3+...+NT),在每一时间节点,所需遍历的节点数呈指数增加。对于Viterbi算法来说,时间复杂度为 O ( N + ( T − 1 ) N 2 ) \mathcal{O}(N+(T-1)N^2) O(N+(T−1)N2),在每个时间节点输入为N个best节点,需要比较的次数为 N 2 N^2 N2,然而这个时间复杂度还是太高。在N比较大的情况下,Beam Search为更好的选择,其时间复杂度为 O ( N + ( T − 1 ) ∗ b e a m s i z e ∗ N ) \mathcal{O}(N+(T-1)*beamsize*N) O(N+(T−1)∗beamsize∗N),每个时间节点的输入为beamsize个best节点,需要比较的次数为 b e a m s i z e ∗ N beamsize*N beamsize∗N
如上图所示,常规的beam search在每个时间节点,对输入的每个节点比较N次,并从 b e a m s i z e ∗ N beamsize*N beamsize∗N个比较结果中,选择 b e a m s i z e beamsize beamsize个结果作为下一时间节点的输入,其python的简单实现如下
import numpy as np
import math
def beam_search(nodes, topk=1):
# log-likelihood可以相加
paths = {'A':math.log(nodes[0]['A']), 'B': math.log(nodes[0]['B']), 'C':math.log(nodes[0]['C'])}
calculations = []
for l in range(1, len(nodes)):
# 拷贝当前路径
paths_ = paths.copy()
paths = {}
nows = {}
cur_cal = 0
for i in nodes[l].keys():
# 计算到达节点i的所有路径
for j in paths_.keys():
nows[j+i] = paths_[j]+math.log(nodes[l][i])
cur_cal += 1
calculations.append(cur_cal)
# 选择topk条路径
indices = np.argpartition(list(nows.values()), -topk)[-topk:]
# 保存topk路径
for k in indices:
paths[list(nows.keys())[k]] = list(nows.values())[k]
print(f'calculation number {calculations}')
return paths
nodes = [{'A':0.1, 'B':0.3, 'C':0.6}, {'A':0.2, 'B':0.4, 'C':0.4}, {'A':0.6, 'B':0.2, 'C':0.2},
{'A': 0.3, 'B': 0.3, 'C': 0.4}]
print(beam_search(nodes, topk=2))
输出结果:
calculation number [9, 6, 6]
{'CBAA': -3.1419147837320724, 'CBAC': -2.854232711280291, 'CCAC': -2.854232711280291}
我们可以看到,在 N = 3 N=3 N=3, b e a m s i z e = 2 beamsize=2 beamsize=2的情况下,每个节点的比较次数为6。
在CTC算法中,由于添加了blank以及重复字符串无blank合并的规则,例如ab
可能aab
,abb
,a blank b
等多种情况的输入,因此ab
的可能性应该为多种情况log概率之和,而不能通过单条beam进行搜索,因此可以采用改进版的prefix beam search,其代码如下
"""
Code from https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0
Author: Awni Hannun
CTC decoder in python, 简单例子可能不太效率
用于CTC模型的输出的前缀beam search
更多细节参考
https://distill.pub/2017/ctc/#inference
https://arxiv.org/abs/1408.2873
"""
import numpy as np
import math
import collections
NEG_INF = -float("inf")
def make_new_beam():
fn = lambda: (NEG_INF, NEG_INF)
return collections.defaultdict(fn)
def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max)
for a in args))
return a_max + lsp
def decode(probs, beam_size=100, blank=0):
"""
对给定输出概率进行预测
Arguments:
probs: 输出概率 (e.g. post-softmax) for each
time step. Should be an array of shape (time x output dim).
beam_size (int): Size of the beam to use during inference.
blank (int): Index of the CTC blank label.
Returns the output label sequence and the corresponding negative
log-likelihood estimated by the decoder.
"""
T, S = probs.shape
probs = np.log(probs)
# 在beam中的元素为(prefix, (p_blank, p_no_blank))
# 初始beam为空序列,第一个是前缀,第二个是后接blank的log概率,第三个是后接非blank的log概率
# 我们需要后接blank和后接非blank两种情况,来区分重复字符是否应该被合并,对于后接blank的情况,重复字符就不会被合并
beam = [(tuple(), (0.0, NEG_INF))]
for t in range(T): # 沿时间维度循环
# 存储下一个候选集的预设置字典,每次新的时间节点都会重设
next_beam = make_new_beam()
for s in range(S): # 沿词表维度循环
p = probs[t, s]
# p_b和p_nb分别为在当前时刻下前缀后接blank和非blank的log概率
for prefix, (p_b, p_nb) in beam: # 对beam进行循环
# 如果s为blank,那么前缀不会改变
# 因为后接的是blank,所以只需要更新前缀不变的情况下后接blank的log概率
if s == blank:
n_p_b, n_p_nb = next_beam[prefix]
n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
continue
# 记录前缀最后一个字符,用于判断当前字符与前缀最后一个字符是否相同
end_t = prefix[-1] if prefix else None
n_prefix = prefix + (s,) # n_prefix代表next prefix
n_p_b, n_p_nb = next_beam[n_prefix] # n_p_b代表 next probability of blank
# 将新的字符s加到prefix后面并将整体加入到beam中
# 因为后接的是非blank,所以只需要更新后接非blank的log概率
if s != end_t:
n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
else:
# 如果后接s是重复的,那么我们在更新后接非blank的log概率时,
# 不包括上一时刻后接非blank的概率。CTC算法会合并没有用blank分隔的重复字符
n_p_nb = logsumexp(n_p_nb, p_b + p)
# 这里是加入语言模型分数的好地方
next_beam[n_prefix] = (n_p_b, n_p_nb)
# 这是合并的情况,如果s重复出现了,前缀也不会改变,我们也更新前缀不变的情况下后接非blank的log概率
if s == end_t:
n_p_b, n_p_nb = next_beam[prefix]
n_p_nb = logsumexp(n_p_nb, p_nb + p)
next_beam[prefix] = (n_p_b, n_p_nb)
# 在进入下一时间步之前,排序并裁剪beam
beam = sorted(next_beam.items(),
key=lambda x: logsumexp(*x[1]),
reverse=True)
beam = beam[:beam_size]
best = beam[0]
return best[0], -logsumexp(*best[1])
if __name__ == "__main__":
np.random.seed(3)
time = 50
output_dim = 20
probs = np.random.rand(time, output_dim)
probs = probs / np.sum(probs, axis=1, keepdims=True)
labels, score = decode(probs)
print(labels)
print("Score {:.3f}".format(score))
与常规BS不同的地方主要在于, PBS区分了几种情况以及log probability的计算方式
BS根据不同的场景可以有不同的写法,其主要目的在于在每个时间点选择TOPK的路径继续搜索,达到增加搜索效率的目的,在BS的搜索过程中,如果是生成字符串,我们还可以加入语言模型的分数,得到更好的结果:
Y ∗ = l o g P ( Y ∣ X ) + α l o g ( P l m ( Y ) ) + β l e n ( Y ) Y^*=logP(Y|X)+\alpha log(P_{lm}(Y))+\beta len(Y) Y∗=logP(Y∣X)+αlog(Plm(Y))+βlen(Y)
语言模型的加入地方一般为字符串扩增时。
Sequence Modeling With CTC