Contrastive Search Decoding——一种对比搜索解码文本生成算法

目录

一、contrastive search decoding

二、代码实现理解和实验

1、代码走读

2、生成效果展示

3、方案的缺陷


最近在做文本生成相关的任务,调研的时候刷到一篇文本生成的论文:

《A Contrastive Framework for Neural Text Generation》

它认为GPT2生成模型再生成的token具有各异向性,使得token之间的相似性非常接近没有很好的区分度,最后解码的时候造成了文本重复——text degeneration;因此论文提出了一种新的训练策略(SimCTG)+解码算法(contrastive search),在多语言任务和实际的工业场景中进行人工评测,很显著的提升了文本生成的质量。关于该论文提出的text degeneration的原因知乎上有很多大佬和论文作者进行讨论和剖析,最后得出的结论是text degeneration的原因并不是SIMCTG提出的Contrastive Training,它并不能保证表征各向同质性,之所以在文本生成的质量上(少无意义的重复)有实实在在的提升,完全来自于新提出的解码策略——contrastive search decoding。既然这么有效的解码策略,是应该好好学习一下。

一、contrastive search decoding

这是一种非topK、topP以及BeamSearch的解码策略,感觉非常有意思。其核心思想就是对比——把当前要生成的token和已经生成的所有token做相似度计算,得到最大的相似度值;然后使得该token的概率与最大的相似度值的差值最大化的那个token就是我们要生成的token;具体的公式如下:

 V(k)是指token在模型输出的分布中top_k个最可能的结果,论文中提出K值通常设置3~10。看完公式觉得思想很简单,一下子就理解了公式要表达的思想,但是这里还是有几个值得注意的地方:

1、如何高效的得到当前token的embedding,也就是hv;以及如何得到h1,.....ht-1(已经生成的token的embedding)

2、如何高效的计算当前token的embedding和之前所有文本的embedding的相似度的最大值

3、如何计算整体上的最大值得到V(k)最佳的v

在问题1已经解决的情况下,2和3问题比较好解决,直接采用矩阵计算使用GPU并行计算,就可以很好的解决计算的效率问题;第一个问题理解起来有点点难,对于不太熟悉GPT2模型的人来说,确实不太好理解。本人再阅读起实现源码后,和作者沟通后,再加上对GPT2生成流程的理解后,才完全理解到底应该怎么求hv的。

Contrastive Search Decoding——一种对比搜索解码文本生成算法_第1张图片

 contrastive search decoding大体上的解码流程如上图所示,当前轮次文本输入gpt2模型,使用hm得到新的k个候选生成tokens;然后把这些tokens和之前的文本拼接起来输入到下一轮模型,得到hm+1。这里的hm+1就是前面说的上一轮应该生成的token的embedding,通过解码公式的计算,选出最佳的hm+1也就得到了tm+1——当前轮最佳的那个token。按照上述流程循坏下去就可以得到生成一个句子了。

二、代码实现理解和实验

1、代码走读

上面的核心思想简单的分析了,下面看看如何具体的使用代码实现,先上整体的实现代码,然后再慢慢解析:

def contrastive_search_decode(curr_input_tensor,attention_mask,tokenizer):
    """
    对比搜索解码策略
    """
    alpha = 0.5
    beam_width = 5
    generated = [item for item in curr_input_tensor.tolist()]
    past_key_values = None

    max_length = 64 + curr_input_tensor.shape[1]
    stop = False

    with torch.no_grad():
        for index in range(max_length):
            if index == 0:
                inputs = prepare_inputs_for_generation(curr_input_tensor, attention_mask, past=past_key_values)
                output = model(**inputs,return_dict = True,use_cache=True,output_hidden_states=True)
                past_key_values = output.past_key_values
                last_hidden_states = output.hidden_states[-1]  # [B, S, E]
                logit_for_next_step = output.logits[:, -1, :]  # [B, V]

            bsz, seqlen, embed_dim = last_hidden_states.size()

            next_probs = F.softmax(logit_for_next_step, dim=-1)
            _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K]
            top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids)  # [B, K]

            # compute new hidden
            past_key_values = enlarge_past_key_values(past_key_values, beam_width)
            output = model(
                input_ids=top_k_ids.view(-1, 1),
                attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),
                past_key_values=past_key_values,
                output_hidden_states=True,
                use_cache=True,
            )
            # past_key_values是一个二维list;里层list元素是tensor
            past_key_values = output.past_key_values
            logits = output.logits[:, -1, :]  # [B*K, V]
            next_hidden = output.hidden_states[-1]  # [B*K, 1, E]
            context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz * beam_width,seqlen,embed_dim)  # [B*K, S, E]

            selected_idx = ranking_fast(
                context_hidden,
                next_hidden,
                top_k_probs,  # [B, K]
                alpha,
                beam_width,
            )  # [B]

            # prepare for the next step
            next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1)  # [B, 1]
            temp = torch.split(next_hidden.squeeze(dim=1), beam_width)
            next_hidden = torch.stack(temp)  # [B, K, E]
            next_hidden = next_hidden[range(bsz), selected_idx, :]  # [B, E]
            last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)  # [B, S+1, E]
            past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
            temp = torch.split(logits, beam_width)
            logit_for_next_step = torch.stack(temp)[range(bsz), selected_idx, :]  # [B, V]

            tokens = next_id.squeeze(dim=-1).tolist()
            for idx, t in enumerate(tokens):
                generated[idx].append(t)

            for token in tokens:
                if token == 102:
                    stop = True
                    break
            if stop:
                break

    res = tokenizer.batch_decode(generated, skip_special_tokens=True)

说说几个细节

a、past_key_values扩充和压缩

由于每次需要传入past_key_values加快模型的推理速度,并且要在top_k中得到最佳的那个token,因此需要把K个token都要纳入计算中,为了能够矩阵计算需要把每次输入都扩充K倍:

past_key_values扩充

def enlarge_past_key_values(past_key_values, beam_width):
    # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
    new_key_values = []
    for layer in past_key_values:
        items = []
        for item in layer:
            # item is the key and value matrix
            bsz, num_head, seq_len, esz = item.size()
            item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz)    # [bsz*beam, num_head, seq_len, esz]
            items.append(item)
        new_key_values.append(items)
    return new_key_values

past_key_values中每个tensor的维度变化[B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]

past_key_values压缩

def select_past_key_values(past_key_values, beam_width, selected_idx):
    '''select_idx: [B]'''
    new_key_values = []
    for layer in past_key_values:
        items = []
        for item in layer:
            bsz_and_beam, num_head, seq_len, esz = item.size()
            bsz = int(bsz_and_beam//beam_width)
            temp = torch.split(item, beam_width, dim=0)
            item = torch.stack(temp)    # [B, K, num_head, seq_len, esz]
            item = item[range(bsz), selected_idx, :, :, :]   # [B, num_head, seq_len, esz]
            items.append(item)
        new_key_values.append(items)
    return new_key_values

past_key_values中每个tensor的维度从[B*K, num_head, seq_len, esz]变回到[B, num_head, seq_len, esz]

b、当前token和之前所有token的相似度并行计算

def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width):
    '''
        context_hidden: bsz*beam x seqlen x embed_dim
        next_hidden: bsz*beam x 1 x embed_dim
        next_top_k_probs: bsz x beam
    '''
    _, context_len, embed_dim = context_hidden.size()
    norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
    norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
    cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)    # [B*K, S]
    scores, _ = torch.max(cosine_matrix, dim=-1)    # [B*K]
    next_top_k_probs = next_top_k_probs.view(-1)    # [B*K]
    scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
    temp = torch.split(scores, beam_width)
    scores = torch.stack(temp)    # [B, K]
    selected_idx = scores.max(dim=-1)[1]    # [B]
    return selected_idx

需要注意到这里的torch.matmul()的计算

context_hidden:[B*K,S,D]

next_hidden:[B*K,1,D]

需要计算batch中每一条数据(每个token的embedding)和之前所有token的embedding的cos相似度

torch.matmul([B*K,S,D],B*K,1,D].T(2,1))=torch.matmul([B*K,S,D],B*K,D,1])=[B*K,S,1]

然后再求最大的那个score的index即可

2、生成效果展示

Contrastive Search Decoding——一种对比搜索解码文本生成算法_第2张图片

 Contrastive Search Decoding——一种对比搜索解码文本生成算法_第3张图片

 生成的语句还是比较流畅的,重复性得到改善,逻辑性这个是模型本身的问题;但是具体比之前采用beamsearch + sample效果具体能好多少,这边我没有做太多的验证,需要上线使用机器人聊一段时间才知道,不过beamsearch + sample在实际使用的时候就算加上了重复惩罚系数,生成的时候也会有部分重复的,生成例子:

现在财务下班了,财务下班了,明天下午到账

不是,我们不是一个公司的,不是一个公司的

好的,那我给您改一下。那我这边给您改一下

[让我看看][让我看看][让我看看][让我看看]

代理点:506经办200019经办200019经办200019经办

2000块钱,2000块钱,2000块,2000块钱,20002000块钱,2000200020

真实的contrastive search decoding效果,还有待观察,不过目前简单的测试几条来看生成还可以。

3、方案的缺陷

一般而言,我们都要求生成的句子具有多样性——有不同的生成,contrastive search decoding是一个确定性方案,每次只能生成固定的结果。这里作者有提出一个比较合适的方法:

就是先使用beamsearch + sample等方法生成部分句子,然后再使用contrastive search decoding对生成的句子进行补齐。

具体的实现不是特别困难,这里就不实现了。

还有一种方法,实现上比较麻烦,我也提一下思想:就是那个公式中选择v的时候,不选最大的那一个,多选择几个,但是要小于K值。

公式中的argmax 换成 top_n,n取2、3、4这种比K/2小的值感觉比较合适。

参考文章:

如何评价剑桥,腾讯, DeepMind以及港大团队新作 SimCTG ? - 王琰的回答 - 知乎

2022 - A Contrastive Framework for Neural Text Generation

你可能感兴趣的:(#,文本生成和翻译,人工智能,深度学习,机器学习,pytorch)