当我们训练完成一个自然语言生成模型后,需要使用这个模型生成新的语言(句子),如何生成这些句子,使用如下的方法:贪心搜索,集束搜索,随机搜索。
贪心搜索最为简单,直接选择每个输出的最大概率,直到出现终结符或最大句子长度。
在每个阶段都选择分值最高的项。此方法经常奏效,但显然不是最优的。
集束搜索是一种启发式图搜索算法,在图的解空间比较大的情况下,为了减少搜索所占用的空间和时间,在每一步深度扩展的时候,剪掉一些质量比较差的结点,保留下一些质量较高的结点。
在sequence2sequence模型[中,beam search的方法只用在测试的情况(decoder解码的时候),因为在训练过程中,每一个decoder的输出是有正确答案的,也就不需要beam search去加大输出的准确率。
具体过程为:使用广度优先策略在树的每一层建立搜索树,按照启发代价对节点进行排序,然后仅留下预先确定的个数(Beam Width-集束宽度)的节点,仅这些节点在下一层次继续扩展,其他节点就被剪掉了。(注意:如果集束宽度无穷大,那该搜索就是宽度优先搜索)
好处:减少了空间消耗,并提高了时间效率。
假设字典为[a,b,c],beam size选择2,则如下图有:
一个正确且高效的算法需要处理的问题大概有两个:
充分利用硬件,可以处理批量数据,且尽量使用并行计算少用循环
处理好长短不同的生成结果
序列扩展的每一步迭代过程如图所示, 使用束搜索找出下一步分数最大的beam个
基础代码实现如下:
十分钟读懂Beam Search(1/2)
论文: The Curious Case of Neural Text Degeneration
随机采样是一种对Beam search进行改进的尝试。
解码过程用随机采样(sampling)代替取概率最大的词。采样的依据就是解码器输出的词典中每个词的概率分布。相比于按概率“掐尖”,这样会增大所选词的范围,引入更多的随机性。这个方法是谷歌开放式聊天机器人Meena[DialoGPT、Meena]采用的方式。当时那篇论文的结论就是这种随机采样的方法远好于Beam Search。但这其实也是有条件的,随机采样容易产生前后不一致的问题。而在开放闲聊领域,生成文本的长度都比较短,这种问题就被自然的淡化了。
采样的时候有一个可以控制的超参数,称为温度(temperature, T)。解码器的输出层后面通常会跟一个softmax函数来将输出概率归一化,通过改变T可以控制概率的形貌。softmax的公式如下,当T大的时候,概率分布趋向平均,随机性增大;当T小的时候,概率密度趋向于集中,即强者愈强,随机性降低,会更多地采样出“放之四海而皆准”的词汇。
这个方法就是在采样前将输出的概率分布截断,取出概率最大的k个词构成一个集合,然后将这个子集词的概率再归一化,最后从新的概率分布中采样词汇。这个办法据说可以获得比Beam Search好很多的效果,但也有一个问题,就是这个k不太好选。因为这个概率分布变化比较大,有时候可能很均匀(flat),有的时候比较集中(peaked)。对于集中的情况还好说,当分布均匀时,一个较小的k容易丢掉很多优质候选词。但如果k定的太大,这个方法又会退化回普通采样。
top-p相比前面那些都更好的采样方式,他不再取一个固定的k,而是固定候选集合的概率密度和在整个概率分布中的比例。也就是构造一个最小候选集,使得
选出来这个集合之后也和top-k采样一样,重新归一化集合内词的概率,并把集合外词的概率设为0。
为了解决重复问题,还可以通过惩罚因子将出现过词的概率变小或者强制不使用重复词来解决。惩罚因子来自于同样广为流传的《CTRL: A Conditional Transformer Language Model for Controllable Generation》。
Top_k, Top_p代码实现
import torch
def top_k_top_p_filtering(logits: torch.FloatTensor, top_k: int = 0, top_p: float = 1.0,
filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1) -> torch.FloatTensor:
'''
logits: logit分布的shape (batch_size, vocabulary size)
'''
scores = logits
if top_k > 0:
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
indices_to_remove = scores < torch.topk(scores, top_k)[0][:, -1, None] #None方便扩展维度而不该改变数据排列顺序
#结果为Bool类型表示是否大于第k个值
scores = scores.masked_fill(indices_to_remove, filter_value)
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(scores, descending = True) #降序排列
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) #softmax操作后,累计计算概率分布
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
#保持至少有min_tokens_to_keep个单词可选
sorted_indices_to_remove[:, :min_tokens_to_keep - 1] = 0
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, filter_value)
return scores
x = torch.randn(4, 8)
print(x)
print(top_k_top_p_filtering(x, top_p = 0.6, min_tokens_to_keep = 2))
Module | Complexity | # Parameters |
Self-Attention | ||
FFN |
参考:
集束搜索(beam search)和贪心搜索(greedy search)
自然语言生成-集束搜索beam search和随机搜索random search
十分钟读懂Beam Search(1/2)
十分钟读懂Beam Search 2