#encoding:utf-8
from math import sqrt
import torch
import torch.nn as nn
class Self_Attention(nn.Module):
def __init__(self, input_dim, dim_k, dim_v):
super(Self_Attention, self). __init__()
self.q = nn.Linear(input_dim, dim_k)
self.k = nn.Linear(input_dim, dim_k)
self.v = nn.Linear(input_dim, dim_v)
self.norm_fact = 1 / sqrt(dim_k)
def forward(self, x):
print("x.shape:", x.shape)
# print("q.shape:", self.q.shape)
Q = self.q(x)
print("Q.shape:", Q.shape)
K = self.k(x)
print("K.shape:", K.shape)
V = self.v(x)
print("V.shape:", V.shape)
atten = nn.Softmax(dim=-1)(torch.bmm(Q,K.permute(0,2,1))) * self.norm_fact
output = torch.bmm(atten, V)
return output
print("\n")
print("self attention:")
x = torch.randn(4,3,1024)
# print(x)
print("input size:", x.size())
self_attention = Self_Attention(1024,128,5)
res = self_attention(x)
# print("\n")
# print(res)
print("output size:", res.size())
print("\n")
class Self_Attention_Muti_Head(nn.Module):
def __init__(self, input_dim, dim_k, dim_v, nums_head):
super(Self_Attention_Muti_Head, self).__init__()
assert dim_k % nums_head == 0
assert dim_v % nums_head == 0
self.q = nn.Linear(input_dim, dim_k)
self.k = nn.Linear(input_dim, dim_k)
self.v = nn.Linear(input_dim, dim_v)
self.nums_head = nums_head
self.dim_k = dim_k
self.dim_v = dim_v
self._norm_fact = 1 / sqrt(dim_k)
def forward(self, x):
Q = self.q(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k//self.nums_head)
K = self.k(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k//self.nums_head)
V = self.v(x).reshape(-1, x.shape[0], x.shape[1], self.dim_v//self.nums_head)
print("x.shape:", x.shape)
print("Q.shape", Q.size())
atten = nn.Softmax(dim=-1)(torch.matmul(Q, K.permute(0,1,3,2)))
output = torch.matmul(atten, V).reshape(x.shape[0], x.shape[1], -1)
return output
print("\n")
print("multi head attention:")
x = torch.randn(4,3,1024)
# print(x)
print(x.size())
self_attention = Self_Attention_Muti_Head(1024,128,6,2)
res = self_attention(x)
print("\n")
# print(res)
print(res.size())
-----------------------------------------------------------------
有个问题:
根据文献:https://arxiv.org/pdf/1911.02150.pdf,感觉这里说的Multi Head Attenion和 Group Query Attention意思是一样的:
这下面这张经典的图中的的Grouped-query意思是一样的:
哪里没理解到位?