RoPE为苏剑林大佬之作,最早应用于他自研的RoFormer (Rotary Transformer),属于相对位置编码。效果优于绝对位置编码和经典式相对位置编码。出自论文:《RoFormer: Enhanced Transformer with Rotary Position Embedding》
据我了解,最近发布的大语言模型:Meta的LLaMA、清华的ChatGLM都采用了RoPE。这也足以证明了RoPE的优势。
本文讲解下个人对RoPE原理的理解以及自己用torch复现了一下,更详细地请参阅苏神的原文(文末已附上链接)。
如对RoPE公式推导有任何疑问,可评论区或私信反馈,我将做出详细解答。
最原始的正余弦位置编码(即sinusoidal位置编码)是一种绝对位置编码,但从其原理中的正余弦的和差化积公式来看,引入的其实也是相对位置编码。
绝对位置编码的讲解可看我的博客:随记·手撕coding | absolute positional embedding
优势: 实现简单,可预先计算好,不用参与训练,速度快。
劣势: 没有外推性,即如果预训练最大长度为512的话,那么最多就只能处理长度为512的句子,再长就处理不了了。当然,也可以将超过512的位置向量随机初始化,然后继续微调。
⭐ 那rope是怎么在q,k中注入这种相对位置信息的呢?我看了苏神的推导。大概是这样的:先假设q,k是二维的情形,因为复数可用二维向量表示,所以借助复数域来求解。在推导的过程中,用的最多的一句话就是:“为简单起见,假设xxx” 这对推导十分关键。
首先,假设新的qk向量(即假设已注入绝对位置信息)的内积会引入相对位置信息。并在最后假设合理的初始化条件:
不是一般性,考虑其q,k向量为二维的情形,借助复数域推导出为q,k向量编码绝对位置信息的函数 f 。
别看公式多,理解起来并不难。下面我细说一下其中几个关键的推导步骤:
上面我们设了q,k的绝对位置编码函数为:
然后又求出了:
而:
那带入(4)式就可以得出q,k的绝对位置编码函数了(下面以q为例,k同理)
为避免这个正交矩阵过于稀疏,浪费算力,代码实现时都是依据下面公式来计算RoPE:
注:苏神在θ的选择上沿用了tansformer的θi = 10000-2i/d 。因为苏神实验发现,在RoPE中采用这个θ也可以带来一定的远程衰减性(意思就是token之间的依赖关系会随着距离的变远而衰减,这也符合我们的直观理解)。当然别的θ也可,只要满足远程衰减。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# %%
def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):
# (max_len, 1)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
# (output_dim//2)
ids = torch.arange(0, output_dim // 2, dtype=torch.float) # 即公式里的i, i的范围是 [0,d/2]
theta = torch.pow(10000, -2 * ids / output_dim)
# (max_len, output_dim//2)
embeddings = position * theta # 即公式里的:pos / (10000^(2i/d))
# (max_len, output_dim//2, 2)
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
# (bs, head, max_len, output_dim//2, 2)
embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape)))) # 在bs维度重复,其他维度都是1不重复
# (bs, head, max_len, output_dim)
# reshape后就是:偶数sin, 奇数cos了
embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim))
embeddings = embeddings.to(device)
return embeddings
# %%
def RoPE(q, k):
# q,k: (bs, head, max_len, output_dim)
batch_size = q.shape[0]
nums_head = q.shape[1]
max_len = q.shape[2]
output_dim = q.shape[-1]
# (bs, head, max_len, output_dim)
pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)
# cos_pos,sin_pos: (bs, head, max_len, output_dim)
# 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是cos 拿出来并复制
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是sin 拿出来并复制
# q,k: (bs, head, max_len, output_dim)
q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
q2 = q2.reshape(q.shape) # reshape后就是正负交替了
# 更新qw, *对应位置相乘
q = q * cos_pos + q2 * sin_pos
k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
k2 = k2.reshape(k.shape)
# 更新kw, *对应位置相乘
k = k * cos_pos + k2 * sin_pos
return q, k
# %%
def attention(q, k, v, mask=None, dropout=None, use_RoPE=True):
# q.shape: (bs, head, seq_len, dk)
# k.shape: (bs, head, seq_len, dk)
# v.shape: (bs, head, seq_len, dk)
if use_RoPE:
q, k = RoPE(q, k)
d_k = k.size()[-1]
att_logits = torch.matmul(q, k.transpose(-2, -1)) # (bs, head, seq_len, seq_len)
att_logits /= math.sqrt(d_k)
if mask is not None:
att_scores = att_logits.masked_fill(mask == 0, -1e-9) # mask掉为0的部分,设为负无穷大
att_scores = F.softmax(att_logits, dim=-1) # (bs, head, seq_len, seq_len)
if dropout is not None:
att_scores = dropout(att_scores)
# (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
return torch.matmul(att_scores, v), att_scores
if __name__ == '__main__':
# (bs, head, seq_len, dk)
q = torch.randn((8, 12, 10, 32))
k = torch.randn((8, 12, 10, 32))
v = torch.randn((8, 12, 10, 32))
res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True)
# (bs, head, seq_len, dk), (bs, head, seq_len, seq_len)
print(res.shape, att_scores.shape)