一个torch.nn.MultiheadAttention的使用例子

# 来源:https://towardsdatascience.com/illustrated-self-attention-2d627e33b20a
import torch
import torch.nn as nn

batch = 2

# TODO MHA
n_q, n_k, n_v = 3, 4, 4
d_q, d_k, d_v = 5, 5, 5  # 必须一致

embed_dim = d_q
num_heads = 8
dmodel = embed_dim * num_heads

q = torch.randn(batch, n_q, d_q)  # [2, 3, 5]
k = torch.randn(batch, n_k, d_k)  # [2, 4, 5]
v = torch.randn(batch, n_v, d_v)  # [2, 4, 6]
mask = torch.zeros(batch, n_q, n_k).bool()
qm = q.tile(1, 1, num_heads)
km = k.tile(1, 1, num_heads)
vm = v.tile(1, 1, num_heads)
print("\nq.shape=", q.shape)
print('k.shape=', k.shape)
print('v.shape=', v.shape)
print("qm.shape=", qm.shape)
print('km.shape=', km.shape)
print('vm.shape=', vm.shape)
print('embed_dim=', embed_dim)
print('num_heads=', num_heads)
print('dmodel=', dmodel)
###############################################################
multihead_attn = nn.MultiheadAttention(dmodel, num_heads, batch_first=True)
att_o, att_o_w = multihead_attn(qm, km, vm)
# attn_output, attn_output_weights = multihead_attn(querys.tile(1,2,1), keys.tile(1,2,1), values.tile(1,2,1))
print('att_o.shape=', att_o.size())
print('att_o_w.shape=', att_o_w.size())
pass

运行结果:

q.shape= torch.Size([2, 3, 5])
k.shape= torch.Size([2, 4, 5])
v.shape= torch.Size([2, 4, 5])
qm.shape= torch.Size([2, 3, 40])
km.shape= torch.Size([2, 4, 40])
vm.shape= torch.Size([2, 4, 40])
embed_dim= 5
num_heads= 8
dmodel= 40
att_o.shape= torch.Size([2, 3, 40])
att_o_w.shape= torch.Size([2, 3, 4])

 1、Pytorch官网MultiheadAttention的forward参数解释

query (Tensor) – Query embeddings of shape (L,Eq) for unbatched input, (L,N,Eq ) when batch_first=False or (N,L,Eq ) when batch_first=True , where L is the target sequence length, N is the batch size, and Eq is the query embedding dimension embed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.

译文:query (Tensor) -查询非批量输入的形状(L,Eq)嵌入,当batch_first=False时(L,N,Eq)或当batch_first=True时(N,L,Eq),其中L为目标序列长度,N为批量大小,Eq为查询嵌入维数embed_dim。将查询与键值对进行比较以产生输出。详见“Attention Is All You Need”一文。

key (Tensor) – Key embeddings of shape (S,Ek) for unbatched input,(S,N,Ek) when batch_first=False or (N,S,Ek) when batch_first=True, where S is the source sequence length,N is the batch size, and Ek is the key embedding dimension kdim. See “Attention Is All You Need” for more details.

译文:key (Tensor) -对于非批处理输入,形状为(S,Ek)的键嵌入,当batch_first=False时(S,N,Ek)或当batch_first=True时(N,S,Ek),其中S为源序列长度,N为批处理大小,Ek为键嵌入维数kdim。详见“Attention Is All You Need”一文。

value (Tensor) – Value embeddings of shape (S,Ev) for unbatched input,(S,N,Ev) when batch_first=False or (N,S,Ev) when batch_first=True, where S is the source sequence length,N is the batch size, and Ev is the value embedding dimension vdim. See “Attention Is All You Need” for more details.

译文:value (Tensor) -对于非批处理输入,当batch_first=False时(S,N,Ev)或batch_first=True时(N,S,Ev)的形状(S,Ev)的值嵌入,其中S为源序列长度,N为批处理大小,Ev为值嵌入维数vdim。详见“Attention Is All You Need”一文。

2、Attention Is All You Need中关于Q、K、V维度的内容

 3.2.2 Multi-Head Attention

Instead of performing a single attention function with dmodel-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values h times with different, learned linear projections to dk, dk and dv dimensions, respectively. On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding dv-dimensional output values. These are concatenated and once again projected, resulting in the final values, as depicted in Figure 2.

译文:我们发现,与其用dmodel维度的键、值和查询来执行一个单一的注意函数,不如用不同的、学习过的线性投影将查询、键和值分别投影到dk、dk和dv维度上,这样做是有益的。在这些投射的查询、键和值的每个版本上,我们都会并行地执行注意函数,产生dv维的输出值。这些值被串联起来,并再次进行投影,从而得到最终的数值,如图2所示。

一个torch.nn.MultiheadAttention的使用例子_第1张图片 图2 多头注意力机制

 

 其中投影是参数矩阵

W^Q_i\in \mathbb{R}^{d_{model}\times d_k}

W^K_i\in \mathbb{R}^{d_{model}\times d_k}

W^V_i\in \mathbb{R}^{d_{model}\times d_v}

W^O\in \mathbb{R}^{h*d_v\times d_{model}}。→根据矩阵乘法可知Q、K、V的最后一个维度都是dmodel

 In this work we employ h = 8 parallel attention layers, or heads. For each of these we use dk = dv = dmodel/h = 64. Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality

译文:在这项工作中,我们采用了h=8个平行注意层,或称头。对于每个头,我们使用dk=dv=dmodel/h=64。由于每个头的维度减少,总的计算成本与全维度的单头注意相似。

总结一下

Q→[N,L,dmodel],K→[N,S,dmodel],V→[N,S,dmodel]

attn_output→[N,LE], attn_output_weights→[N,LS]

对开头的例子来说:

Q→[N,L,dmodel],K→[N,S,dmodel],V→[N,S,dmodel]

attn_output→[N,LE], attn_output_weights→[N,LS]

你可能感兴趣的:(深度学习,python,pytorch)