一种可理解的线性transformer

transformer O ( n 2 ) O(n^2) O(n2)复杂度的关键点在于,对每个token都查询了一次。因此降低复杂度的一个行之有效的方法是降低查询的次数。因此提出竞争查询的方法。公式如下:

Q = ∑ i e x p ( z i ) ∑ j e x p ( z j ) Q i Q = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}Q_i Q=ijexp(zj)exp(zi)Qi

x n ′ = x n + e x p ( z n ) ∑ j e x p ( z j ) ∑ m ( e x p ( Q ∗ K m ) ∑ k e x p ( Q ∗ K k ) V m ) x_n' = x_n + \dfrac{exp(z_n)}{\sum_jexp(z_j)}\sum_m (\dfrac{exp( Q* K_m)}{\sum_kexp(Q*K_k)}V_m) xn=xn+jexp(zj)exp(zn)m(kexp(QKk)exp(QKm)Vm)

Z Z Z向量为竞争向量,通过softmax归一化得到分布在tokens上的权重,根据 Z Z Z的权重对所有的query向量 Q i Q_i Qi进行求和,得到竞争成功的 Q Q Q向量。可以理解为这一步将所有要查询的东西编码到同一个向量中。然后正常按照transformer的办法用Q向量与每个token的key向量 K m K_m Km V m V_m Vm得到更新向量。然后按照 Q Q Q向量的比例,依次按照比例把更新向量赋值给所有的token。可以看出当 Q i Q_i Qi的比例为极限情况(0,0,…,1,…,0,0)时,相当于只对比例为1的token做查询。

另外一种公式是:

z i = w x i z_i = w x_i zi=wxi

Q = ∑ i e x p ( z i ) ∑ j e x p ( z j ) Q i Q = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}Q_i Q=ijexp(zj)exp(zi)Qi

V = ∑ i e x p ( z i ) ∑ j e x p ( z j ) V i V = \sum_i\dfrac{exp(z_i)}{\sum_jexp(z_j)}V_i V=ijexp(zj)exp(zi)Vi

x n ′ = x n + r e l u ( Q K n ) V x_n' = x_n + relu(QK_n)V xn=xn+relu(QKn)V

也是具备非常明确的意义。

第二种存在GPT形式(casual mask attention):

要依次根据mask,生成每个token的 Q n , V n Q^n,V^n Qn,Vn。每个token的更新如下:

Q n = ∑ i = 0 n e x p ( z i ) ∑ j = 0 n e x p ( z j ) Q i Q^n = \sum_{i = 0}^n\dfrac{exp(z_i)}{\sum_{j = 0}^nexp(z_j)}Q_i Qn=i=0nj=0nexp(zj)exp(zi)Qi

V n = ∑ i = 0 n e x p ( z i ) ∑ j = 0 n e x p ( z j ) V i V^n = \sum_{i = 0}^n\dfrac{exp(z_i)}{\sum_{j = 0}^nexp(z_j)}V_i Vn=i=0nj=0nexp(zj)exp(zi)Vi

x n ′ = x n + r e l u ( Q n K n ) V n x_n' = x_n + relu(Q^nK_n)V^n xn=xn+relu(QnKn)Vn

注意 Q n Q_n Qn的计算是存在递归公式的,因此其复杂度为seq len线性相关,缺点是无法并行。

第二个公式是有明确意义的。每一层都预定了一个w,用来判断query的重要程度。因此w是一个先验知识,用来决定应该先查询什么,然后再查询什么。但是真正的查询向量Q又是和序列有关的,不同的序列有不同的查询向量。

你可能感兴趣的:(python,算法,机器学习,transformer)