快速矩阵乘法的研究——上

快速矩阵乘法的研究

最近的工作主要在于深度学习框架的性能优化。深度学习框架在工程的优化(内存池、SIMD、汇编、GPU、DSP等等)做到接近极限之后,突破点便集中于算法。

深度学习的性能瓶颈主要在于卷积,卷积的运算方法主要是通过 Im2Col / Winograd / FFT 转化为矩阵乘,完成矩阵乘法之后,再转化为目标结果。

深度学习框架的输入是算法工程产出的网络模型,而目前网络模型都渐渐地转变为 mobilenet 那样 1x1 convolution + depthwise 的形式,在精度几乎无损的情况下,既减少了计算量,又减少了模型体积。而这类网络模型,都以 1x1 卷积为主要耗时点。

对 1x1 卷积而言,其本身就是一个矩阵乘法,FFT / Winograd 等卷积算法已经失去价值,因此研读了一些矩阵乘法相关的论文,整理如下。

传统矩阵乘算法

定义

在 1968 年之前,矩阵乘算法只有按定义实现的传统算法,:
设:
A = ( a 11 a 12 . . . a 21 a 22 . . . . . . . . . . . . a n 1 a n 2 . . . ) B = ( b 11 b 12 . . . b 21 b 22 . . . . . . . . . . . . b n 1 b n 2 . . . ) A=\begin{pmatrix} a_{11} &a_{12} &... \\ a_{21} &a_{22} &... \\ ... & ... & ... \\ a_{n1} & a_{n2} & ... \\ \end{pmatrix} B=\begin{pmatrix} b_{11} &b_{12} &... \\ b_{21} &b_{22} &... \\ ... & ... & ... \\ b_{n1} & b_{n2} & ... \\ \end{pmatrix} A=a11a21...an1a12a22...an2............B=b11b21...bn1b12b22...bn2............
AB 为其乘积,则:
[ A B ] p q = ∑ i = 1 n a p i b i q [AB]_{pq} = \sum_{i=1}^{n}a_{pi}b_{iq} [AB]pq=i=1napibiq

很明显,它是一个 n 3 n^3 n3复杂度的算法,需要 n 3 n^3 n3 次乘法和 n 3 − n 2 n^3-n^2 n3n2次加法。

矩阵乘表示

C = A B C = AB C=AB,A 为 e ∗ l e*l el的矩阵,B 为 l ∗ h l*h lh的矩阵,则称这个矩阵乘是一个 [ e , l , h ] [e, l, h] [e,l,h] 的矩阵乘。

快速矩阵乘法的初步探索

Winograd 算法

请注意,这个不是我们通常所说的卷积优化算法,只是同一个人(Winograd大神)在 1968 年提出一种减少乘法数的矩阵乘算法。

其思路是通过两次 n 2 n^2 n2 的乘法预处理,将规模大的矩阵乘法减少一半,但相应的加法增加一半。为了说明简单,这里假定 n n n为偶数。
θ p = ∑ j = 1 ⌊ n / 2 ⌋ ( a p , 2 j − 1 a p , 2 j ) γ q = ∑ j = 1 ⌊ n / 2 ⌋ ( b 2 j − 1 , q b 2 j , q ) [ A B ] p q = ∑ j = 1 ⌊ n / 2 ⌋ ( a p , 2 j − 1 + b 2 j , q ) ( a p , 2 j + b 2 j − 1 , q ) − θ p − γ q \theta_p = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(a_{p, 2j-1} a_{p, 2j}) \\\gamma_q = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(b_{2j-1, q}b_{2j, q}) \\ [AB]_{pq} = \sum_{j=1}^{\left \lfloor n/2 \right \rfloor}(a_{p, 2j-1}+b_{2j, q})(a_{p, 2j}+b_{2j-1, q}) - \theta_p - \gamma_q θp=j=1n/2(ap,2j1ap,2j)γq=j=1n/2(b2j1,qb2j,q)[AB]pq=j=1n/2(ap,2j1+b2j,q)(ap,2j+b2j1,q)θpγq

这个算法没有降低矩阵乘法的阶(还是 n 3 n^3 n3),只是以廉价计算(加法)替代昂贵运算(乘法),需要根据具体的硬件去判断是否可应用。ARM 架构的 CPU,对量化矩阵乘有帮助,但对浮点矩阵乘没有用。

Strassen 矩阵乘算法

Strassen 矩阵乘的思路是通过加减变换,将一个 [ 2 , 2 , 2 ] [2, 2, 2] [2,2,2]的矩阵乘法所用的乘法数由8降到7,并且递归使用,降低矩阵乘法的阶数: n 3 n^3 n3变成 n 2.81 n^{2.81} n2.81
A = ( a 11 a 12 a 21 a 22 ) B = ( b 11 b 12 b 21 b 22 ) A B = ( c 11 c 12 c 21 c 22 ) A=\begin{pmatrix} a_{11} &a_{12} \\ a_{21} &a_{22} \\ \end{pmatrix} B=\begin{pmatrix} b_{11} &b_{12} \\ b_{21} &b_{22} \\ \end{pmatrix} AB=\begin{pmatrix} c_{11} &c_{12} \\ c_{21} &c_{22} \\ \end{pmatrix} A=(a11a21a12a22)B=(b11b21b12b22)AB=(c11c21c12c22)

v 1 = ( a 11 + a 22 ) ( b 11 + b 22 ) v 2 = ( a 21 + a 22 ) ( b 11 ) v 3 = ( a 11 ) ( b 12 − b 22 ) v 4 = ( a 22 ) ( b 21 − b 11 ) v 5 = ( a 11 + a 12 ) ( b 22 ) v 6 = ( a 21 − a 11 ) ( b 11 + b 12 ) v 7 = ( a 12 − a 22 ) ( b 21 + b 22 ) v_1 = (a_{11}+a_{22})(b_{11}+b_{22})\\ v_2 = (a_{21}+a_{22})(b_{11})\\v_3 = (a_{11})(b_{12}-b_{22})\\v_4 = (a_{22})(b_{21}-b_{11})\\v_5 = (a_{11}+a_{12})(b_{22})\\v_6 = (a_{21}-a_{11})(b_{11}+b_{12})\\v_7 = (a_{12}-a_{22})(b_{21}+b_{22}) v1=(a11+a22)(b11+b22)v2=(a21+a22)(b11)v3=(a11)(b12b22)v4=(a22)(b21b11)v5=(a11+a12)(b22)v6=(a21a11)(b11+b12)v7=(a12a22)(b21+b22)

c 11 = v 1 + v 4 − v 5 + v 7 c 21 = v 2 + v 4 c 12 = v 3 + v 5 c 22 = v 1 + v 3 − v 2 + v 6 c_{11} = v_1+v_4-v_5+v_7\\c_{21} = v_2+v_4\\c_{12} = v_3+v_5\\c_{22} = v_1+v_3-v_2+v_6 c11=v1+v4v5+v7c21=v2+v4c12=v3+v5c22=v1+v3v2+v6

请注意,其中每个元素( a 11 , b 12 , c 22 a_{11}, b_{12}, c_{22} a11,b12,c22等等)不限于实数,可以是一个矩阵。因为矩阵乘法满足分配率与结合率。这样算法就有了脱离硬件的普适价值,因为矩阵加减的复杂度( n 2 n^2 n2)远低于矩阵乘( n 3 n^3 n3)

Winograd 在 Strassen 的基础上对它的算法进行了改进,减少了加减数(18->15),这个也成为最常用的 Strassen 矩阵乘法应用。

三线性表示

为了方便矩阵乘算法的研究,人们提出一种表示矩阵乘算法的形式,叫“Trilinear-form”,即三线性形式。
我们先以 Strassen 算法为例,它的三线性形式是:
∑ i = 1 2 ∑ j = 1 2 ∑ k = 1 2 a i j b j k c i k = ( a 11 ) ( b 12 − b 22 ) ( c 12 + c 22 ) + ( a 11 + a 12 ) ( b 22 ) ( − c 11 + c 12 ) + ( a 21 + a 22 ) ( b 11 ) ( c 21 − c 22 ) + ( a 22 ) ( b 21 + b 11 ) ( c 11 + c 21 ) + ( a 11 + a 22 ) ( b 11 + b 22 ) ( c 11 + c 22 ) + ( a 12 − a 22 ) ( b 21 + b 22 ) ( c 11 ) + ( a 11 − a 21 ) ( b 11 + b 12 ) ( − c 22 ) \sum_{i=1}^2\sum_{j=1}^2\sum_{k=1}^2 a_{ij}b_{jk}c_{ik} = (a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22}) i=12j=12k=12aijbjkcik=(a11)(b12b22)(c12+c22)+(a11+a12)(b22)(c11+c12)+(a21+a22)(b11)(c21c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12a22)(b21+b22)(c11)+(a11a21)(b11+b12)(c22)

怎么看这个公式呢,它其实是按 T r a c e ( A B C ) = A B Trace(ABC) = AB Trace(ABC)=AB 的原理去表示的。两个矩阵的乘积,等效于三个矩阵乘积的迹。在上面公式中,如果我们要算出 c 11 c_{11} c11 的解法,就将 c 11 c_{11} c11 设成 1,其他的 c 值, c 12 , c 21 , c 22 c_{12}, c_{21}, c_{22} c12,c21,c22 全设成 0 ,然后将对应的项相加即可。

这个算式总共有7项,这个 7 我们称之为 Rank (阶)

APA——矩阵乘算法的突破

APA,即 Any Precision Algorithm,是把矩阵乘法阶数继续往下降的重要思想,基本思路是先给出近似的矩阵乘法表达式,然后在多阶张量积之后转换为准确的矩阵乘法。

张量积

我们来看 Strassen 矩阵乘法的表达式:
λ = ( a 11 ) ( b 12 − b 22 ) ( c 12 + c 22 ) + ( a 11 + a 12 ) ( b 22 ) ( − c 11 + c 12 ) + ( a 21 + a 22 ) ( b 11 ) ( c 21 − c 22 ) + ( a 22 ) ( b 21 + b 11 ) ( c 11 + c 21 ) + ( a 11 + a 22 ) ( b 11 + b 22 ) ( c 11 + c 22 ) + ( a 12 − a 22 ) ( b 21 + b 22 ) ( c 11 ) + ( a 11 − a 21 ) ( b 11 + b 12 ) ( − c 22 ) \lambda = (a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22}) λ=(a11)(b12b22)(c12+c22)+(a11+a12)(b22)(c11+c12)+(a21+a22)(b11)(c21c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12a22)(b21+b22)(c11)+(a11a21)(b11+b12)(c22)

对其平方:
λ 2 = ( ( a 11 ) ( b 12 − b 22 ) ( c 12 + c 22 ) + ( a 11 + a 12 ) ( b 22 ) ( − c 11 + c 12 ) + ( a 21 + a 22 ) ( b 11 ) ( c 21 − c 22 ) + ( a 22 ) ( b 21 + b 11 ) ( c 11 + c 21 ) + ( a 11 + a 22 ) ( b 11 + b 22 ) ( c 11 + c 22 ) + ( a 12 − a 22 ) ( b 21 + b 22 ) ( c 11 ) + ( a 11 − a 21 ) ( b 11 + b 12 ) ( − c 22 ) ) 2 \lambda^2 = ((a_{11})(b_{12}-b_{22})(c_{12}+c_{22}) +(a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}) +(a_{21}+a_{22})(b_{11})(c_{21}-c_{22})+(a_{22})(b_{21}+b_{11})(c_{11}+c_{21})+(a_{11}+a_{22})(b_{11}+b_{22})(c_{11}+c_{22})+(a_{12}-a_{22})(b_{21}+b_{22})(c_{11})+(a_{11}-a_{21})(b_{11}+b_{12})(-c_{22}))^2 λ2=((a11)(b12b22)(c12+c22)+(a11+a12)(b22)(c11+c12)+(a21+a22)(b11)(c21c22)+(a22)(b21+b11)(c11+c21)+(a11+a22)(b11+b22)(c11+c22)+(a12a22)(b21+b22)(c11)+(a11a21)(b11+b12)(c22))2

这是个多项式乘法,不难知 λ 2 \lambda^2 λ2 7 2 = 49 7^2=49 72=49 项,我们来看其中一项:
( ( a 11 ) ( b 12 − b 22 ) ( c 12 + c 22 ) ) ( ( a 11 + a 12 ) ( b 22 ) ( − c 11 + c 12 ) ) = ( a 11 a 11 + a 11 a 12 ) ( b 12 b 22 − b 22 b 22 ) ( − c 12 c 11 + c 12 c 12 − c 22 c 11 + c 22 c 12 ) ((a_{11})(b_{12}-b_{22})(c_{12}+c_{22}))((a_{11}+a_{12})(b_{22})(-c_{11}+c_{12}))=(a_{11}a_{11}+a_{11}a_{12})(b_{12}b_{22}-b_{22}b_{22})(-c_{12}c_{11}+c_{12}c_{12}-c_{22}c_{11}+c_{22}c_{12}) ((a11)(b12b22)(c12+c22))((a11+a12)(b22)(c11+c12))=(a11a11+a11a12)(b12b22b22b22)(c12c11+c12c12c22c11+c22c12)
(依然是将a, b, c 分别组合在一起)

a , b , c a, b, c a,b,c间的相乘,如 a 11 a 12 a_{11}a_{12} a11a12,我们将其替代为直和: a 1112 a_{1112} a1112,其含义可以这么理解,在 a 11 a_{11} a11的区域(左上角)中,再划分为四块,取其 a 12 a_{12} a12的区域(右上角)。
不难证明,我们通过这个多项式平方后得到的三线性形式,等效于一个 [ 4 , 4 , 4 ] [4, 4, 4] [4,4,4] 的矩阵乘法。

类似地,我们可以对矩阵乘法的三线性形式进行立方,n次方,以及两个不同的三线性形式乘积,这一系列操作可由“张量积”概括。

APA

Any Precision Algorithm(APA),即任意精度算法,通过在算式中引入一个可配置的实数 λ \lambda λ,得到更好的简化效果。

下面的式子近似用21项表示了一个 [ 3 , 3 , 3 ] [3, 3, 3] [3,3,3]的矩阵乘法

F 1 ( λ ) = ( a 11 + λ 2 a 12 ) ( λ 2 b 11 + b 21 ) c 11 + ( a 21 + λ 2 a 22 ) ( λ 2 b 12 + b 22 ) c 22 + ( a 31 + λ 2 a 32 ) ( λ 2 b 13 + b 23 ) c 33 − a 11 ( b 21 + b 31 ) ( c 11 + c 12 + c 13 ) − a 21 ( b 22 + b 32 ) ( c 21 + c 22 + c 23 ) − a 31 ( b 23 + b 33 ) ( c 31 + c 32 + c 33 ) + ( a 11 + λ 2 a 22 ) ( b 21 − λ b 12 ) c 12 + ( a 21 + λ 2 a 12 ) ( b 22 − λ b 11 ) c 21 + ( a 11 + λ 2 a 32 ) ( b 21 − λ b 13 ) c 13 + ( a 31 + λ 2 a 12 ) ( b 23 − λ b 11 ) c 31 + ( a 21 + λ 2 a 32 ) ( b 22 − λ b 13 ) c 23 + ( a 31 + λ 2 a 22 ) ( b 23 − λ b 12 ) c 32 + ( a 11 + λ 2 a 23 ) ( b 31 + λ b 12 ) ( c 12 + λ c 21 ) + ( a 21 + λ 2 a 13 ) ( b 32 + λ b 11 ) ( c 21 + λ c 12 ) + ( a 11 + λ 2 a 33 ) ( b 31 + λ b 13 ) ( c 13 + λ c 31 ) + ( a 31 + λ 2 a 13 ) ( b 33 + λ b 12 ) ( c 31 + λ c 13 ) + ( a 21 + λ 2 a 33 ) ( b 32 + λ b 13 ) ( c 23 + λ c 32 ) + ( a 31 + λ 2 a 23 ) ( b 33 + λ b 12 ) ( c 32 + λ c 23 ) + ( a 11 + λ 2 a 13 ) b 31 ( c 11 − λ c 31 − λ c 21 ) + ( a 21 + λ 2 a 23 ) b 32 ( c 22 − λ c 32 − λ c 12 ) + ( a 31 + λ 2 a 33 ) b 33 ( c 33 − λ c 13 − λ c 23 ) = λ 2 ( T r a c e ( A B C ) + λ G ( λ ) ) F_1(\lambda) = (a_{11}+\lambda^2a_{12})(\lambda^2b_{11}+b_{21})c_{11}\\+(a_{21}+\lambda^2a_{22})(\lambda^2b_{12}+b_{22})c_{22}+(a_{31}+\lambda^2a_{32})(\lambda^2b_{13}+b_{23})c_{33}-a_{11}(b_{21}+b_{31})(c_{11}+c_{12}+c_{13})-a_{21}(b_{22}+b_{32})(c_{21}+c_{22}+c_{23})-a_{31}(b_{23}+b_{33})(c_{31}+c_{32}+c_{33})+(a_{11}+\lambda^2a_{22})(b_{21}-\lambda b_{12})c_{12}+(a_{21}+\lambda^2a_{12})(b_{22}-\lambda b_{11})c_{21}+(a_{11}+\lambda^2a_{32})(b_{21}-\lambda b_{13})c_{13}+(a_{31}+\lambda^2a_{12})(b_{23}-\lambda b_{11})c_{31}+(a_{21}+\lambda^2a_{32})(b_{22}-\lambda b_{13})c_{23}+(a_{31}+\lambda^2a_{22})(b_{23}-\lambda b_{12})c_{32}+(a_{11}+\lambda^2a_{23})(b_{31}+\lambda b_{12})(c_{12}+\lambda c_{21})+(a_{21}+\lambda^2a_{13})(b_{32}+\lambda b_{11})(c_{21}+\lambda c_{12})+(a_{11}+\lambda^2a_{33})(b_{31}+\lambda b_{13})(c_{13}+\lambda c_{31})+(a_{31}+\lambda^2a_{13})(b_{33}+\lambda b_{12})(c_{31}+\lambda c_{13})+(a_{21}+\lambda^2a_{33})(b_{32}+\lambda b_{13})(c_{23}+\lambda c_{32})+(a_{31}+\lambda^2a_{23})(b_{33}+\lambda b_{12})(c_{32}+\lambda c_{23})+(a_{11}+\lambda^2a_{13})b_{31}(c_{11}-\lambda c_{31}-\lambda c_{21})+(a_{21}+\lambda^2a_{23})b_{32}(c_{22}-\lambda c_{32}-\lambda c_{12})+(a_{31}+\lambda^2a_{33})b_{33}(c_{33}-\lambda c_{13}-\lambda c_{23}) = \lambda^2 (Trace(ABC)+\lambda G(\lambda)) F1(λ)=(a11+λ2a12)(λ2b11+b21)c11+(a21+λ2a22)(λ2b12+b22)c22+(a31+λ2a32)(λ2b13+b23)c33a11(b21+b31)(c11+c12+c13)a21(b22+b32)(c21+c22+c23)a31(b23+b33)(c31+c32+c33)+(a11+λ2a22)(b21λb12)c12+(a21+λ2a12)(b22λb11)c21+(a11+λ2a32)(b21λb13)c13+(a31+λ2a12)(b23λb11)c31+(a21+λ2a32)(b22λb13)c23+(a31+λ2a22)(b23λb12)c32+(a11+λ2a23)(b31+λb12)(c12+λc21)+(a21+λ2a13)(b32+λb11)(c21+λc12)+(a11+λ2a33)(b31+λb13)(c13+λc31)+(a31+λ2a13)(b33+λb12)(c31+λc13)+(a21+λ2a33)(b32+λb13)(c23+λc32)+(a31+λ2a23)(b33+λb12)(c32+λc23)+(a11+λ2a13)b31(c11λc31λc21)+(a21+λ2a23)b32(c22λc32λc12)+(a31+λ2a33)b33(c33λc13λc23)=λ2(Trace(ABC)+λG(λ))

λ \lambda λ趋于无穷小时,其误差也趋于无穷小,因此我们可以设定任意的精度去使用它,这就是 APA 的由来。

对于 APA 算法,多项式的个数我们称之为 Border Rank,上述算式表示了一个 [ 3 , 3 , 3 ] [3, 3, 3] [3,3,3]的矩阵乘法,在 λ 3 \lambda ^3 λ3的基础上分出误差,我们称之为一个降解: [ 3 , 3 , 3 ] ⊴ 3 21 [3, 3, 3] \unlhd_3 21 [3,3,3]321

现在我们来看怎么把上面的 APA 算法变成准确算法。

直观的做法就是把 λ 2 \lambda^2 λ2项取出来,如: ( a 11 + λ 2 a 12 ) ( λ 2 b 11 + b 21 ) c 11 (a_{11}+\lambda^2a_{12})(\lambda^2b_{11}+b_{21})c_{11} (a11+λ2a12)(λ2b11+b21)c11,取出 λ 2 a 11 b 11 c 11 + λ 2 a 12 b 21 c 11 \lambda^2a_{11}b_{11}c_{11}+\lambda^2a_{12}b_{21}c_{11} λ2a11b11c11+λ2a12b21c11,代价就是增加了多项式,不难证明,我们最多会增加到 2 ( 2 + 1 ) / 2 = 3 2(2+1)/2=3 2(2+1)/2=3倍的多项式个数。

无疑,这样做肯定亏了, 3 ∗ 21 = 63 > 3 ∗ 3 ∗ 3 = 27 3*21=63 > 3*3*3=27 321=63>333=27,我们需要施个魔法,就是张量积。

对上面APA 算法进行n次张量积之后,我们可以得到 3 n 3^n 3n大小的矩阵乘算法的降解: [ 3 n , 3 n , 3 n ] ⊴ 2 n + 1 2 1 n [3^n, 3^n, 3^n] \unlhd_{2n+1} 21^n [3n,3n,3n]2n+121n

这时候我们再来取,就不一样了,其阶数变成了:
n ( 2 n + 1 ) 2 1 n n(2n+1)21^n n(2n+1)21n
很明显,当 n 足够大时, n ( 2 n + 1 ) n(2n+1) n(2n+1) 和指数项相比可忽略,这样我们就得到了更好的准确算法,其阶数为:
3 l n ( 21 ) / l n ( 27 ) ≈ 2.77 3ln(21)/ln(27)\approx2.77 3ln(21)/ln(27)2.77

下篇内容:
1、组合矩阵乘
2、渐近和定理
3、Strassen构造
4、Coppersmith–Winograd 算法

你可能感兴趣的:(异构计算/算法优化)