
# 来源:https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a
import torch
import torch.nn as nn

batch = 2

n_q, n_k, n_v = 3, 4, 4
d_q, d_k, d_v = 5, 5, 5  # 必须一致

embed_dim = d_q
num_heads = 8
dmodel = embed_dim * num_heads

q = torch.randn(batch, n_q, d_q)  # [2, 3, 5]
k = torch.randn(batch, n_k, d_k)  # [2, 4, 5]
v = torch.randn(batch, n_v, d_v)  # [2, 4, 6]
mask = torch.zeros(batch, n_q, n_k).bool()
qm = q.tile(1, 1, num_heads)
km = k.tile(1, 1, num_heads)
vm = v.tile(1, 1, num_heads)
print("\nq.shape=", q.shape)
print('k.shape=', k.shape)
print('v.shape=', v.shape)
print("qm.shape=", qm.shape)
print('km.shape=', km.shape)
print('vm.shape=', vm.shape)
print('embed_dim=', embed_dim)
print('num_heads=', num_heads)
print('dmodel=', dmodel)
multihead_attn = nn.MultiheadAttention(dmodel, num_heads, batch_first=True)
att_o, att_o_w = multihead_attn(qm, km, vm)
# attn_output, attn_output_weights = multihead_attn(querys.tile(1,2,1), keys.tile(1,2,1), values.tile(1,2,1))
print('att_o.shape=', att_o.size())
print('att_o_w.shape=', att_o_w.size())


q.shape= torch.Size([2, 3, 5])
k.shape= torch.Size([2, 4, 5])
v.shape= torch.Size([2, 4, 5])
qm.shape= torch.Size([2, 3, 40])
km.shape= torch.Size([2, 4, 40])
vm.shape= torch.Size([2, 4, 40])
embed_dim= 5
num_heads= 8
dmodel= 40
att_o.shape= torch.Size([2, 3, 40])
att_o_w.shape= torch.Size([2, 3, 4])


query (Tensor) – Query embeddings of shape (L,Eq) for unbatched input, (L,N,Eq ) when batch_first=False or (N,L,Eq ) when batch_first=True , where L is the target sequence length, N is the batch size, and Eq is the query embedding dimension embed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.

译文:query (Tensor) -查询非批量输入的形状(L,Eq)嵌入,当batch_first=False时(L,N,Eq)或当batch_first=True时(N,L,Eq),其中L为目标序列长度,N为批量大小,Eq为查询嵌入维数embed_dim。将查询与键值对进行比较以产生输出。详见“Attention Is All You Need”一文。

key (Tensor) – Key embeddings of shape (S,Ek) for unbatched input,(S,N,Ek) when batch_first=False or (N,S,Ek) when batch_first=True, where S is the source sequence length,N is the batch size, and Ek is the key embedding dimension kdim. See “Attention Is All You Need” for more details.

译文:key (Tensor) -对于非批处理输入,形状为(S,Ek)的键嵌入,当batch_first=False时(S,N,Ek)或当batch_first=True时(N,S,Ek),其中S为源序列长度,N为批处理大小,Ek为键嵌入维数kdim。详见“Attention Is All You Need”一文。

value (Tensor) – Value embeddings of shape (S,Ev) for unbatched input,(S,N,Ev) when batch_first=False or (N,S,Ev) when batch_first=True, where S is the source sequence length,N is the batch size, and Ev is the value embedding dimension vdim. See “Attention Is All You Need” for more details.

译文:value (Tensor) -对于非批处理输入,当batch_first=False时(S,N,Ev)或batch_first=True时(N,S,Ev)的形状(S,Ev)的值嵌入,其中S为源序列长度,N为批处理大小,Ev为值嵌入维数vdim。详见“Attention Is All You Need”一文。

2、Attention Is All You Need中关于Q、K、V维度的内容

 3.2.2 Multi-Head Attention

Instead of performing a single attention function with dmodel-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values h times with different, learned linear projections to dk, dk and dv dimensions, respectively. On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding dv-dimensional output values. These are concatenated and once again projected, resulting in the final values, as depicted in Figure 2.


一个torch.nn.MultiheadAttention的使用例子_第1张图片 图2 多头注意力机制



W^Q_i\in \mathbb{R}^{d_{model}\times d_k}

W^K_i\in \mathbb{R}^{d_{model}\times d_k}

W^V_i\in \mathbb{R}^{d_{model}\times d_v}

W^O\in \mathbb{R}^{h*d_v\times d_{model}}。→根据矩阵乘法可知Q、K、V的最后一个维度都是dmodel

 In this work we employ h = 8 parallel attention layers, or heads. For each of these we use dk = dv = dmodel/h = 64. Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality




attn_output→[N,LE], attn_output_weights→[N,LS]



