GEMM矩阵计算中共享内存的存取

对于4x4矩阵。计算一次 FMA(乘累加)为一次运算,而各读取 A 和B中一个元素为1+1=2次运算。访存比为1/2。

而若一个 thread 并不只计算一个结果,而是计算 4x4=16个结果,就要从A和B中分别取出4个数据,共8个数据。访存比变为16/8=2,是上面的4倍。

上面是使用一个block来计算一个完整矩阵的情况,对于更大的矩阵,需要用多个block:

GEMM矩阵计算中共享内存的存取_第1张图片

A的每个block的大小为TILE_M*TILE_K,B的每个block的大小为TILE_K*TILE_N:

 GEMM矩阵计算中共享内存的存取_第2张图片

 

注意,这里将全局内存中的A矩阵存入共享内存smemA中时进行了矩阵转置。

GEMM矩阵计算中共享内存的存取_第3张图片

 

然后,需要从共享内存中取出A和B矩阵用于计算,每个线程分别从A和B中取出4*1的矩阵来进行计算,得到C矩阵:

GEMM矩阵计算中共享内存的存取_第4张图片

 

 

C矩阵为128*128大小的矩阵。C矩阵被分成了四份,每份的尺寸都为4*4,使用同一个线程计算这四份4*4大小区域的FMA计算。

最后,利用 Prefetch 的思想,隐藏 Global Memory 读入中间寄存器、将来自 Global Memory 的数据块写入 Shared Memory、从 Shared Memory 中读出数据块的访存延迟,以免计算单元因为 stall 而空闲太久,最终的伪代码如下所示:

#define TILE_K 16
    __shared__ float4 smemA[2][TILE_K * 128 / 4];
    __shared__ float4 smemB[2][TILE_K * 128 / 4];
    float4 c[8][2] = {{make_float4(0.f, 0.f, 0.f, 0.f)}};
    float4 ldg_a_reg[2];
    float4 ldg_b_reg[2];
    float4 a_reg[2][2];
    float4 b_reg[2][2];

    // transfer first tile from global mem to shared mem
    load_gmem_tile_to_reg(A, 0, ldg_a_reg);
    load_gmem_tile_to_reg(B, 0, ldg_b_reg);

    store_reg_to_smem_tile_transpose(ldg_a_reg, 0, smemA[0]);
    store_reg_to_smem_tile(ldg_b_reg, 0, smemB[0]);
    __syncthreads();

    // load first tile from shared mem to register 
    load_smem_tile_to_reg(smemA[0], 0, a_reg[0]);
    load_smem_tile_to_reg(smemB[0], 0, b_reg[0]);

    int write_stage_idx = 1; //ping pong switch
    do {
        i += TILE_K;
        // load next tile from global mem
        load_gmem_tile_to_reg(A, i, ldg_a_reg);
        load_gmem_tile_to_reg(B, i, ldg_b_reg);

        int load_stage_idx = write_stage_idx ^ 1;

    #pragma unroll
        for(int j = 0; j < TILE_K - 1; ++j) {
            // load next tile from shared mem to register 
            load_smem_tile_to_reg(smemA[load_stage_idx], j + 1, a_reg[(j + 1) % 2]);
            load_smem_tile_to_reg(smemB[load_stage_idx], j + 1, b_reg[(j + 1) % 2]);
            // compute matrix multiply accumulate 8x8
            mma8x8(a_reg[j % 2], b_reg[j % 2], c);
        }

        if(i < K) {
            // store next tile to shared mem
            store_reg_to_smem_tile_transpose(ldg_a_reg, 0, smemA[write_stage_idx]);
            store_reg_to_smem_tile(ldg_b_reg, 0, smemB[write_stage_idx]);
            // use double buffer, only need one sync
            __syncthreads();
            // switch
            write_stage_idx ^= 1;
        }

        // load first tile from shared mem to register of next iter
        load_smem_tile_to_reg(smemA[load_stage_idx ^ 1], 0, a_reg[0]);
        load_smem_tile_to_reg(smemB[load_stage_idx ^ 1], 0, b_reg[0]);
        // compute last tile mma 8x8
        mma8x8(a_reg[1], b_reg[1], c);
    } while (i < K);

    store_c(c, C);

你可能感兴趣的:(矩阵,算法,python)