ChatGLMv2 RoPE的代码实现

前言

参考链接:https://zhuanlan.zhihu.com/p/645263524
一开始对tensor的 reshape, 片选操作不熟, 还以为v2没有做rotate动作, 请教了之后才算搞懂了。

代码注释

@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    # Tag:q, k: [sq, b, np, hn], hn=128
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    # rot_dim =64
    rot_dim = rope_cache.shape[-2] * 2
    # Tag: x-->(sq, b, 64, 2)
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # truncate to support variable sizes
    rope_cache = rope_cache[:sq]

    # (sq,b, 32, 2)
    # Tag:总体上就是用的reshape+片选 实现rotate 交换动作;
    # 最后一维是2, 举例如下:[q1, q2],[q3, q4]
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
    # (sq,b, 1, 32, 2); 同样最后一维是2, 举例如下:[cos1, sin1],[cos2, sin2], ...
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
    # rope_cache[..., 0]; shape-->(sq,b, 32)
    # xshaped[..., 0]-->[q1, q3,...]; 利用片选分离出单数, 双数的q
    x_out2 = torch.stack(
        [
            # [q1, q3, ] *[cos1, cos2, cos3] - [q2, q4, ] *[sin1, sin2, sin3]
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            # [q2, q4, ] *[cos1, cos2, cos3] - [q1, q3, ] *[sin1, sin2, sin3]
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)

最终实现下图公式的效果
ChatGLMv2 RoPE的代码实现_第1张图片

你可能感兴趣的:(深度学习,python,大模型,语言模型)