本文将对 Scaled Dot-Product Attention,Multi-head attention,Self-attention,Transformer等概念做一个简要介绍和区分。最后对通用的 Multi-head attention 进行代码实现和应用。
在实际应用中,经常会用到 Attention 机制,其中最常用的是Scaled Dot-Product Attention,它是通过计算query和key之间的点积 来作为 之间的相似度。
是在 Scaled Dot-Product Attention 的基础上,分成多个头,也就是有多个Q、K、V并行进行计算attention,可能侧重与不同的方面的相似度和权重。
自注意力机制 是在Scaled Dot-Product Attention 以及Multi-head attention的基础上的一种应用场景,就是指 QKV的来源是相同的, 自己和自己计算attention,类似于经过一个线性层等,输入输出等长。
如果QKV的来源是不同的,不能叫做 self-attention,只能是attention。比如GST中的KV是随机初始化的多个token,而Q是reference encoder得到的梅尔谱的一帧。同理,Q也可以是随机初始化的一个,而KV是来自于输入,这样就可以将某一变长长度为N的输入计算attention得到一个长度为1的向量。
Transformer 是指 在Scaled Dot-Product Attention 以及Multi-head attention以及Self-attention的基础上的一种通用的模型框架,它包括Positional Encoding,Encoder,Decoder等等。Transformer不等于Self-attention。
平时经常会用到Attention操作,接下来对Multi-head Attention 进行代码整理和实现,方便以后可以直接调用接口,其中单头注意力机制作为其中的一种特殊情况。
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
'''
input:
query --- [N, T_q, query_dim]
key --- [N, T_k, key_dim]
mask --- [N, T_k]
output:
out --- [N, T_q, num_units]
scores -- [h, N, T_q, T_k]
'''
def __init__(self, query_dim, key_dim, num_units, num_heads):
super().__init__()
self.num_units = num_units
self.num_heads = num_heads
self.key_dim = key_dim
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
def forward(self, query, key, mask=None):
querys = self.W_query(query) # [N, T_q, num_units]
keys = self.W_key(key) # [N, T_k, num_units]
values = self.W_value(key)
split_size = self.num_units // self.num_heads
querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
## score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim ** 0.5)
## mask
if mask is not None:
## mask: [N, T_k] --> [h, N, T_q, T_k]
mask = mask.unsqueeze(1).unsqueeze(0).repeat(self.num_heads,1,querys.shape[2],1)
scores = scores.masked_fill(mask, -np.inf)
scores = F.softmax(scores, dim=3)
## out = score * V
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
return out,scores
## 类实例化
attention = MultiHeadAttention(3,4,5,1)
## 输入
qurry = torch.randn(8, 2, 3)
key = torch.randn(8, 6 ,4)
mask = torch.tensor([[False, False, False, False, True, True],
[False, False, False, True, True, True],
[False, False, False, False, True, True],
[False, False, False, True, True, True],
[False, False, False, False, True, True],
[False, False, False, True, True, True],
[False, False, False, False, True, True],
[False, False, False, True, True, True],])
## 输出
out, scores = attention(qurry, key, mask)
print('out:', out.shape) ## torch.Size([8, 2, 5])
print('scores:', scores.shape) ## torch.Size([1, 8, 2, 6])