从上图中我们可以看到三种处理方法。第一种是将A和B矩阵分块(竖切和横切),第二种方法是将C和B矩阵分块(竖切和竖切),第三种方法是将C和A矩阵分块(横切和横切):
GEMM的子任务是GEPP或GEMP;最小粒度的任务是GEBP或GEPB或点乘。
这里面M表示横向和纵向维度都很大的矩阵,P表示横向或纵向有一个维度很小的矩阵(或者就是一个向量),B表示横向和纵向维度都很大的矩阵(或者就是只有一个元素的矩阵或向量)。
其各自计算方法大致如下:
对于内存和线程结构:
有一点很难理解的,这个示例里面是一个包含了 8 个 warp、256 个线程的线程block:
在8个warp中,每个warp中包含了32个线程。单个线程thread和包含了多个线程的线程束Warp之间的关系,如下图所示:
左侧1x1的绿色小块的单个线程,灰色8x4=32个灰色块表示32个线程。右侧显示了左图中的单个绿色小块,即单个线程所计算的绿色块的计算过程:
根据 8×1的紫色块 (fragment A)和 1×8 黄色块(fragment B)的向量外积,计算出一个 8×8 的单个线程分块(8x8的右图中绿色块,对应于左图中1x1的绿色块)。
最核心的任务重点,是将C=A行xB列,变成C=A列xB行。将矩阵相乘(GEMM)转化为向量相乘。
这需要把A,B,C分别拆成一系列子矩阵来达到:
上图中完成4x4的矩阵C的矩阵乘法(GEMM)的核心计算如下所述:
将每个循环中的rank1更新:
将大小为4x1的a向量(A矩阵的第0列)和大小为1x4的b向量(B矩阵的第0行)相乘,得到4x4的矩阵结果,然后与C矩阵的初始值相加;
然后循环移动rank1,再次更新:
将大小为4x1的a向量(A矩阵的第1列)和大小为1x4的b向量(B矩阵的第1行)相乘,得到4x4的矩阵结果,然后与C矩阵的初始值相加。从而实现进行rank1更新。
继续循环移动rank1,再次更新。
直到A矩阵的最后一列和B矩阵的最后一行进行rank1更新,整个GEMM完成。
4.B1矩阵取kcxnr的矩阵B2,完成mcxnr的子矩阵C2的更新;
5.A2矩阵每次取MR行得到子矩阵A3,与B2矩阵进行微核心计算,更新MRxNR的C:
6.A3矩阵每次取一列,B2矩阵每次取一行进行一个ALU部件的计算。
相信这6个循环步骤就很直观的表现了GEMM的实现过程。
如果从存储的视角观察GEMM的执行过程。首先最基本的是所有数据都在主存中,所以要完成C=C+AB的任务(这一步不重要)。
然后在Cache中完成分块的子任务:A按列划分的子矩阵与B按行划分的子矩阵乘积,去更新C矩阵(中的一行,下图中蓝色的横行)的老值:
Cache中再将向量-向量乘继续细分(就在这一步进行拆分!拆分成亮蓝色的竖行!)到A列子矩阵的子行和B行子矩阵的子列进行乘积计算,得到更新后的C矩阵的一行:
接下来,在核心内部寄存器按先列后行顺序更新packed A矩阵读入,在Cache中按先行后列顺序更新packed B矩阵读入:
在其内部的小循环中,将可以单次的核心计算称为GEMM微核心,计算MRxNR的矩阵C=C+AB,如下图所示:
然后,一个MRxNR的计算需要KC次A和B进行乘积更新运算,这也是GEMM微核心的典型循环迭代计算称为块点乘:载入packed A的一列,载入packed B的一行,计算点乘,更新寄存器中的C的老值:
那为什么不用简单的点乘一个一个元素去更新,而要用块点乘每次更新整个块的方法呢?可以看到浮点操作和访存操作的比值,简单点乘的话约等于1;而块点乘的则远大于1。由于我们希望运算访存比越高越好(这样计算效率就高),所以采用块点乘的方式:
几乎所有level3的BLAS操作都是用块点乘这种思路实现的。
当矩阵更大包含多个MR或NR的A、B矩阵时,就有两层的外层循环,循环更新步长为MR和NR。外层称为宏核心或内核心,也是通过cache实现的主要方式:
再最后来回顾一下GEMM,完成GEMM将矩阵划分为6级:
1.划分为ncxnc的矩阵;
2.A矩阵划分为ncxkc的子矩阵A1,B矩阵划分为kcxnc的子矩阵B1,完成更新ncxnc的C矩阵;
3.A1矩阵划分为mcxkc的子矩阵A2进行计算,每次更新mcxnc的C子矩阵