官方文档链接:MultiheadAttention — PyTorch 1.12 documentation
解读 官方给的参数解释:
整体称为一个单注意力头,因为运算结束后只对每个输入产生一个输出结果,一般在网络中,输出可以被称为网络提取的特征,那我们肯定希望提取多种特征,[ 比如说我输入是一个修狗狗图片的向量序列,我肯定希望网络提取到特征有形状、颜色、纹理等等,所以单次注意肯定是不够的 ]
其中的紫色长方块(Scaled Dot-Product Attention)就是上一张单注意力头,内部结构没有画出,如果拼接h个单注意力头,摆放位置就如图所示。
import torch
import torch.nn as n
# 先决定参数
dims = 256 * 10 # 所有头总共需要的输入维度
heads = 10 # 单注意力头的总共个数
dropout_pro = 0.0 # 单注意力头
# 传入参数得到我们需要的多注意力头
layer = torch.nn.MultiheadAttention(embed_dim = dims, num_heads = heads, dropout = dropout_pro)
embed_dim - Total dimension of the model 模型的总维度(总输入维度)
num_heads - Number of parallel attention heads. Note that embed_dim
will be split across num_heads
(i.e. each head will have dimension embed_dim // num_heads
注意看括号里的这句话,每个头的维度为 embed_dim除num_heads
dropout – Dropout probability on attn_output_weights
. Default: 0.0
(no dropout).
bias – If specified, adds bias to input / output projection layers. Default: True
add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default: False
add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False
kdim – Total number of features for keys. Default: None
(uses kdim=embed_dim
vdim – Total number of features for values. Default: None
(uses vdim=embed_dim
batch_first – If True
, then the input and output tensors are provided as (batch, seq, feature). Default: False
(seq, batch, feature).
(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None, average_attn_weights=True)
query – Query embeddings of shape for unbatched input, when batch_first=False
or when batch_first=True
, where is the target sequence length, is the batch size, and is the query embedding dimension embed_dim
. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.
翻译一下就是说,如果输入不是以batch形式的,query的形状就是,是目标序列的长度,就是query embedding的维度,也就是输入词向量被变换成q后,q的维度,这个注释说是embed_dim, 说明输入词向量和q维度一致;
则query的形状为,若 batch_first=True
key – Key embeddings of shape for unbatched input, when batch_first=False
or when batch_first=True
, where S is the source sequence length,is the batch size, and is the key embedding dimension kdim
. See “Attention Is All You Need” for more details.
则key的形状为。是key embedding的维度,默认也是与相同,则是原序列的长度(source sequence length)
value – Value embeddings of shape for unbatched input, when batch_first=False
or when batch_first=True
, where is the source sequence length, is the batch size, and is the value embedding dimension vdim
. See “Attention Is All You Need” for more details.
key_padding_mask – If specified, a mask of shape (N, S)(N,S) indicating which elements within key
to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be (S)(S). Binary and byte masks are supported. For a binary mask, a True
value indicates that the corresponding key
value will be ignored for the purpose of attention. For a byte mask, a non-zero value indicates that the corresponding key
value will be ignored.
need_weights – If specified, returns attn_output_weights
in addition to attn_outputs
. Default: True
attn_mask – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape (L, S)(L,S) or (N\cdot\text{num\_heads}, L, S)(N⋅num_heads,L,S), where NN is the batch size, LL is the target sequence length, and SS is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a True
value indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.
average_attn_weights – If true, indicates that the returned attn_weights
should be averaged across heads. Otherwise, attn_weights
are provided separately per head. Note that this flag only has an effect when need_weights=True
. Default: True
(i.e. average weights across heads)
attn_output - Attention outputs of shape when input is unbatched, when batch_first=False
or when batch_first=True
, where is the target sequence length, is the batch size, and is the embedding dimension embed_dim
输出的形状为, 是目标序列长度,是batch的大小,是embed_dim(第一步实例化设置的)
attn_output_weights - Only returned when need_weights=True
. If average_attn_weights=True
, returns attention weights averaged across heads of shape ) when input is unbatched or , where NN is the batch size,is the target sequence length, and S is the source sequence length. If average_weights=False
, returns attention weights per head of shapewhen input is unbatched or .
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output, attn_output_weights = multihead_attn(query, key, value)