Einsum Is All You Need

基础

一直弄不清楚这个einsum,今天花了些时间学习,以下是学习记录。
首先是找到了这个视频
https://www.youtube.com/watch?v=CLrTj7D2fLM&ab_channel=FacultyofKhan

  • Why use Einsum
    Extremely Convenient and Compact

  • 有4个规则

  1. Any twice-repeated index in a single term is summed over
  2. ij,j->i i is the free index, occurs only once in the expression and cannot be replaced by another free index
  3. No index may occurs 3 or more times in given the term
  4. In an equation involving Einstein notation, free indices on both sides must match

虽然写了规则,但是还是不理解
感觉上就是等式的结果没有写的下标都被求和了,写了就只跑循环。

然后又看了
https://www.youtube.com/watch?v=pkVwUVEHmfI&ab_channel=AladdinPersson
里面讲了很多例子

比如矩阵相乘
M = A × B M = A \times B M=A×B

M i , j = ∑ k A i , k × B k , j M_{i,j} = \sum_k A_{i,k} \times B_{k, j} Mi,j=kAi,k×Bk,j

[ 1 2 3 4 5 6 ] × [ 1 2 3 4 5 6 ] = [ 22 28 49 64 ] \begin{gathered} \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \times \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix} =\begin{bmatrix} 22 & 28 \\ 49 & 64 \end{bmatrix} \end{gathered} [142536]×135246=[22492864]
用numpy实现

import numpy as np

A = np.arange(1, 6+1).reshape(2,3)
B = np.arange(1, 6+1).reshape(3, 2)
print(A)
print(B)
M = np.empty((2, 2))
for i in range(2):
    for j in range(2):
        total = 0
        for k in range(3):
            total += A[i,k] * B[k, j]

        M[i,j] = total

print(M)

如果用einsum写的话就是ik,kj->ij

常用的一些操作

视频中还讲了很多其他的例子

  • Permutation of Tensors ij->ji
  • Summation ij->
  • Column Sum ij->j
  • Row sum ij->i
  • Matrix-Vector Multiplication ij,j->i
  • Matrix-Matrix Multiplication ij,kj->ik
  • Dot product with vector i,i->
  • Dot product with matrix ij,ij->
  • Hadamard Product (element-wise multiplication) ij,ij->ij
  • Outer Product i,j->ij
  • Batch Matrix Multiplication ijk,ikl->ijl
  • Matrix Diagonal ii->i
  • Matrix Trace ii->

知乎上
https://zhuanlan.zhihu.com/p/44954540
这篇文章也写了一些

  • 张量缩约 pqrs, tuqvr->pstuv
    https://www.wikiwand.com/en/Tensor_contraction
    这个不懂
    C p s t u v = ∑ q ∑ r A p q r s B t u q v r = A p q r s B t u q v r C_{pstuv}=\sum_q \sum_r A_{pqrs} B_{tuqvr} = A_{pqrs} B_{tuqvr} Cpstuv=qrApqrsBtuqvr=ApqrsBtuqvr

  • 双线性变换 ik,jkl,il->ij
    https://pytorch.org/docs/master/nn.html#torch.nn.Bilinear
    这个也不懂
    D i j = ∑ k ∑ l A i k B j k l C i l = A i k B j k l C i l D_{ij} = \sum_k \sum_l A_{ik} B_{jkl} C_{il} = A_{ik} B_{jkl} C_{il} Dij=klAikBjklCil=AikBjklCil

实际在代码中的使用

我自己在别人的论文开源代码中遇到过几次用einsum,前面的我忘了。这次是Transformer-XL中多头注意力用了einsum,具体的代码在
https://github.com/kimiyoung/transformer-xl/blob/master/tf/model.py#L42

def rel_multihead_attn(w, r, r_w_bias, r_r_bias, attn_mask, mems, d_model,
                       n_head, d_head, dropout, dropatt, is_training,
                       kernel_initializer, scope='rel_attn'):
  scale = 1 / (d_head ** 0.5)
  with tf.variable_scope(scope):
    qlen = tf.shape(w)[0]
    rlen = tf.shape(r)[0]
    bsz = tf.shape(w)[1]

    cat = tf.concat([mems, w],
                    0) if mems is not None and mems.shape.ndims > 1 else w
    w_heads = tf.layers.dense(cat, 3 * n_head * d_head, use_bias=False,
                              kernel_initializer=kernel_initializer, name='qkv')
    r_head_k = tf.layers.dense(r, n_head * d_head, use_bias=False,
                               kernel_initializer=kernel_initializer, name='r')

    w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1)
    w_head_q = w_head_q[-qlen:]

    klen = tf.shape(w_head_k)[0]

    w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head])
    w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head])
    w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head])

    r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head])

    rw_head_q = w_head_q + r_w_bias
    rr_head_q = w_head_q + r_r_bias

    AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
    BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
    BD = rel_shift(BD)

    attn_score = (AC + BD) * scale
    attn_mask_t = attn_mask[:, :, None, None]
    attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t

    attn_prob = tf.nn.softmax(attn_score, 1)
    attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)

    attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
    size_t = tf.shape(attn_vec)
    attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head])

    attn_out = tf.layers.dense(attn_vec, d_model, use_bias=False,
                               kernel_initializer=kernel_initializer, name='o')
    attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)

    output = tf.contrib.layers.layer_norm(attn_out + w, begin_norm_axis=-1)
  return output

3个地方用到了


    AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
    BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
    attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)

由于有4维,比较难理解
第一个ibnd,jbnd->ijbn
有点像Batch Matrix Multiplication ijk,ikl->ijl
可以把bn看作这里的i,那么就变成了id,jd->ij很像了。
综合attention的计算,这里就是 K T Q K^T Q KTQ
第二个ibnd,jnd->ijbn
这个第二Tensor少了一维b,比较麻烦,不然和第一个是一样的。这个确实是还没懂。
第三个ijbn,jbnd->ibnd
同样的,把bn拿掉后,变成了ij,jd->id,整个还是非常像Batch Matrix Multiplication
综合attention的计算,这里是 V A ^ V \hat{A} VA^ i j b n ijbn ijbn j j j的这一维是系数。
在Self-Attention里面,具体的计算是这样的
参考:https://www.bilibili.com/video/BV1JE411g7XF?p=23
对一个输入的序列 x x x
先做embed
a i = W x i a^i = W x^i \\ ai=Wxi
然后算 q , k , v q, k, v q,k,v
q i = W q a i k i = W k a i v i = W v a i q^i = W^q a^i \\ k^i = W^k a^i \\ v^i = W^v a^i \\ qi=Wqaiki=Wkaivi=Wvai
接着算Attention
Scaled Dot-Product Attention,根据前面的,算Dot product with vector i,i->
α 1 , i = q 1 ⋅ k i d α ^ 1 , i = e x p ( α 1 , i ) ∑ j α 1 , j b 1 = ∑ i α ^ 1 , i v i \alpha_{1,i} = \frac{ q^1 \cdot k^i } { \sqrt{d} } \\ \hat{\alpha}_{1,i} = \frac{ exp(\alpha_{1,i}) } {\sum_j \alpha_{1,j}} \\ b^1 = \sum_i \hat{\alpha}_{1,i} v^i α1,i=d q1kiα^1,i=jα1,jexp(α1,i)b1=iα^1,ivi
写成矩阵的形式的话
Q = W q I K = W k I V = W k I A = K T Q A ^ = s o f t m a x ( A ) O = V A ^ Q = W^q I \\ K = W^k I \\ V = W^k I \\ A = K^T Q \\ \hat{A} = softmax(A) \\ O = V \hat{A} Q=WqIK=WkIV=WkIA=KTQA^=softmax(A)O=VA^
实际矩阵运算是需要转置 K K K
多头注意力
q i , 1 = W q , 1 q i q i , 2 = W q , 2 q i b i = c o n c a t ( b i , 1 , b i , 2 ) b i = W O b i q^{i,1} = W^{q,1} q^i \\ q^{i,2} = W^{q,2} q^i \\ b^i = concat(b^{i,1}, b^{i,2}) \\ b^i = W^O b^i \\ qi,1=Wq,1qiqi,2=Wq,2qibi=concat(bi,1,bi,2)bi=WObi

einsum的具体实现

后面再学吧

你可能感兴趣的:(Einsum Is All You Need)