【技术记录1】:Prefixtunning在BART中是如何实现的?

浅浅的记录一次在实验中遇到的坑,这里主要记录一下prefixtunning在BART中是如何实现的。

先看看其他人是如何实现的,我目前在网上找到的代码包括openprompt版本[1]的代码和prefixtunning作者自己上传到github上的代码[2]。

1、https://github.com/thunlp/OpenPrompt

2、https://github.com/XiangLi1999/PrefixTuning

这两个代码都存在一定的问题。作者这个源代码写的太乱了,加了很多if,我实在是看不清楚,而且作者也是修改了huggingface的源代码,但具体改了哪里也没说明(也可能我没注意到?)。openprompt这个代码就简洁明了很多了,但这个代码是by-case的,也就是没有实现针对BART的prefixtunning,实现的是针对T5和GPT2的prefixtunnig。

但是观察两个代码的共同点,这两个代码都用到了transformers库里的一个特定的参数:

past_key_values

我们首先看一下官方文档里对这个参数的解释,

【技术记录1】:Prefixtunning在BART中是如何实现的?_第1张图片

 简单翻译一下,也就是这个参数是用来通过添加自定义的key和value值加速解码过程的。实话实说,这个解释看完后我一头雾水,这跟prefixtunning有什么关系呢。prefixtunning的思想是加入一组可以微调的prompts,然后在训练的时候冻结预训练语言模型的参数只训练这组参数。

后来看到记录一次对past_key_values用法的理解 - 知乎这篇博客后,就恍然大明白了。重点在pre-computed这句话。prefixtunning原论文对他们加入的可训练参数是怎么具体实现的写的特别隐晦:

【技术记录1】:Prefixtunning在BART中是如何实现的?_第2张图片

原论文用Acitivation一笔带过了,实际上就是通过自己生成一组key和value然后传入了模型中,具体实现如下:

class PrefixEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Embedding(seq_len, dim_ebd)
        self.trans = torch.nn.Sequential(
            torch.nn.Linear(dim_ebd, number),
            torch.nn.Tanh(),
            torch.nn.Linear(number, num_layer * n * dim_ebd)
        ).to(device)
    def forward(self, prefix):
        prefix_tokens = self.embedding(prefix)
        past_key_values = self.trans(prefix_tokens)
        return past_key_values

这里seq_len就是prefix的长度,dim_ebd是模型的维度(bart-base是768)。number可以自己定义,越大参数量就越大(相对应的训练时间就会越长,但这跟提示学习的思想就相违背了),n取决于具体哪个模型,对于bart来说应该是4。

实际上到这步应该就结束了,但是问题就是past_key_values原本不是用来做prefixtunning的,是用来加速decoding的,这带来了第一个大坑:也就是模型训练的时候用不上这个参数(因为是teacher-forcing),但是在模型做自回归生成的时候tansformers自己会调用这个参数。你如果用model.generate()直接尝试解码是会报错的。这里有两种解决方法:

1、自己写decoding的代码,反正的decoding是一个很模块化的过程,就可以调用各种包。

2、参考openprompt的方法,对transformers源代码进行修修补补。但是我觉得这种做法挺难的。

另一个大坑是BART独有的,我们具体看一下past_key_value在bart里是怎么传递的。

首先从BartForConditionalGeneartion传到BartModel:(self.model)——

——>然后从BartModel传到BartDecoder:(self.decoder)——

——>从BartDecoder传到BartDecoderLayer:(decoder_layer)——>

然后就是关键了,这里开始对past_key_values里的值进行操作了:

首先模型会把past_key_values(一个长度为4的tuple)对半分开,分别给self attention和cross attention使用

 这两个会分别送入两个BartAttetion模块中,

【技术记录1】:Prefixtunning在BART中是如何实现的?_第3张图片

 这一部分非常关键!其实看到这里也就明白了,由于在cross attention中,key_states和value_states的值是通过你传入的参数直接赋值的,而不是像self attention一样是拼接的,这会导致模型学习的时候学习不到编码器的表示,换言之,本来应该传递给解码器的编码器表示被忽略掉了。具体在val的时候,就会发现模型对任何不同的输入输出都是一样的。

解决办法也比较简单,直接把第一个if删掉就好了,然后传past_key_value的时候不要传后两个tuple了。


2022-7-23 更新

清华大学的opendelta项目上也实现了BART做prefixtunning的代码,google一下就能看到了。

你可能感兴趣的:(自然语言处理,python)