CUDA矩阵乘法GEMM优化:全局内存-共享内存-寄存器优化,以及数据预存取优化

不使用任何优化的矩阵乘法,代码如下:

__global__ void matrixMul(const float *A, const float *B, float *C, 
                          int M, int N, int K) {
    int tx = blockIdx.x * blockDim.x + threadIdx.x;
    int ty = blockIdx.y * blockDim.y + threadIdx.y;
    if(ty < M && tx < N) {
        float c = 0;
        for(int i = 0; i < K; ++i){
            c += A[ty * K + i] * B[i * N + tx];
        }
        C[ty * N + tx] = c;
    }
}

计算一次 FMA(乘累加)之前需要读一次 A 和读一次 B,众所周知,读取 Global Memory 的代价很大,通常都需要几百个 cycle(时钟周期),而计算一次 FMA 通常只需要几个 cycle,大量的时间被花费在了访存上。

可以将 A 和 B 矩阵先搬运到 Shared Memory(SM 中低延迟的 on-chip memory,block 内线程共享,附 NVIDIA GPU 内存结构图)中降低访存的开销,这的确是一个很好的思路,但是这只能将访存代价从几百 cycle 降低到几十 cycle,并不改变问题的本质。

问题的关键在于主体循环由两条 Load 指令与一条 FMA 指令构成,计算指令只占总体的 1/3,计算访存比过低,最终导致了访存延迟不能被隐藏,从而性能不理想。

下面是只使用全局内存-共享内存-寄存器来优化矩阵乘法。让一个 thread 并不只计算一个结果,而是计算 4x4 个结果。其伪代码如下:

float c[4][4] = {{0}};
    float a_reg[4];
    float b_reg[4];
    for(int i = 0; i < K; i += TILE_K){
        __syncthreads();
        // transfer tile from global mem to shared mem
        load_gmem_tile_to_smem(A, i, smemA);
        load_gmem_tile_to_smem(B, i, smemB);
        __syncthreads();
    #pragma unroll
        for(int j = 0; j < TILE_K; ++j) {
            // load tile from shared mem to register 
            load_smem_tile_to_reg(smemA, j, a_reg);
            load_smem_tile_to_reg(smemB, j, b_reg);
            // compute matrix multiply accumulate 4x4
            mma4x4(a_reg, b_reg, c);
        }
    }

从 smemA 读取到寄存器 a_reg 中,需要进行 4 次访存操作,B 同理,那么主体的计算访存指令比例变成了 16/8。

相对于之前的情况,计算指令的占比大大提高了。足够大的计算访存比能提升计算单元的利用率,并能起到隐藏访存延迟的作用。

思考一下为什么能得到这样的提升,首先梳理一下不使用共享内存和寄存器优化的普通矩阵乘法(向量内积)逻辑:

# 数组 A:M行K列的行主序矩阵
# 数组 B:K行N列的行主序矩阵
# 数组 C:M行N列的行主序矩阵
# alpha:一个标量
# beta:一个标量
# 计算方法:
#    c=alpha*A*B+beta*C;
​
__global__ void matrixMul(const float *A, const float *B, float *C,
                          int M, int N, int K, float alpha, float beta)
{
    int tx = blockIdx.x * blockDim.x + threadIdx.x;
    int ty = blockIdx.y * blockDim.y + threadIdx.y;

    int baseX = blockIdx.x * blockDim.x;
    int baseY = blockIdx.y * blockDim.y;

    float c = 0;

    if (tx < M && ty < N)
    {
        for (int i = 0; i < K; i++)
        {
            c += A[tx * K + i] * B[i * N + ty];
        }
        C[tx * N + ty] = beta * C[tx * N + ty] + alpha * c; // we multiply alpha here to reduce the alpha cal num.
    }
}

其伪代码为:

M=N=K=8;
float a[M*K];
float b[N*K];
float c[M*N];
for i in range(M):
    for j in range(N):
        for k in range(K):
            c[i*N+j]+=a[i*K+k]*b[k*N+j];

而使用共享内存和寄存器优化,基于向量外积的矩阵伪代码为:

M=N=K=8;
float a[M*K];
float b[N*K];
float c[M*N];
for k in range(K):
    for i in range(M):
        for j in range(N):
            c[i*N+j]+=a[i*K+k]*b[k*N+j];

在向量外积中,编译器自动做了一些优化:

float a[M*K];
float b[N*K];
float c[M*N];
for k in range(K):
    regB[0:N] = b[k*N:(k+1)*N]
    for i in range(M):
        regA = a[i*K+k];
        for j in range(N):
            c[i*N+j]+=regA*regB[j];

    把A矩阵和B矩阵拆成K列和K行的小块,然后将M-N-K的循环改为K-M-N的循环。

其中 regA 和 regB 均为寄存器。我们不难发现,对于每一次循环 j ,使用的都是完全相同的 A 矩阵里的元素,因此可以用一个寄存器来缓存该值;对于每一次循环 k,使用的都是完全相同的一行 B 矩阵中的值,因此我们可以用 N 个寄存器缓存该值。

于是将原本M*N*2次访存(底下两层循环需要访问一次 A 矩阵和一次 B 矩阵),通过使用N+1个寄存器缓存(B使用N个,A使用一个),优化为M+N次访存。同时我们也注意到, M 和 N 越大的情况下,提升效果越发显著,这也是为什么我们希望每一个线程负责的分块大一点比较好。但同时 M 和 N 越大,每一个线程多使用的寄存器就越多,而在 GPU 的语境下,更高的寄存器占用意味着更低的 Occupancy。

第二种K-M-N循环体构造与第一种M-N-K循环最大的区别就在于它能在不展开 k 的情况下通过展开 m 和 n 处的循环就能自动的识别到重复访存,并使用相应的寄存器来避免重复访存。例如我们假定 M=N=2 ​,那么展开 m 和 n 处循环的结果如下。

M=N=2;
float a[M*K];
float b[N*K];
float c[M*N];
for k in range(K):
    c[0*N+0]+=a[0*K+k]*b[k*N+0]
    c[0*N+1]+=a[0*K+k]*b[k*N+1]
    c[1*N+0]+=a[1*K+k]*b[k*N+0]
    c[1*N+1]+=a[1*K+k]*b[k*N+1]

只要是稍微现代一点的编译器,都能一眼看出这四条指令的 8 次访存,有 4 次是可以合并的。同时现代一点的编译器也能在一定程度上根据生成的汇编交叉排列计算和访存达到延迟覆盖的目的。而向量内积的方案需要把整个 k 维度展开才能看到这些潜在的访存合并机会。在 CPU 矩阵乘的语境下,一般计算 kernel 的 Kblock 都比较大(好几百),而 Mblock 和 Nblock 都很小(一般取 6x16,根据架构来做具体确定),寄存器数量又非常少,因此基本上无法在 K 维上将循环完全展开并做优化。因为展开一个超长的循环不仅会带来额外的寄存器占用、优化难度,还会带来更多的汇编指令,使得最终的二进制文件臃肿不堪。但在 GPU 上,情况却恰恰相反。对于已知循环次数的小循环,即便你没有指定 #pragma unroll,nvcc 也会自动的展开这些循环。而对于一个 thread 所负责的小型矩阵乘,这三层循环的值均为 8,符合 nvcc 自动展开循环的条件。而在展开完成后,nvcc 会对所有的访存以及计算指令重排得到一个不错的汇编指令排列。

而加上了数据预存取,利用 Prefetch 的思想,隐藏 Global Memory 读入中间寄存器、将来自 Global Memory 的数据块写入 Shared Memory、从 Shared Memory 中读出数据块的访存延迟,以免计算单元因为 stall 而空闲太久,最终的伪代码如下所示::

#define TILE_K 16
    __shared__ float4 smemA[2][TILE_K * 128 / 4];
    __shared__ float4 smemB[2][TILE_K * 128 / 4];
    float4 c[8][2] = {{make_float4(0.f, 0.f, 0.f, 0.f)}};
    float4 ldg_a_reg[2];
    float4 ldg_b_reg[2];
    float4 a_reg[2][2];
    float4 b_reg[2][2];

    // transfer first tile from global mem to shared mem
    load_gmem_tile_to_reg(A, 0, ldg_a_reg);
    load_gmem_tile_to_reg(B, 0, ldg_b_reg);

    store_reg_to_smem_tile_transpose(ldg_a_reg, 0, smemA[0]);
    store_reg_to_smem_tile(ldg_b_reg, 0, smemB[0]);
    __syncthreads();

    // load first tile from shared mem to register 
    load_smem_tile_to_reg(smemA[0], 0, a_reg[0]);
    load_smem_tile_to_reg(smemB[0], 0, b_reg[0]);

    int write_stage_idx = 1; //ping pong switch
    do {
        i += TILE_K;
        // load next tile from global mem
        load_gmem_tile_to_reg(A, i, ldg_a_reg);
        load_gmem_tile_to_reg(B, i, ldg_b_reg);

        int load_stage_idx = write_stage_idx ^ 1;

    #pragma unroll
        for(int j = 0; j < TILE_K - 1; ++j) {
            // load next tile from shared mem to register 
            load_smem_tile_to_reg(smemA[load_stage_idx], j + 1, a_reg[(j + 1) % 2]);
            load_smem_tile_to_reg(smemB[load_stage_idx], j + 1, b_reg[(j + 1) % 2]);
            // compute matrix multiply accumulate 8x8
            mma8x8(a_reg[j % 2], b_reg[j % 2], c);
        }

        if(i < K) {
            // store next tile to shared mem
            store_reg_to_smem_tile_transpose(ldg_a_reg, 0, smemA[write_stage_idx]);
            store_reg_to_smem_tile(ldg_b_reg, 0, smemB[write_stage_idx]);
            // use double buffer, only need one sync
            __syncthreads();
            // switch
            write_stage_idx ^= 1;
        }

        // load first tile from shared mem to register of next iter
        load_smem_tile_to_reg(smemA[load_stage_idx ^ 1], 0, a_reg[0]);
        load_smem_tile_to_reg(smemB[load_stage_idx ^ 1], 0, b_reg[0]);
        // compute last tile mma 8x8
        mma8x8(a_reg[1], b_reg[1], c);
    } while (i < K);

    store_c(c, C);

你可能感兴趣的:(算法,c++,矩阵)