RWKV:一种鱼和熊掌兼得的线性transformer模型

众所周知,现在transformer及其变种是NLP和CV领域已经杀疯了。但其中最核心的self-attention机制因为其O(N2)的时间复杂度(二次依赖问题)被诟病。

在不改变transformer block这个整体架构的前提下,现在学术界解决二次依赖问题的主要是两个思路。一种是实现self-attention的线性化。这方面的工作是很多的,比如Performer[5]、Reformer[6]、Linformer[7]、Nyströmformer[9]、AdaMRA[10]等。关于这部分工作更多的内容大家可以在苏剑林的博客中了解到[8].虽然关于线性attention的工作很多,但参考AdaMRA[10]论文的图。只有Nyströmformer[9]和AdaMRA[10]相较于Transformer能获得速度和效果的双重提升,其他的大多需要付出效果的代价才能获取一定的速度提升。但就是这哥俩由于用了平均池化作为特征聚类,因此无法mask未来信息从而丧失了自回归的能力。因此通过替换线性attention从而提升transformer速度这一思路是必须付出代价的。

RWKV:一种鱼和熊掌兼得的线性transformer模型_第1张图片

另一种思路将self-attention换成其他线性复杂度的部件。比如前段时间谷歌发现用膨胀卷积取代self-attention也能取到不错的效果[1]。而在CV领域杀疯的MLP-Mixer[2],兼具CV和NLP能力的gMLP、aMLP,[3]MLP-Mixer的NLP版本Synthesizer[4]。但都有或多或少的缺点,就比如Synthesizer和gMLP在NLP领域相较于self-attention还是差了点的。而aMLP虽然效果好了吧,但其实还是要用到self-attention,提速的目的还是没达到。不过今年暑假那会,苹果提出的AFT模型[11]号称自己是最快的transformer模型。

上述是标准AFT的公式,其中σ是sigmoid函数,QKV就是sefl-attention的那一套,w是一个训练出来的参数矩阵。不难看出AFT是通过点乘的方式实现的注意力,在做自回归时只需要对W矩阵进行mask即可。并且W矩阵是自带位置信息的,不仅解决了部分线性attention不能做自回归的问题,还顺便把transformer里位置编码的问题给解决了。可以说AFT实现了一举三得。但成也萧何败也萧何,W矩阵是AFT成功的核心也是AFT的最大缺点。一般来说W应该是一个[max_len,max_len]大小的方阵。换而言之AFT所能处理的文本长度受限于W矩阵的大小,如果想要处理一万字的长文本,W矩阵的参数量就快赶上Bert了。为了解决这个问题,下面该本文的主角RWKV出场了。RWKV的原文在RWKV is all you need?一种新语言模型,改进 Transformer - 知乎,不过原文实在过于简短了不便阅读和理解。因此笔者写了此文介绍一下RWKV是怎么实现鱼和熊掌兼得的。

RWKV

整体结构 RWKV的整体结构依然采用的是transformer block的思路,其整体结构如图所示。相较于原始transformer block的结构,RWKV将self-attention替换为Position Encoding和TimeMix,将FFN替换为ChannelMix。其余部分与transfomer一致的。

RWKV:一种鱼和熊掌兼得的线性transformer模型_第2张图片

Position Matrix RWKV采用的位置编码类似于AliBi编码[12]的形式。原文作者并没有给他的位置编码命名,为了便于介绍参考该位置编码主要考虑距离衰减的特性,本文将其命名为distance编码。对于第i个head的第j个token而言,其位置编码如下述公式所示。其中nhead表示头的数量,max_len表示为所允许的最大长度。

RWKV:一种鱼和熊掌兼得的线性transformer模型_第3张图片

目前学术界的主流观点是RNN结构是天然的时序结构,不需要transformer模型必须的位置编码。而如果我们查看RNN的计算流程,可以发现RNN只考虑到当前token及之前的信息,而随着距离的延长前面的信息会逐渐减少。而distance位置编码便是参考RNN时序特点所设计的。

不过RWKV模型中,不会直接对输入的X进行上述计算。而是得到类似AFT中的W矩阵参与后续Time-Mix计算。其中W矩阵的形状为[n_head,seq_len,seq_len]。因此对于W矩阵中的而言,其数值如下述公式所示。

RWKV:一种鱼和熊掌兼得的线性transformer模型_第4张图片

从这里不难看出,AFT中的W矩阵在RWKV中是通过公式得到而不是训练得到的,因此解决了AFT中无法解决长文本,或解决长文本时参数爆炸的问题。

当然,在处理的任务文本长度有限的情况下。比如机器翻译,或者是RWKV目前应用的ai写小说这类应用场景。在这类应用场景中,由于不会面临长文本的情况,因此可以为W矩阵添加更多的位置信息。参考公式如下

其中和分别为形状[n_head,seq_len,1]和[n_head,1,seq_len]的向量,在初始化时为全1矩阵。即将作为W矩阵的初始化。结合该步后,在形式上W矩阵融合了distance编码中的距离信息与相对信息。

值得注意的是,原作者是设计distance编码时专门设计了一个不考虑位置信息衰减的头。即该头的W矩阵是一个全一的下三角矩阵。

Time-shit 在介绍TimeMix之前,要先介绍一下RWKV所使用的Time-shit技巧。

原文:Time-shift: 一行代码,免费提高 Transformer 性能(无参数,无耗时) - 知乎

Time-shiit是原作者提出的一种几乎零成本提升模型效果的trick,实现代码如下所示。

Torch实现
C=x.shape[-1]
self.time_shift = nn.ZeroPad2d((0,0,1,0))
x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
Keras实现
d=K.shape(x)[-1]
x=K.concatenate(K.temporal_padding(x,(1,0))[:,:-1,:d//2],x[:,:,d//2:])

可以看出不论哪个框架也就两行就能实现了,为了便于读者理解。假设存在一个3x4的矩阵。

RWKV:一种鱼和熊掌兼得的线性transformer模型_第5张图片

在经过time-shift后变为

RWKV:一种鱼和熊掌兼得的线性transformer模型_第6张图片

其实就相当于插入一个小的RNN,实验表明简单的trick能让模型的更快更好地收敛。

TimeMix TimeMix是RWKV中用于代替self-attention的部分,基于AFT的基础上做出改进兼具了线性的速度和较好的性能。在进行该步前,需要对输入的x进行time-shift。

同self-attention中的QKV矩阵一样,RWKV中也有对应的RKV矩阵。对与输出矩阵中第i个头的第j个token而言计算步骤如下所示。

RWKV:一种鱼和熊掌兼得的线性transformer模型_第7张图片

这其中是一个[hiden_size,hiden_size]大小的方阵,与常规attention一样用于最后的输出。而是一个[seq_len,hiden_size]大小的矩阵,其作用笔者猜测应该是类似于bias的作用。

ChannelMix ChannelMix 是RWKV中用于替代FFN的部分。类似于tiny attention之于attention。ChannelMix本质上来说是一个tiny TimeMix。

在进行该步计算前,和TimeMix一样要先进行一次time-shift。随后依然要计算出RKV矩阵和W权重。不过有所不同的是在这一步中假设输入x的维度是embed_size,则R的维度应和X相同。KV的维度是用户所自定义的hidden_size,W的形状为[hidden_size,embed_size].

通过设置较小的hidden_size可以实现一个tiny版TimeMix,能在对性能影响较小的情况下实现提速。当hidden_size==embed_size时,可以看作一个不考虑位置信息和归一化的TimeMix或者看作点乘式的FFN。

具体计算公式如下所示

总结 本文介绍了一种鱼和熊掌兼得的模型。既能和AFT一样兼具通用性和高效,distance位置编码的设计使得模型也具备面对超长文本的能力。

实际实验效果可以去看原文的内容,本文只对其结构进行介绍。但总体而言,笔者测试过基于GPT的ai写小说和基于RWKV的ai写小说。相比较而言,RWKV的写出来的文章会更流畅,并且在训练时收敛速度页更快。

参考文献

[1] Are Pre-trained Convolutions Better than Pre-trained Transformers https://arxiv.org/pdf/2105.03322.pdf

[2] MLP-Mixer: An all-MLP Architecture for Vision https://arxiv.org/pdf/2105.01601.pdf

[3] Pay Attention to MLPs https://arxiv.org/pdf/2105.08050.pdf

[4] Synthesizer: Rethinking Self-Attention in Transformer Models https://arxiv.org/abs/2005.00743

[5] Rethinking Attention with Performers https://arxiv.org/abs/2009.14794

[6] Reformer: The Efficient Transformer https://arxiv.org/abs/2001.04451

[7] Linformer: Self-Attention with Linear Complexity https://arxiv.org/abs/2006.04768

[8] 线性Attention的探索:Attention必须有个Softmax吗? https://spaces.ac.cn/archives/7546

[9] Nyströmformer: A Nyström-Based Algorithm for Approximating Self-Attention https://arxiv.org/abs/2102.03902

[10] Adaptive Multi-Resolution Attention with Linear Complexity https://arxiv.org/abs/2108.04962

[11] An Attention Free Transformer https://arxiv.org/abs/2105.14103

[12] Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation

https://arxiv.org/abs/2108.12409




 RWKV:一种鱼和熊掌兼得的线性transformer模型 - 知乎

https://www.youtube.com/watch?v=oaP8_fUFVWw 

你可能感兴趣的:(#,Transformer,transformer,深度学习,人工智能)