beamsearch的计算过程和代码实现

Beam search(束搜索)是一种用于生成序列的搜索算法,常用于序列生成任务,例如机器翻译、语音识别和文本生成。它是一种启发式算法,旨在在生成序列时平衡搜索空间的广度和深度。

Beam search使用一个参数称为"beam width"(束宽度)来控制搜索的宽度,即在每个时间步骤选择保留的最有希望的候选项数量。在每个时间步骤,Beam search保留最有希望的K个候选项,其中K是束宽度。

下面是Beam search算法的详细步骤:

  1. 初始化:将初始输入作为序列的起始点,并将其放入候选项列表中。

  2. 生成候选项:对于每个候选项,使用模型(例如神经网络)生成下一个可能的元素或单词。

  3. 扩展候选项:对于每个候选项,将生成的元素添加到当前序列中,并计算相应的分数或概率。这些分数用于评估候选项的好坏。

  4. 剪枝:根据分数或概率对候选项进行排序,并选择当前分数最高的K个候选项,将其保留为下一步的候选项。

  5. 终止条件:如果生成的序列达到了预定的长度,或者满足特定的终止条件(例如遇到了终止标记),则停止搜索。

  6. 重复步骤2至5,直到到达终止条件。

  7. 返回结果:从最终的K个候选项中选择得分最高的序列作为最终的输出。

Beam search的优点是可以在生成序列时保持一定的多样性,因为它保留了多个候选项,并在每个时间步骤维护了一个较小的搜索空间。这有助于避免过于确定性的结果,并提供更多选择的可能性。

然而,Beam search也存在一些限制。它可能会陷入局部最优解,因为它只考虑了当前时间步骤的最有希望的候选项,并没有全局优化。此外,束宽度的选择也会影响结果,较小的束宽度可能会导致搜索空间不足,而较大的束宽度会增加计算成本。

算法关键点:在解码过程中,每次都挑选当前解码字的前k个最大概率的字符,第一轮可以得到k个结果,第二轮可以k2个结果,然后在这k2个结果中选择前k个最大概率的结果。依次类推...

具体步骤:

1.初始化Result列表用来存储每次得到的最大k个概率结果,初始化为[[list(),1]] 1为当前初始化的成绩

2.遍历解码长度S(解码出来S个字),

3.编历Result,用来为每个当前为止的最大k个结果解码出候选集

4.每个解码出的k个结果统一存储在Condidate列表中

5.按照成绩选取前k个作为Result,继续遍历,直到解码出S长度或者

from math import log
from numpy import array
from numpy import argmax

# 集束搜索
def beam_search_decoder(data, k):
	sequences = [[list(), 1.0]]#初始化存储最后结果的列表,存储k个
	# 遍历序列中的每一步
	for row in data:#序列的最大长度
		all_candidates = list()
		# 扩展每个候选项,即解码当前所得序列的下一个字
		for i in range(len(sequences)):
			seq, score = sequences[i]
			for j in range(len(row)):#计算每个词表中的字的成绩
				candidate = [seq + [j], score * -log(row[j])]
				all_candidates.append(candidate)
		# 根据分数排列所有候选项
		ordered = sorted(all_candidates, key=lambda tup:tup[1])
		# 选择k个最有可能的
		sequences = ordered[:k]
	return sequences

# 定义一个由10个单词组成的序列,单词来自于大小为5的词汇表
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
		[0.5, 0.4, 0.3, 0.2, 0.1],
		[0.1, 0.2, 0.3, 0.4, 0.5],
		[0.5, 0.4, 0.3, 0.2, 0.1],
		[0.1, 0.2, 0.3, 0.4, 0.5],
		[0.5, 0.4, 0.3, 0.2, 0.1],
		[0.1, 0.2, 0.3, 0.4, 0.5],
		[0.5, 0.4, 0.3, 0.2, 0.1],
		[0.1, 0.2, 0.3, 0.4, 0.5],
		[0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# 解码输出序列
result = beam_search_decoder(data, 3)
# 打印结果

for seq in result:
	print(seq)

你可能感兴趣的:(自然语言处理,自然语言处理,深度学习,人工智能,神经网络)