Beam Search与Prefix Beam Search的理解与python实现

引言

Beam search是一种动态规划算法,能够极大的减少搜索空间,增加搜索效率,并且其误差在可接受范围内,常被用于Sequence to Sequence模型,CTC解码等应用中

时间复杂度

对于 T × N T\times N T×N的时间序列,如果我们要遍历所有可能能,则其所需的时间复杂度为 O ( N + N 2 + N 3 + . . . + N T ) \mathcal{O}(N+N^2+N^3+...+N^T) O(N+N2+N3+...+NT),在每一时间节点,所需遍历的节点数呈指数增加。对于Viterbi算法来说,时间复杂度为 O ( N + ( T − 1 ) N 2 ) \mathcal{O}(N+(T-1)N^2) O(N+(T1)N2),在每个时间节点输入为N个best节点,需要比较的次数为 N 2 N^2 N2,然而这个时间复杂度还是太高。在N比较大的情况下,Beam Search为更好的选择,其时间复杂度为 O ( N + ( T − 1 ) ∗ b e a m s i z e ∗ N ) \mathcal{O}(N+(T-1)*beamsize*N) O(N+(T1)beamsizeN),每个时间节点的输入为beamsize个best节点,需要比较的次数为 b e a m s i z e ∗ N beamsize*N beamsizeN

常规Beam Search (BS)

Beam Search与Prefix Beam Search的理解与python实现_第1张图片
如上图所示,常规的beam search在每个时间节点,对输入的每个节点比较N次,并从 b e a m s i z e ∗ N beamsize*N beamsizeN个比较结果中,选择 b e a m s i z e beamsize beamsize个结果作为下一时间节点的输入,其python的简单实现如下

import numpy as np
import math

def beam_search(nodes, topk=1):
    # log-likelihood可以相加
    paths = {'A':math.log(nodes[0]['A']), 'B': math.log(nodes[0]['B']), 'C':math.log(nodes[0]['C'])}
    calculations = []
    for l in range(1, len(nodes)):
        # 拷贝当前路径
        paths_ = paths.copy()
        paths = {}
        nows = {}
        cur_cal = 0
        for i in nodes[l].keys():
            # 计算到达节点i的所有路径
            for j in paths_.keys():
                nows[j+i] = paths_[j]+math.log(nodes[l][i])
                cur_cal += 1
        calculations.append(cur_cal)
        # 选择topk条路径
        indices = np.argpartition(list(nows.values()), -topk)[-topk:]
        # 保存topk路径
        for k in indices:
            paths[list(nows.keys())[k]] = list(nows.values())[k]

    print(f'calculation number {calculations}')
    return paths


nodes = [{'A':0.1, 'B':0.3, 'C':0.6}, {'A':0.2, 'B':0.4, 'C':0.4}, {'A':0.6, 'B':0.2, 'C':0.2},
         {'A': 0.3, 'B': 0.3, 'C': 0.4}]
print(beam_search(nodes, topk=2))

输出结果:
calculation number [9, 6, 6]
{'CBAA': -3.1419147837320724, 'CBAC': -2.854232711280291, 'CCAC': -2.854232711280291}

我们可以看到,在 N = 3 N=3 N=3, b e a m s i z e = 2 beamsize=2 beamsize=2的情况下,每个节点的比较次数为6。

Prefix(前缀)Beam Search (PBS)

在CTC算法中,由于添加了blank以及重复字符串无blank合并的规则,例如ab可能aab,abb,a blank b等多种情况的输入,因此ab的可能性应该为多种情况log概率之和,而不能通过单条beam进行搜索,因此可以采用改进版的prefix beam search,其代码如下

"""
Code from https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0
Author: Awni Hannun
CTC decoder in python, 简单例子可能不太效率
用于CTC模型的输出的前缀beam search
更多细节参考
  https://distill.pub/2017/ctc/#inference
  https://arxiv.org/abs/1408.2873
"""

import numpy as np
import math
import collections

NEG_INF = -float("inf")


def make_new_beam():
    fn = lambda: (NEG_INF, NEG_INF)
    return collections.defaultdict(fn)


def logsumexp(*args):
    """
    Stable log sum exp.
    """
    if all(a == NEG_INF for a in args):
        return NEG_INF
    a_max = max(args)
    lsp = math.log(sum(math.exp(a - a_max)
                       for a in args))
    return a_max + lsp


def decode(probs, beam_size=100, blank=0):
    """
    对给定输出概率进行预测
    Arguments:
        probs: 输出概率 (e.g. post-softmax) for each
          time step. Should be an array of shape (time x output dim).
        beam_size (int): Size of the beam to use during inference.
        blank (int): Index of the CTC blank label.
    Returns the output label sequence and the corresponding negative
    log-likelihood estimated by the decoder.
    """
    T, S = probs.shape
    probs = np.log(probs)

    # 在beam中的元素为(prefix, (p_blank, p_no_blank))
    # 初始beam为空序列,第一个是前缀,第二个是后接blank的log概率,第三个是后接非blank的log概率
    # 我们需要后接blank和后接非blank两种情况,来区分重复字符是否应该被合并,对于后接blank的情况,重复字符就不会被合并
    beam = [(tuple(), (0.0, NEG_INF))]

    for t in range(T):  # 沿时间维度循环

        # 存储下一个候选集的预设置字典,每次新的时间节点都会重设
        next_beam = make_new_beam()

        for s in range(S):  # 沿词表维度循环
            p = probs[t, s]

            # p_b和p_nb分别为在当前时刻下前缀后接blank和非blank的log概率
            for prefix, (p_b, p_nb) in beam:  # 对beam进行循环

                # 如果s为blank,那么前缀不会改变
                # 因为后接的是blank,所以只需要更新前缀不变的情况下后接blank的log概率
                if s == blank:
                    n_p_b, n_p_nb = next_beam[prefix]
                    n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
                    next_beam[prefix] = (n_p_b, n_p_nb)
                    continue

                # 记录前缀最后一个字符,用于判断当前字符与前缀最后一个字符是否相同
                end_t = prefix[-1] if prefix else None
                n_prefix = prefix + (s,)  # n_prefix代表next prefix
                n_p_b, n_p_nb = next_beam[n_prefix]  # n_p_b代表 next probability of blank
                # 将新的字符s加到prefix后面并将整体加入到beam中
                # 因为后接的是非blank,所以只需要更新后接非blank的log概率
                if s != end_t:
                    n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
                else:
                    # 如果后接s是重复的,那么我们在更新后接非blank的log概率时,
                    # 不包括上一时刻后接非blank的概率。CTC算法会合并没有用blank分隔的重复字符
                    n_p_nb = logsumexp(n_p_nb, p_b + p)

                # 这里是加入语言模型分数的好地方
                next_beam[n_prefix] = (n_p_b, n_p_nb)

                # 这是合并的情况,如果s重复出现了,前缀也不会改变,我们也更新前缀不变的情况下后接非blank的log概率
                if s == end_t:
                    n_p_b, n_p_nb = next_beam[prefix]
                    n_p_nb = logsumexp(n_p_nb, p_nb + p)
                    next_beam[prefix] = (n_p_b, n_p_nb)

        # 在进入下一时间步之前,排序并裁剪beam
        beam = sorted(next_beam.items(),
                      key=lambda x: logsumexp(*x[1]),
                      reverse=True)
        beam = beam[:beam_size]

    best = beam[0]
    return best[0], -logsumexp(*best[1])


if __name__ == "__main__":
    np.random.seed(3)

    time = 50
    output_dim = 20

    probs = np.random.rand(time, output_dim)
    probs = probs / np.sum(probs, axis=1, keepdims=True)

    labels, score = decode(probs)
    print(labels)
    print("Score {:.3f}".format(score))

与常规BS不同的地方主要在于, PBS区分了几种情况以及log probability的计算方式

  1. 对于BS来说, l o g l i k e l i h o o d = l o g ( p 1 ) + l o g ( p 2 ) + . . . loglikelihood=log(p1)+log(p2)+... loglikelihood=log(p1)+log(p2)+...,对于PBS来说,由于区分了存在blank和不存在blank的情况,并且其中之一的可能性为0,相加log probability等于负无穷的情况,因此不能直接相加,所以采用了一种稳定的logsumexp的方式来计算loglikelihood
  2. 当前缀后接blank时,前缀不变,更新当前前缀后接blank的log概率:
    n _ p _ b = l o g s u m e x p ( n _ p _ b , p _ b + p , p _ n b + p ) n\_p\_b = logsumexp(n\_p\_b, p\_b + p, p\_nb + p) n_p_b=logsumexp(n_p_b,p_b+p,p_nb+p)
  3. 当前缀后接重复字符且中间没有blank隔开时,前缀也不变,更新当前前缀后接非blank的log概率:
    n _ p _ n b = l o g s u m e x p ( n _ p _ n b , p _ n b + p ) n\_p\_nb = logsumexp(n\_p\_nb, p\_nb + p) n_p_nb=logsumexp(n_p_nb,p_nb+p)
  4. 当前缀后接不同字符时,前缀变化,更新当前前缀后接非blank的log概率:
    n _ p _ n b = l o g s u m e x p ( n _ p _ n b , p _ b + p , p _ n b + p ) n\_p\_nb = logsumexp(n\_p\_nb, p\_b + p, p\_nb + p) n_p_nb=logsumexp(n_p_nb,p_b+p,p_nb+p)
  5. 当前缀后接重复字符,且中间有blank隔开,前缀变化,更新当前前缀后接非blank的log概率:
    n _ p _ n b = l o g s u m e x p ( n _ p _ n b , p _ b + p ) n\_p\_nb = logsumexp(n\_p\_nb, p\_b + p) n_p_nb=logsumexp(n_p_nb,p_b+p)

总结

BS根据不同的场景可以有不同的写法,其主要目的在于在每个时间点选择TOPK的路径继续搜索,达到增加搜索效率的目的,在BS的搜索过程中,如果是生成字符串,我们还可以加入语言模型的分数,得到更好的结果:
Y ∗ = l o g P ( Y ∣ X ) + α l o g ( P l m ( Y ) ) + β l e n ( Y ) Y^*=logP(Y|X)+\alpha log(P_{lm}(Y))+\beta len(Y) Y=logP(YX)+αlog(Plm(Y))+βlen(Y)
语言模型的加入地方一般为字符串扩增时。

参考

Sequence Modeling With CTC

你可能感兴趣的:(algorithm,python,自然语言处理,语音识别)