多头注意力机制

前面已经讲完了自注意力机制,简单来讲,就是对一组向量空间分别求内积,然后进行缩放,最后对不同的向量使用压缩后的分数累加求和。

1.多头是个什么东西?

        实际上很简单,自注意力层的输出空间被分解为一组独立的子空间,对这些子空间分别进行学习,也就是说,初始的Q,K,V三组独立的密集投影生成三组独立的向量[1],每个向量都通过神经注意力进行处理,然后将多个输出拼接为一个输出序列[2],然后将输出序列经过线性变换[3],每个这样的子空间叫做一个头。密集投影层是可学习层,因此投影过程是可以学习的,独立的头也有助于该层为每个词元学习多组特征,其中每一组内的特征彼此相关,但与其他组的特征几乎无关。

我标记出了三个位置,这三个位置的描述就是实现多头注意力的关键

按照之前我们实现了一个注意力层,我们将其打包为attention(q,k,v)

(1).Q,K,V三组投影,实际上就是线性变化Y = W X

newQ = W_q*Q\\ newK = W_k*K\\ newV = W_v*V

import numpy as np

#假设有矩阵Q,K,V,矩阵大小都一样,[batch_size, N, feature_numbers]
head_num = 3 #三个头

#这里的w矩阵需要能够学习,这里是选择了一个初始化为0的矩阵
w_q = np.random.random((head_num, N, feature_numbers))
w_k = np.random.random((head_num, N, feature_numbers))
w_v = np.random.random((head_num, N, feature_numbers))

#线性变换
newQ = np.matmul(w_q, Q)
newK = np.matmul(w_k, K)
newV = np.matmul(w_v, V)

#使用多头注意力
result = attention(newQ, newK, newV)

#这里只能算伪代码了
#拼接多个头,假设各个矩阵大小一样,因此可以直接转换维度作为拼接
output = result.reshape(Q.shape, head_num)

#最终输出到密集层
head_output = output * Wo

然后经过注意力机制,生成一个头,这是其中一个头而已,根据需要可以生成多个

h_i=attention(newQ, newK, newV)

(2).拼接多个头

output = concat(h_1,h_2,h_3,...,h_n)

(3).全连接

result = W*output

这个代码顶多算伪代码,以后有空修改吧

你可能感兴趣的:(人工智能,算法,机器学习)