参考链接:https://zhuanlan.zhihu.com/p/645263524
一开始对tensor的 reshape, 片选操作不熟, 还以为v2没有做rotate动作, 请教了之后才算搞懂了。
@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
# x: [sq, b, np, hn]
# Tag:q, k: [sq, b, np, hn], hn=128
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
# rot_dim =64
rot_dim = rope_cache.shape[-2] * 2
# Tag: x-->(sq, b, 64, 2)
x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# truncate to support variable sizes
rope_cache = rope_cache[:sq]
# (sq,b, 32, 2)
# Tag:总体上就是用的reshape+片选 实现rotate 交换动作;
# 最后一维是2, 举例如下:[q1, q2],[q3, q4]
xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
# (sq,b, 1, 32, 2); 同样最后一维是2, 举例如下:[cos1, sin1],[cos2, sin2], ...
rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
# rope_cache[..., 0]; shape-->(sq,b, 32)
# xshaped[..., 0]-->[q1, q3,...]; 利用片选分离出单数, 双数的q
x_out2 = torch.stack(
[
# [q1, q3, ] *[cos1, cos2, cos3] - [q2, q4, ] *[sin1, sin2, sin3]
xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
# [q2, q4, ] *[cos1, cos2, cos3] - [q1, q3, ] *[sin1, sin2, sin3]
xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
],
-1,
)
x_out2 = x_out2.flatten(3)
return torch.cat((x_out2, x_pass), dim=-1)