By LongLuo
机器学习中需要训练大量数据,涉及大量复杂运算,例如卷积、矩阵等。这些复杂运算不仅多,而且每次计算的数据量很大,如果能针对这些运算进行优化,可以大幅提高性能。
假设 A A A为 m × p m \times p m×p的矩阵, B B B为 p × n p \times n p×n的矩阵,那么称 m × n m \times n m×n的矩阵 C C C为矩阵 A A A与 B B B的乘积,记作 C = A B C = AB C=AB,称为矩阵积(matrix product)。
其中矩阵 C C C中的第 i i i行第 j j j列元素可以表示为:
( A B ) i j = ∑ k = 1 p a i k b k j = a i 1 b 1 j + a i 2 b 2 j + ⋅ ⋅ ⋅ + a i p b p j (AB)_{ij} = \sum_{k=1}^{p}{a_{ik}b_{kj}} = a_{i1}b_{1j} + a_{i2}b_{2j} + \cdot\cdot\cdot + a_{ip}b_{pj} (AB)ij=k=1∑paikbkj=ai1b1j+ai2b2j+⋅⋅⋅+aipbpj
如下图所示:
假如在矩阵 A A A和矩阵 B B B中,$m = p = n = N , 那 么 完 成 ,那么完成 ,那么完成C=AB$需要多少次乘法呢?
综合可以看出,矩阵乘法的算法复杂度是:$\Theta(N^{3}) $。
那么有没有比 Θ ( N 3 ) \Theta(N^{3}) Θ(N3)更快的算法呢?
1969年,Volker Strassen提出了第一个算法时间复杂度低于 Θ ( N 3 ) \Theta(N^{3}) Θ(N3)矩阵乘法算法,算法复杂度为 Θ ( n l o g 2 7 ) = Θ ( n 2.807 ) \Theta(n^{log_{2}^{7}}) = \Theta(n^{2.807}) Θ(nlog27)=Θ(n2.807)。从下图可知,Strassen算法只有在对于维数比较大的矩阵 ( N > 300 N > 300 N>300) ,性能上才有很大的优势,可以减少很多乘法计算。
Strassen算法证明了矩阵乘法存在时间复杂度低于 Θ ( N 3 ) \Theta(N^{3}) Θ(N3)的算法的存在,后续学者不断研究发现新的更快的算法,截止目前时间复杂度最低的矩阵乘法算法是Coppersmith-Winograd方法的一种扩展方法,其算法复杂度为 Θ ( n 2.375 ) \Theta(n^{2.375}) Θ(n2.375)。
假设矩阵 A A A 和矩阵 B B B都是 N × N ( N = 2 n ) N \times N (N = 2^{n}) N×N(N=2n)的方矩阵,求 C = A B C = AB C=AB,如下所示:
A = [ A 11 A 12 A 21 A 22 ] , B = [ B 11 B 12 B 21 B 22 ] , C = [ C 11 C 12 C 21 C 22 ] A = \left [\begin{matrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{matrix} \right ] , B = \left [ \begin{matrix} B_{11} & B_{12} \\ B_{21} & B_{22} \\ \end{matrix} \right ] , C = \left [ \begin{matrix} C_{11} & C_{12} \\ C_{21} & C_{22} \\ \end{matrix} \right ] A=[A11A21A12A22],B=[B11B21B12B22],C=[C11C21C12C22]
其中
[ C 11 C 12 C 21 C 22 ] = [ A 11 A 12 A 21 A 22 ] ⋅ [ B 11 B 12 B 21 B 22 ] \begin{bmatrix} C_{11} & C_{12} \\C_{21} & C_{22} \end{bmatrix} = \begin{bmatrix} A_{11} & A_{12} \\A_{21} & A_{22} \end{bmatrix} \cdot \begin{bmatrix} B_{11} & B_{12} \\B_{21} & B_{22} \end{bmatrix} [C11C21C12C22]=[A11A21A12A22]⋅[B11B21B12B22]
矩阵 C 可以通过下列公式求出:
C 11 = A 11 ⋅ B 11 + A 12 ⋅ B 21 C 12 = A 11 ⋅ B 12 + A 22 ⋅ B 21 C 21 = A 21 ⋅ B 11 + A 22 ⋅ B 21 C 22 = A 21 ⋅ B 12 + A 22 ⋅ B 22 C_{11} = A_{11} \cdot B_{11} + A_{12} \cdot B_{21}\\ C_{12} = A_{11} \cdot B_{12} + A_{22} \cdot B_{21}\\ C_{21} = A_{21} \cdot B_{11} + A_{22} \cdot B_{21}\\ C_{22} = A_{21} \cdot B_{12} + A_{22} \cdot B_{22} C11=A11⋅B11+A12⋅B21C12=A11⋅B12+A22⋅B21C21=A21⋅B11+A22⋅B21C22=A21⋅B12+A22⋅B22
从上述公式我们可以得出,计算2个 n ∗ n n * n n∗n的矩阵相乘需要2个 n 2 ∗ n 2 \frac{n}{2} * \frac{n}{2} 2n∗2n的矩阵8次乘法和4次加法。我们使用 T ( n ) T(n) T(n)表示 n ∗ n n*n n∗n矩阵乘法的时间复杂度,那么我们可以根据上面的分解得到下面的递推公式:
T ( n ) = 8 ∗ T ( n 2 ) + Θ ( n 2 ) T(n) = 8 * T(\frac{n}{2}) + \Theta(n^{2}) T(n)=8∗T(2n)+Θ(n2)
其中,
最终可计算得到 T ( n ) = Θ ( n l o g 2 8 ) = Θ ( n 3 ) T(n)=\Theta(n^{log_{2}^{8}})=\Theta(n^{3}) T(n)=Θ(nlog28)=Θ(n3)。
可以看出每次递归操作都需要8次矩阵相乘,而这正是瓶颈的来源。相比加法,矩阵乘法是非常慢的,于是我们想到能不能减少矩阵相乘的次数呢?
答案是当然可以!!!
Strassen算法正是从这个角度出发,实现了降低算法复杂度!
实现步骤可以分为以下4步:
按上述方法将矩阵 A , B , C A,B,C A,B,C分解(花费时间$\Theta(1) $。
如下创建10个 n 2 × n 2 \frac{n}{2} × \frac{n}{2} 2n×2n的矩阵 S 1 , S 2 , . . . , S 10 S_1, S_2, ..., S_{10} S1,S2,...,S10(花费时间 $\Theta(n^2) $。
S 1 = B 12 − B 22 S 2 = A 11 + A 12 S 3 = A 21 + A 22 S 4 = B 21 − B 11 S 5 = A 11 + A 22 S 6 = B 11 + B 22 S 7 = A 12 − A 22 S 8 = B 21 + B 22 S 9 = A 11 − A 21 S 10 = B 11 + B 12 S_1 = B_{12} - B_{22}\\ S_2 = A_{11} + A_{12}\\S_3 = A_{21} + A_{22}\\S_4 = B_{21} - B_{11}\\S_5 = A_{11} + A_{22}\\S_6 = B_{11} + B_{22}\\S_7 = A_{12} - A_{22}\\S_8 = B_{21} + B_{22}\\S_9 = A_{11} - A_{21}\\S_{10} = B_{11} + B_{12} S1=B12−B22S2=A11+A12S3=A21+A22S4=B21−B11S5=A11+A22S6=B11+B22S7=A12−A22S8=B21+B22S9=A11−A21S10=B11+B12
递归地计算7个矩阵积 P 1 , P 2 , . . . , P 7 P_1, P_2, ..., P_7 P1,P2,...,P7,每个矩阵 P i P_i Pi都是 n 2 × n 2 \frac{n}{2} × \frac{n}{2} 2n×2n的。
P 1 = A 11 ⋅ S 1 = A 11 ⋅ B 12 − A 11 ⋅ B 22 P 2 = S 2 ⋅ B 22 = A 11 ⋅ B 22 + A 12 ⋅ B 22 P 3 = S 3 ⋅ B 11 = A 21 ⋅ B 11 + A 22 ⋅ B 11 P 4 = A 22 ⋅ S 4 = A 22 ⋅ B 21 − A 22 ⋅ B 11 P 5 = S 5 ⋅ S 6 = A 11 ⋅ B 11 + A 11 ⋅ B 22 + A 22 ⋅ B 11 + A 22 ⋅ B 22 P 6 = S 7 ⋅ S 8 = A 12 ⋅ B 21 + A 12 ⋅ B 22 − A 22 ⋅ B 21 − A 22 ⋅ B 22 P 7 = S 9 ⋅ S 10 = A 11 ⋅ B 11 + A 11 ⋅ B 12 − A 21 ⋅ B 11 − A 21 ⋅ B 12 P_1 = A_{11} \cdot S_1 = A_{11} \cdot B_{12} - A_{11} \cdot B_{22}\\P_2 = S_2 \cdot B_{22} = A_{11} \cdot B_{22} + A_{12} \cdot B_{22}\\P_3 = S_3 \cdot B_{11} = A_{21} \cdot B_{11} + A_{22} \cdot B_{11}\\P_4 = A_{22} \cdot S_4 = A_{22}\cdot B_{21} - A_{22} \cdot B_{11}\\P_5 = S_5 \cdot S_6 = A_{11} \cdot B_{11} + A_{11} \cdot B_{22} + A_{22} \cdot B_{11} + A_{22} \cdot B_{22}\\P_6 = S_7 \cdot S_8 = A_{12} \cdot B_{21} + A{12} \cdot B_{22} - A_{22} \cdot B_{21} - A_{22} \cdot B_{22}\\P_7 = S_9 \cdot S_{10}= A_{11} \cdot B_{11} + A_{11} \cdot B_{12} - A_{21} \cdot B_{11} - A_{21} \cdot B_{12} P1=A11⋅S1=A11⋅B12−A11⋅B22P2=S2⋅B22=A11⋅B22+A12⋅B22P3=S3⋅B11=A21⋅B11+A22⋅B11P4=A22⋅S4=A22⋅B21−A22⋅B11P5=S5⋅S6=A11⋅B11+A11⋅B22+A22⋅B11+A22⋅B22P6=S7⋅S8=A12⋅B21+A12⋅B22−A22⋅B21−A22⋅B22P7=S9⋅S10=A11⋅B11+A11⋅B12−A21⋅B11−A21⋅B12
注意,上述公式中只有中间一列需要计算。
通过 P i P_i Pi计算 C 11 , C 12 , C 21 , C 22 C_{11}, C_{12}, C_{21}, C_{22} C11,C12,C21,C22,花费时间 Θ ( n 2 ) \Theta(n^2) Θ(n2)。
C 11 = P 5 + P 4 − P 2 + P 6 C 12 = P 1 + P 2 C 21 = P 3 + P 4 C 22 = P 5 + P 1 − P 3 − P 7 C_{11} = P_5 + P_4 - P_2 + P_6\\C_{12} = P_1 + P_2\\C_{21} = P_3 + P_4\\C_{22} = P_5 + P_1 - P_3 - P_7 C11=P5+P4−P2+P6C12=P1+P2C21=P3+P4C22=P5+P1−P3−P7
综合可得如下递归式:
T ( n ) = { Θ ( 1 ) 若 n = 1 7 T ( n 2 ) + Θ ( n 2 ) 若 n > 1 T(n) = \begin{cases}\Theta(1) & 若n = 1\\7T(\frac{n}{2}) + \Theta(n^2) & 若n >1 \end{cases} T(n)={Θ(1)7T(2n)+Θ(n2)若n=1若n>1
进而求出时间复杂度为: T ( n ) = Θ ( n l o g 2 7 ) T(n) = \Theta(n^{log_{2}^{7}}) T(n)=Θ(nlog27)
我们以MNN中关于Strassen算法源码实现来学习:https://github.com/alibaba/MNN/blob/master/source/backend/cpu/compute/StrassenMatmulComputor.cpp。
类StrassenMatrixComputor提供了3个API供调用:
API | 说明 |
---|---|
_generateTrivalMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT); | 普通矩阵乘法计算 |
_generateMatMul(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth); | Strassen算法的矩阵乘法 |
_generateMatMulConstB(const Tensor* AT, const Tensor* BT, const Tensor* CT, int currentDepth); | Strassen算法的矩阵乘法(和MatMul的区别在于内存Buffer是否允许复用) |
我们以_generateMatMul为例来学习下Strassen算法如何实现,可以分成如下几步:
在矩阵操作中,因为需要对矩阵的维数进行扩展,涉及大量读写操作,这些读写操作都需要大量循环,如果读写次数超出使用Strassen乘法的收益的话,就得不偿失了,那么就使用普通的矩阵乘法。
/*
Compute the memory read / write cost for expand Matrix Mul need eSub*lSub*hSub*(1+1.0/CONVOLUTION_TILED_NUMBWR), Matrix Add/Sub need x*y*UNIT*3 (2 read 1 write)
*/
float saveCost = (eSub * lSub * hSub) * (1.0f + 1.0f / CONVOLUTION_TILED_NUMBWR) - 4 * (eSub * lSub) * 3 - 7 * (eSub * hSub * 3);
if (currentDepth >= mMaxDepth || e <= CONVOLUTION_TILED_NUMBWR || l % 2 != 0 || h % 2 != 0 || saveCost < 0.0f)
{
return _generateTrivialMatMul(AT, BT, CT);
}
将矩阵$A,B,C$3个矩阵都分成4块:
auto aStride = AT->stride(0);
auto a11 = AT->host<float>() + 0 * aUnit * eSub + 0 * aStride * lSub;
auto a12 = AT->host<float>() + 0 * aUnit * eSub + 1 * aStride * lSub;
auto a21 = AT->host<float>() + 1 * aUnit * eSub + 0 * aStride * lSub;
auto a22 = AT->host<float>() + 1 * aUnit * eSub + 1 * aStride * lSub;
auto bStride = BT->stride(0);
auto b11 = BT->host<float>() + 0 * bUnit * lSub + 0 * bStride * hSub;
auto b12 = BT->host<float>() + 0 * bUnit * lSub + 1 * bStride * hSub;
auto b21 = BT->host<float>() + 1 * bUnit * lSub + 0 * bStride * hSub;
auto b22 = BT->host<float>() + 1 * bUnit * lSub + 1 * bStride * hSub;
auto cStride = CT->stride(0); auto c11 = CT->host<float>() + 0 * aUnit * eSub + 0 * cStride * hSub;
auto c12 = CT->host<float>() + 0 * aUnit * eSub + 1 * cStride * hSub;
auto c21 = CT->host<float>() + 1 * aUnit * eSub + 0 * cStride * hSub;
auto c22 = CT->host<float>() + 1 * aUnit * eSub + 1 * cStride * hSub;
Strassen算法核心就是分治思想。这一步可以写成下列所示伪代码:
1. If n = 1 Output A × B
2. Else
3. Compute A11,B11, . . . ,A22,B22 % by computing m = n/2
4. P1 Strassen(A11,B12 − B22)
5. P2 Strassen(A11 + A12,B22)
6. P3 Strassen(A21 + A22,B11)
7. P4 Strassen(A22,B21 − B11)
8. P5 Strassen(A11 + A22,B11 + B22)
9. P6 Strassen(A12 − A22,B21 + B22)
10. P7 Strassen(A11 − A21,B11 + B12)
11. C11 P5 + P4 − P2 + P6
12. C12 P1 + P2
13. C21 P3 + P4
14. C22 P1 + P5 − P3 − P7
15. Output C
16. End If
例如其中的一步代码如下所示:
{
// S1=A21+A22, T1=B12-B11, P5=S1T1
auto f = [a22, a21, b11, b12, xAddr, yAddr, eSub, lSub, hSub, aStride, bStride]() {
MNNMatrixAdd(xAddr, a21, a22, eSub * aUnit / 4, eSub * aUnit, aStride, aStride, lSub);
MNNMatrixSub(yAddr, b12, b11, lSub * bUnit / 4, lSub * bUnit, bStride, bStride, hSub);
};
mFunctions.emplace_back(f);
auto code = _generateMatMul(X.get(), Y.get(), C22.get(), currentDepth);
if (code != NO_ERROR)
{
return code;
}
}
递归执行,得到最终结果!