PyTorch中torch.nn.MultiheadAttention()的实现(一维情况下)

import torch
import torch.nn as nn
import numpy as np


# TODO MHA
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True


# 设置随机数种子
setup_seed(20)

Q = torch.tensor([[1]], dtype=torch.float32)  # [2, 3, 4]
K = torch.tensor([[3]], dtype=torch.float32)  # [2, 5, 4]
V = torch.tensor([[5]], dtype=torch.float32)  # [2, 5, 4]

multiHead = nn.MultiheadAttention(1, 1)
att_o, att_o_w = multiHead(Q, K, V)

################################

# 复现 Multi-head Attention
w = multiHead.in_proj_weight
b = multiHead.in_proj_bias
w_o = multiHead.out_proj.weight
b_o = multiHead.out_proj.bias

w_q, w_k, w_v = w.chunk(3)
b_q, b_k, b_v = b.chunk(3)

# Q、K、V的映射
q = Q @ w_q + b_q
k = K @ w_k + b_k
v = V @ w_v + b_v
dk = q.shape[-1]
# 注意力权重的计算
softmax_2 = torch.nn.Softmax(dim=-1)
att_o_w2 = softmax_2(q @ k.transpose(-2, -1) / np.sqrt(dk))
# 输出
out = att_o_w * v
# 输出映射
att_o2 = out @ w_o + b_o
print(att_o, att_o_w)
print(att_o2, att_o_w2)
pass

输出结果

tensor([[-0.4038]], grad_fn=) tensor([[1.]], grad_fn=)
tensor([[-0.4038]], grad_fn=) tensor([[1.]], grad_fn=)

你可能感兴趣的:(pytorch,深度学习,python)