令人头秃的cudaTensorCoreGemm详解

  我曾被cudaTensorCoreGemm的代码苦苦折磨了一个月,官方文档也不够详细,代码虽然很多注释也不太好理解,于是决定自己写一个cudaTensorCore的详解。本文主要对cudaTensorCore中的compute_gemm函数进行分析。

  Compute_gemm函数流程图如下图所示:

令人头秃的cudaTensorCoreGemm详解_第1张图片

  Cuda wmma API主要用于使用SM中WARP和TensorCore加速矩阵乘加运算,标准公式如下(大写字母均为矩阵):

D=α×A×B+β×C

  其中A矩阵为M×K,B矩阵为K×N,D和C矩阵均为M×N矩阵。

1 基本概念

  Compute_gemm的计算维度为4096×4096×4096(M、K、N),即D=AB+C中的A:4096×4096,B:4096×4096,C:4096×4096,D:4096×4096。wmmaAPI可执行的最小单位tile为16x16大小。所以需将各个矩阵分为256×256个16×16的小块。

令人头秃的cudaTensorCoreGemm详解_第2张图片

图1  wmma加速的最小单位乘加运算16×16×16

2 核函数任务拆解  

  接下来我们需要对计算任务分配至Block、WARP和WARP内的线程组级别。核函数的griddim和blockdim如下所示:

compute_gemm<<>>(A, B, C, D, alpha, beta));

        Griddim:GPU的SM数量 

        Blockdim:256

  如下图2所示,一个WARP里有32个thread,所以一个BLOCK中含有256/32=8个WARP。Compute_gemm又将WARP里的线程分为8个线程组,一个线程组含有4个thread。一个线程组负责一个tile(16×16矩阵)的计算。

令人头秃的cudaTensorCoreGemm详解_第3张图片

 图2  CudaTensorCore任务拆解结构

  根据上述结构划分,一个Block负责8×8个Tile,总维度为128×128矩阵的全部计算,如下图3所示。

令人头秃的cudaTensorCoreGemm详解_第4张图片

 图3  CudaTensorCore中一个Block负责128×128矩阵的全部乘加计算

  因为D矩阵总共有256×256大小的tile需要计算,一个Block负责一个8×8大小的tile计算,而compute_gemm根据GPU(A10)的SM数总共分配了72个Block,无法一次性计算全部的D矩阵,所以核函数需要进行多次循环,一个Block需要计算多个8×8大小的tile矩阵,如下图4所示。在图4中不同颜色代表了不同轮次的循环,其中在最后一轮中将有56个Block闲置不进行计算。

令人头秃的cudaTensorCoreGemm详解_第5张图片

图4  Compute_gemm分15轮完成对最终矩阵的计算,下标表示blockidx.x

  上述Block循环处理的代码如下:

for (unsigned int block_pos = blockIdx.x;; block_pos += gridDim.x) {
    const unsigned int block_tile_i =
        ((block_pos * BLOCK_ROW_TILES) / N_TILES) * (BLOCK_COL_TILES);
    const unsigned int block_tile_j = (block_pos * BLOCK_COL_TILES) % N_TILES;

    // Stop when there are no more D matrix tiles to compute in this CTA.
    if (block_tile_i >= M_TILES) {
      break;
    }

3 数据准备

3.1 数据量  

  在完成Block、WARP和线程组的任务划分后,接下来在计算之前需要先将需要的矩阵数据从GlobalMemory中拷贝至ShareMemory中。

  因为一个线程组(4个thread)负责一个tile(16×16)的全部计算,所以我们将先从线程组的视角来解释数据是如何搬运的。

  首先根据公式D=AB+C,计算一个Tile需要的全部数据有:

  1. 16×16 float类型(4字节)C子矩阵
  2. 16行(一行256×16个数据)half类型(2字节)A子矩阵
  3. 16列(一列256×16个数据)half类型(2字节)B子矩阵

  所以一个线程组所需的所有数据如下图5所示(图5未按照比例绘制):

令人头秃的cudaTensorCoreGemm详解_第6张图片

 图5  一个线程组需要的A、B、C子矩阵

  在一个WARP中有8个线程组,按照2行4列的方式排列,同一行所需要的A矩阵的数据是相同的,同一列需要的B矩阵也是相同的。所以一个WARP只需要拷贝A矩阵的2×16行和B矩阵的4×16列,C矩阵则要拷贝2×4×(16×16)个元素。

  每个WARP使用wmma API中的数据结构fragment存储2行A和4列B(注意这里并不是将一行的数据全部拷贝,而是分批次拷贝然后计算,将在下文中讲到):

wmma::fragment
            a[WARP_COL_TILES];//[2]
wmma::fragment
            b[WARP_ROW_TILES];//[4]

  同理,一个Block有8个WARP,按照4行2列的方式排列,同样有些WARP需要拷贝的数据是重合的。所以一个Block总共需要8×16行A矩阵的数据和8×16列B矩阵的数据以及8×8×(16×16)的C矩阵数据,如图6所示。

令人头秃的cudaTensorCoreGemm详解_第7张图片

 图6  一个Block需要拷贝至ShareMemory的全部数据量

3.2 ShareMemory

  接下来我们需要将上述提到的数据拷贝至ShareMemory中。在Compute_gemm核函数分配中,一个SM对应一个Block对应一个ShareMemory空间,为了避免ShareMemory的BankConflict,在开辟ShareMemory空间时我们使用了一个偏移量SKEW_HALF。声明共享内存的代码如下:

extern __shared__ half shmem[][CHUNK_K * K + SKEW_HALF];

  由于共享内存和L1 Cache享有共同的内存空间,所以共享内存空间不能太大。Compute_gemm在分配时限制了ShareMemory的大小为64KB:

#if SHARED_MEMORY_LIMIT_64K
// With only 64 Kb shared memory available, we can fit two 8-tile chunks of
// the A and B matrix data, that are 16 * 16 * 8 * 8 * 2 = 32 Kb each
// (i.e. two 8x8 arrays of tiles of 16x16 half-typed elements per CTA).
// But we cannot account the 8 Kb total skew overhead, without which the
// performance would be severely impacted. So we choose to reduce the chunk size
// in half, i.e. the amount of A and B matrix data we cache in shared memory.
// Accordingly, this doubles the number of outer iterations across the global K
// dimension, which only slightly impacts the performance.
#define CHUNK_K 4
#else
#define CHUNK_K 8
#endif

  由于64KB无法存储一个Block所需的全部A、B、C的子矩阵,所以需分步依次将A、B、C拷贝至共享内存中,再将其存储到核函数的局部变量fragment中进行计算。

3.2.1 拷贝C矩阵

  首先是将C矩阵拷贝至ShareMemory中。需要拷贝的数据大小为8×8×16×16×sizeof(float)=64KB正好填满ShareMemory。

typedef int4 copy_t;// int4大小为16B

*((copy_t *)(shmem_warp_stream_ptr + SHMEM_STRIDE * i) + laneId) =
          *((copy_t *)(src_gmem_warp_stream_ptr + GLOBAL_MEM_STRIDE * i) + laneId);

  再将C矩阵的数据分配至各个WARP的C_fragment数组中。

// These fragments will accumulate the result of A and B matrix fragment
// multiplications along the K_GLOBAL dimension.
// [2][4]
wmma::fragment c[WARP_COL_TILES]
                                                   [WARP_ROW_TILES];
wmma::load_matrix_sync(c[i][j], tile_ptr, SHMEM_STRIDE, C_LAYOUT);

3.2.2 拷贝A、B矩阵

  C矩阵拷贝结束后,接下来需要拷贝A矩阵和B矩阵到ShareMemory中。但是A矩阵1行tile就有256×16×16×sizeof(half)=128KB,8行就是1MB。ShareMemory无法存储如此大的数据,所以需要分批次拷贝至ShareMemory中,每个批次进行计算后将中间结果保存,待所有批次计算结束后返回最终结果。

  用D(k)表示第k个批次的中间结果,D(k)的计算公式如下式所示。其中k∈(1,256/CHUNK_K)。

D(k)=A(k)×B(k)+D(k-1)+C

  考虑到共享内存的大小,我们将一个批次的步长设置为CHUNK_K大小,用CHUNK_K步长遍历K维度(256)。在64KB共享内存的限制下这个CHUNK_K为4。

for (int tile_k = 0; tile_k < K_TILES; tile_k += CHUNK_K)

  ps:下段代码解释了为什么CHUNK_K选择4而不是8。因为我们共享内存的大小限制为64KB,若此时选择CHUNK_K=8的话,我们一个Block就需要8×16×16×CHUNK_K(8)×sizeof(half)×2=64KB刚好占满ShareMemory,此时就无法通过偏移存储数据避免BankConflict,会严重影响性能。所以我们将CHUNK_K的大小减半。

#if SHARED_MEMORY_LIMIT_64K
// With only 64 Kb shared memory available, we can fit two 8-tile chunks of
// the A and B matrix data, that are 16 * 16 * 8 * 8 * 2 = 32 Kb each
// (i.e. two 8x8 arrays of tiles of 16x16 half-typed elements per CTA).
// But we cannot account the 8 Kb total skew overhead, without which the
// performance would be severely impacted. So we choose to reduce the chunk size
// in half, i.e. the amount of A and B matrix data we cache in shared memory.
// Accordingly, this doubles the number of outer iterations across the global K
// dimension, which only slightly impacts the performance.
#define CHUNK_K 4
#else
#define CHUNK_K 8
#endif

  如下图7所示,每个批次中的每行/列需拷贝CHUNK_K个tile至ShareMemory。

令人头秃的cudaTensorCoreGemm详解_第8张图片

图7  一个批次中的一行/列需拷贝CHUNK_K个数据至ShareMemory

  由于一个WARP需要拷贝2×16行A矩阵和4×16列B矩阵,同样需要按CHUNK_K的步长分批次拷贝,如图8所示。

令人头秃的cudaTensorCoreGemm详解_第9张图片

图8  WARP一个批次需拷贝的数据

  同理一个Block总共需拷贝8×16行A矩阵和8×16列B矩阵到共享内存中,在一个批次中Block总共需往ShareMemory中拷贝8×16×16×CHUNK_K(4)×sizeof(half)=16KB的A矩阵数据,同样B矩阵也需拷贝16KB,总大小32KB小于共享内存的大小64KB。

  由于wmma API是在WARP级别上执行,所以拷贝A、B矩阵的操作也需要在WARP级别。由上述可知,一个Block在一个轮次中总共需要拷贝16KB的A矩阵数据和16KB的B矩阵数据。一个Block有8个WARP,所以一个WARP需拷贝(16+16)/8=4KB的数据。Compute_gemm分配WARP_ID 0-3的WARP拷贝A矩阵,WARP_ID 4-7的WARP拷贝B矩阵:

// Select what warp copies what matrix to shared memory.
// Warps 0-3 copy the A matrix, warps 4-7 copy the B matrix.
const half *warp_ptr = (warpId < 4) ? (&A[block_tile_i * M * K_GLOBAL] +
                                        M * K_GLOBAL * (warpId % 4) * 2)
                                      : (&B[block_tile_j * N * K_GLOBAL] +
                                        N * K_GLOBAL * (warpId % 4) * 2);

   拷贝A、B矩阵至ShareMemory的代码:

for (int i = 0; i < ((WARP_SIZE / 2) / CHUNK_COPY_LINES_PER_WARP) * 2; i++) {
        // Copy 16 bytes at once in each lane.
        *((int4 *)&shmem[shmem_idx][0] + (laneId % CHUNK_COPY_LINE_LANES)) =
            *lane_ptr;

        // Advance the global memory pointer and the shared memory index.
        lane_ptr =
            (int4 *)((half *)lane_ptr + K_GLOBAL * CHUNK_COPY_LINES_PER_WARP);
        shmem_idx += CHUNK_COPY_LINES_PER_WARP;
      }

  当前WARP从ShareMemory取出A、B矩阵:

wmma::fragment
            a[WARP_COL_TILES];//[2]
wmma::fragment
            b[WARP_ROW_TILES];//[4]
wmma::load_matrix_sync(a[i], tile_ptr, K * CHUNK_K + SKEW_HALF);
wmma::load_matrix_sync(b[j], tile_ptr, K * CHUNK_K + SKEW_HALF);

4 计算

  当数据都载入到fragment中后,就可以执行wmma API加速矩阵计算:

wmma::mma_sync(c[i][j], a[i], b[j], c[i][j]);

  其中每个批次的中间结果也就是D(k)都将保存到c矩阵的fragment中也就是c[i][j]中。

5 保存最终结果

  在全部批次计算结束后,最终结果保存在c矩阵的fragment中,最后一步操作就是将C矩阵的数据存储到GlobalMemory开辟的D矩阵空间中。

  首先得将fragment中的数据拷贝至ShareMemory中:

// Uniform, point-wise transformations of ALL fragment elements by ALL
// threads in the warp are well-defined even though element indices
// within fragment storage are not defined.
for (int t = 0; t < c[i][j].num_elements; t++) c[i][j].x[t] *= alpha;

    float *tile_ptr = shmem_warp_tile_ptr + i * SHMEM_STRIDE * K + j * N;

    wmma::store_matrix_sync(tile_ptr, c[i][j], SHMEM_STRIDE, C_LAYOUT);

  最后再从ShareMemory将计算结果拷贝至GlobalMemory中:

*((int4 *)(dst_gmem_warp_stream_ptr + GLOBAL_MEM_STRIDE * i) + laneId) =
          *((int4 *)(shmem_warp_stream_ptr + SHMEM_STRIDE * i) + laneId);

6 实验

  在A10上分别测试simple_wmma_gemm(不使用ShareMemory只是用wmma API的简单版本)、CUBLAS和compute_gemm计算4096×4096×4096(M、N、K)维的矩阵乘加运算,结果如下图9、10、11所示。

图9  simple_wmma_gemm

图10  CUBLAS

图11  compute_gemm

  由此可见采用了ShareMemory技术使用TensorCore执行矩阵乘加运算达到了最快速度2.31ms且TFLOPS达到最高的59.45。CUBLAS虽然不如compute_gemm,但仍然比简单使用wamma API的方法要快出10倍以上,CUBLAS的优点在于速度快且使用简单快捷。

7 其他

  cudaTensorCoreGemm的逐行代码详解:

        cuda学习:学习nvcuda::wmma实现高效gemm - 知乎 (zhihu.com)

  其他参考:

        NVIDIA深度学习Tensor Core全面解析(上篇) (baidu.com)

        用 CUDA 9 编程 Tensor Core - NVIDIA 开发者博客

你可能感兴趣的:(cuda,cuda,gpu,c++,经验分享)