在上一篇文章中我们具体探讨了Beam search的思想以及Beam search的大致工作流程。根据对Beam search的大致流程我们已经清楚了,在这我们来具体实现一下Beam search并应用在我们的seq2seq任务中。
堆是一种特殊的树形数据结构。堆分为大根堆和小根堆两种类型,其中:
堆的应用场景主要是以下两个:
1. 堆排序,完成升序或降序排列;
2. 优先级队列,其中元素按照优先级顺序排列,优先级越低越先出队。在每次插入元素时,堆会自动调整以确保最高(或最低)优先级的元素位于堆的根部。
我们通过构建堆来实现Beam search,主要流程:
1. 构造
2. 取出堆中的数据,开始forward操作,获取当前时间步的输出output、hidden;
3. 从output中选择top k个数据输出,做为下一个时间步的输入(其中Beam width = k);
4. 把下一个时间步需要的输入数据保存在一个新的堆中;
5. 获取新的堆中概率最大的数据,判断数据是否为
class Beam:
def __init__(self):
self.heap = list()
self.beam_width = 3
def add(self, probability, complete, seq, decoder_input, decoder_hidden):
"""
入队
:param probability: 概率乘积
:param complete: 句子是否输出完成
:param seq: 句子 包含token的list
:param decoder_input: 下一个时间步进行解码的输入
:param decoder_hidden: 下一个时间步进行解码的hidden
:return:
"""
heapq.heappush(self.heap, [probability, complete, seq, decoder_input, decoder_hidden])
# 如果数据的个数大于beam_width则弹出
if len(self.heap) > self.beam_width:
# heappop会根据优先级从小到大弹出,所以优先级最大的beam_widt会被保存在堆中
# 当两个元素的probability的优先级相同时,则根据complete优先级弹出
heapq.heappop(self.heap)
def __iter__(self):
return iter(self.heap)
现在我们完成了保存数据的数据结构。
在decoder中我们先定义一个函数处理序列
def _prepar_seq(self, seq):
"""去除seq中的和的token"""
if seq[0].item() == ws.SOS:
seq = seq[1:]
if seq[-1].item() == ws.EOS:
seq = seq[:-1]
seq = [i.item() for i in seq]
return seq
接下来在decoder中使用beam search
def beam_search(self, encoder_outputs, encoder_hidden):
"""使用堆来完成beam search
:param encoder_outputs: [batch_size, seq_len, encoder_hidden_size]
:param encoder_hidden: [1, batch_size, encoder_hidden_size]
"""
batch_size = encoder_hidden.size(1)
# 1. 构造第一次需要的输入数据,保存在堆中
decoder_input = torch.LongTensor([[ws.SOS]*batch_size]).to(device) # [batch_size, 1]
# 要输入的hidden
decoder_hidden = encoder_hidden
prev_beam = Beam()
prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden)
while True:
cur_beam = Beam()
# 2. 取出堆中的数据,进行forward_step操作,获得当前时间步的output, hidden
for _probability, _complete, _seq, _decoder_input, _decoder_hidden in prev_beam:
# 判断前一次的 _complete是否为True,如果是则不需要forward
# 有可能为True,但是概率并不是最大
if _complete == True:
cur_beam.add(_probability, _complete, _seq, _decoder_input, _decoder_hidden)
else:
# 需要进行forward操作
decoder_output_t, decoder_hidden = self.forward_step(_decoder_input, _decoder_hidden, encoder_outputs)
# 3. 从output中选择最大的beam width个输出,作为下一次的input
value, index = torch.topk(decoder_output_t, config.beam_width) # [batch_size, beam_width]
for m, n in zip(value[0], index[0]):
decoder_input = torch.LongTensor([[n]]).to(config.device)
seq = _seq + [n] # 更新句子序列
probability = _probability * m # 更新概率乘积
if n.item() == config.chatbot_ws_by_word_target.SOS:
complete = True
else:
complete = False
# 4. 把下个时间步需要的输入等数据保存在一个新的堆中
cur_beam.add(probability, complete, seq, decoder_input, decoder_hidden)
# 5. 获取新的堆中的优先级最高(概率最大)的数据,判断数据是否以EOS结尾或者是达到最大长度
# 若是则停止迭代
# 若不是则继续
best_prob, best_complete, best_seq, _, _ = max(cur_beam)
if best_complete == True or len(best_seq) - 1 == config.chatbot_target_max_seq_len + 1:
return self._perpar_seq(best_seq)
else:
prev_beam = cur_beam