保留网络(RetNet)具有与相同大小的转换器相当的性能,可以并行训练,但支持递归模式,允许每个令牌的O(1)推理复杂性。
非官方但完整的实现可以在下面的我的回购中找到:
GitHub - syncdoth/RetNet:RetNet 的完整实现(Retentive Networks...
RetNet(保留网络,https://arxiv.org/pdf/2307.08621.pdf)的完整实现,包括并行...
github.com
对于序列模型,尤其是生成模型,我们有上述三个特点:快速推理、并行训练和强大的性能。(在我看来,还有一个维度:序列长度外推。RetNet 可能支持这一点,但没有明确的实验。
RNN 具有快速推理但训练缓慢,线性变压器的性能较弱,变压器每个令牌推理具有 O(n)。RetNet满足所有三个条件: 并行训练、O(1) 推理和节拍变压器。
有多种方法可以减轻生成变压器的昂贵推理。著名的作品包括Linear Transformers,Attention-Free Transformers(AFT;来自Apple)和RWKV(来自BlinkDL,基于AFT)。
这些值得单独发布,所以我不会详细介绍:但在我看来,它们在数学上都非常优雅,尤其是 RNN 如何并行化的推导。而我发现 RetNet 更有趣,因为它也有块表示和一些漂亮的技巧,如 xpos。
RetNet 是在同一 Transformer 架构中将“注意力”替换为“保留”的即插即用替代。
我将以自上而下的方式介绍它们。
每个 RetNet 块的公式。
在最高级别,RetNet 由几个相同的块堆栈组成,每个堆栈都包含 MultiScaleRetention (MSR) 和 FeedForwardNetwork (FFN)。它们还具有层规范和跳过连接,与变形金刚相同。FFN也几乎与变形金刚相同,后者是2层MLP,隐藏的暗光尺寸= 2倍嵌入尺寸,并具有gelu激活功能。
如果我们用MultiHeadAttention代替MSR,这只是Transformer。因此,所有差异都可以在MSR中找到。
多尺度类似于多头。在上面的等式中,γ是一些用于保留的超参数,这是为每个头部单独定义的。在群体规范之前,这是普通的多头关注,但保留。
门控MSR在输出端增加了组范数、旋门和输出投影,可视为辅助设计选择。(组规范允许缩放点积,但目前并不那么重要。 最重要的区别(保留模块)尚未到来。
最后,让我们看看什么是保留。保留有 3 种范式:并行、循环和块递归。让我们一一看一下。
并行保留
保留的并行表示
专注于最后一行。忽略 D,再次,这是没有 softmax 的点积关注。所以重要的细节又在D和Theta中。
请参阅 xpos 白皮书。我还发现这篇讲义有助于理解这一点。
如果绘制 D,则 D 如下所示:
gamma = 0.9
exponent = [[0, 0, 0, 0],
[1, 0, 0, 0],
[2, 1, 0, 0],
[3, 2, 1, 0]]
D = tril(gamma**exponent)
# [[1., 0., 0., 0.],
# [0.9000, 1., 0., 0.],
# [0.8100, 0.9000, 1., 0.],
# [0.7290, 0.8100, 0.9000, 1.]])
经常性保留
经常性保留
Sn类似于变压器中的KV缓存。RetNet 不是按顺序连接所有这些矩阵,而是将它们聚合成一个矩阵,循环在第一行。然后,此值乘以当前步骤的查询。
这与并行保留完全相同。
非正式证明草图:
设 S_0 = 0。 如果我们解决了S_n的复发,
回想一下平行表示中 D 的指数矩阵的最后一行,即 [3, 2, 1, 0]。请注意,n=4。当我们计算第 4 个代币与第 1 个代币的保留期时,我们将其衰减 3 倍,相当于上式中的 n — i = 3! 由于其余部分相同,因此并行表示和循环表示彼此相同。
分块保留
这看起来很复杂,但它实际上是每个块的并行计算 + 块的循环连接。 唯一重要的是应用的衰减次数。
实际上,论文对 Ri 的分块表示(上面的等式)是错误的!事实上,它应该是
其中 X 运算符是叉积,D_B 是 D 矩阵的最后一行。直观地说,这是从平行表示和循环表示的衰减乘法得出的。
就是这样!以上是两种表示的摘要图。
所以基本上,最重要的细节是它使用了一种叫做衰减的东西,并且应用正确的衰减次数允许并行化。但我们必须了解这种衰败背后的动机是什么。推导(在高级别)非常简单。
3.现在,我们将A矩阵对角化为以下内容。
4. 然后,可以将 Λ 符号吸收到其他可学习的参数中(Q_n = X * W_k,因此 Λ 可以吸收到 W_k!因此,我们只剩下中间部分。
中间部分正是我们之前观察到的γ(衰变)和θ。
直观地说,它们作为一种“封闭式位置编码”工作,它也具有递归形式,因此可以提前计算时间n的编码,从而实现并行化。
崔世贤
对于那些感兴趣的人,请看一下我对RetNet的实现:
GitHub - syncdoth/RetNet: Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.