一直弄不清楚这个einsum,今天花了些时间学习,以下是学习记录。
首先是找到了这个视频
https://www.youtube.com/watch?v=CLrTj7D2fLM&ab_channel=FacultyofKhan
Why use Einsum
Extremely Convenient and Compact
有4个规则
ij,j->i
i is the free index, occurs only once in the expression and cannot be replaced by another free index虽然写了规则,但是还是不理解
感觉上就是等式的结果没有写的下标都被求和了,写了就只跑循环。
然后又看了
https://www.youtube.com/watch?v=pkVwUVEHmfI&ab_channel=AladdinPersson
里面讲了很多例子
比如矩阵相乘
M = A × B M = A \times B M=A×B
M i , j = ∑ k A i , k × B k , j M_{i,j} = \sum_k A_{i,k} \times B_{k, j} Mi,j=k∑Ai,k×Bk,j
[ 1 2 3 4 5 6 ] × [ 1 2 3 4 5 6 ] = [ 22 28 49 64 ] \begin{gathered} \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \times \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix} =\begin{bmatrix} 22 & 28 \\ 49 & 64 \end{bmatrix} \end{gathered} [142536]×⎣⎡135246⎦⎤=[22492864]
用numpy实现
import numpy as np
A = np.arange(1, 6+1).reshape(2,3)
B = np.arange(1, 6+1).reshape(3, 2)
print(A)
print(B)
M = np.empty((2, 2))
for i in range(2):
for j in range(2):
total = 0
for k in range(3):
total += A[i,k] * B[k, j]
M[i,j] = total
print(M)
如果用einsum写的话就是ik,kj->ij
视频中还讲了很多其他的例子
ij->ji
ij->
ij->j
ij->i
ij,j->i
ij,kj->ik
i,i->
ij,ij->
ij,ij->ij
i,j->ij
ijk,ikl->ijl
ii->i
ii->
知乎上
https://zhuanlan.zhihu.com/p/44954540
这篇文章也写了一些
张量缩约 pqrs, tuqvr->pstuv
https://www.wikiwand.com/en/Tensor_contraction
这个不懂
C p s t u v = ∑ q ∑ r A p q r s B t u q v r = A p q r s B t u q v r C_{pstuv}=\sum_q \sum_r A_{pqrs} B_{tuqvr} = A_{pqrs} B_{tuqvr} Cpstuv=q∑r∑ApqrsBtuqvr=ApqrsBtuqvr
双线性变换 ik,jkl,il->ij
https://pytorch.org/docs/master/nn.html#torch.nn.Bilinear
这个也不懂
D i j = ∑ k ∑ l A i k B j k l C i l = A i k B j k l C i l D_{ij} = \sum_k \sum_l A_{ik} B_{jkl} C_{il} = A_{ik} B_{jkl} C_{il} Dij=k∑l∑AikBjklCil=AikBjklCil
我自己在别人的论文开源代码中遇到过几次用einsum,前面的我忘了。这次是Transformer-XL中多头注意力用了einsum,具体的代码在
https://github.com/kimiyoung/transformer-xl/blob/master/tf/model.py#L42
def rel_multihead_attn(w, r, r_w_bias, r_r_bias, attn_mask, mems, d_model,
n_head, d_head, dropout, dropatt, is_training,
kernel_initializer, scope='rel_attn'):
scale = 1 / (d_head ** 0.5)
with tf.variable_scope(scope):
qlen = tf.shape(w)[0]
rlen = tf.shape(r)[0]
bsz = tf.shape(w)[1]
cat = tf.concat([mems, w],
0) if mems is not None and mems.shape.ndims > 1 else w
w_heads = tf.layers.dense(cat, 3 * n_head * d_head, use_bias=False,
kernel_initializer=kernel_initializer, name='qkv')
r_head_k = tf.layers.dense(r, n_head * d_head, use_bias=False,
kernel_initializer=kernel_initializer, name='r')
w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1)
w_head_q = w_head_q[-qlen:]
klen = tf.shape(w_head_k)[0]
w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head])
w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head])
w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head])
r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head])
rw_head_q = w_head_q + r_w_bias
rr_head_q = w_head_q + r_r_bias
AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
BD = rel_shift(BD)
attn_score = (AC + BD) * scale
attn_mask_t = attn_mask[:, :, None, None]
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
attn_prob = tf.nn.softmax(attn_score, 1)
attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)
attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
size_t = tf.shape(attn_vec)
attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head])
attn_out = tf.layers.dense(attn_vec, d_model, use_bias=False,
kernel_initializer=kernel_initializer, name='o')
attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)
output = tf.contrib.layers.layer_norm(attn_out + w, begin_norm_axis=-1)
return output
3个地方用到了
AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
由于有4维,比较难理解
第一个ibnd,jbnd->ijbn
有点像Batch Matrix Multiplication ijk,ikl->ijl
可以把bn
看作这里的i
,那么就变成了id,jd->ij
很像了。
综合attention的计算,这里就是 K T Q K^T Q KTQ
第二个ibnd,jnd->ijbn
这个第二Tensor少了一维b
,比较麻烦,不然和第一个是一样的。这个确实是还没懂。
第三个ijbn,jbnd->ibnd
同样的,把bn
拿掉后,变成了ij,jd->id
,整个还是非常像Batch Matrix Multiplication
综合attention的计算,这里是 V A ^ V \hat{A} VA^。 i j b n ijbn ijbn中 j j j的这一维是系数。
在Self-Attention里面,具体的计算是这样的
参考:https://www.bilibili.com/video/BV1JE411g7XF?p=23
对一个输入的序列 x x x
先做embed
a i = W x i a^i = W x^i \\ ai=Wxi
然后算 q , k , v q, k, v q,k,v
q i = W q a i k i = W k a i v i = W v a i q^i = W^q a^i \\ k^i = W^k a^i \\ v^i = W^v a^i \\ qi=Wqaiki=Wkaivi=Wvai
接着算Attention
Scaled Dot-Product Attention,根据前面的,算Dot product with vector i,i->
α 1 , i = q 1 ⋅ k i d α ^ 1 , i = e x p ( α 1 , i ) ∑ j α 1 , j b 1 = ∑ i α ^ 1 , i v i \alpha_{1,i} = \frac{ q^1 \cdot k^i } { \sqrt{d} } \\ \hat{\alpha}_{1,i} = \frac{ exp(\alpha_{1,i}) } {\sum_j \alpha_{1,j}} \\ b^1 = \sum_i \hat{\alpha}_{1,i} v^i α1,i=dq1⋅kiα^1,i=∑jα1,jexp(α1,i)b1=i∑α^1,ivi
写成矩阵的形式的话
Q = W q I K = W k I V = W k I A = K T Q A ^ = s o f t m a x ( A ) O = V A ^ Q = W^q I \\ K = W^k I \\ V = W^k I \\ A = K^T Q \\ \hat{A} = softmax(A) \\ O = V \hat{A} Q=WqIK=WkIV=WkIA=KTQA^=softmax(A)O=VA^
实际矩阵运算是需要转置 K K K
多头注意力
q i , 1 = W q , 1 q i q i , 2 = W q , 2 q i b i = c o n c a t ( b i , 1 , b i , 2 ) b i = W O b i q^{i,1} = W^{q,1} q^i \\ q^{i,2} = W^{q,2} q^i \\ b^i = concat(b^{i,1}, b^{i,2}) \\ b^i = W^O b^i \\ qi,1=Wq,1qiqi,2=Wq,2qibi=concat(bi,1,bi,2)bi=WObi
后面再学吧