实现矩阵乘法 C = A ∗ B C=A*B C=A∗B,其中, A A A, B B B, C C C 是 N ∗ N N*N N∗N 的单精度稠密矩阵。本实验中矩阵均为column major。
华为鲲鹏920:aarch64架构,64核CPU,CPU最高工作频率2600MHz。
L1d cache:64KB
L1i cache:64KB
L2 cache:512KB
L3 cache:32768KB
Page size:65536Byte
测试用例中我们选取的矩阵规模为 n ∈ { 32 ∗ k ± 1 , 32 ∣ 1 ≤ k ≤ 32 } n \in \{32*k\pm1, 32|1\le k\le 32\} n∈{32∗k±1,32∣1≤k≤32} 。
最简单粗暴的算法就是先按行遍历再按列遍历,分别计算 C i j C_{ij} Cij。在编译过程中,我们设置编译器不做任何优化。
void square_gemm (int n, float* A, float* B, float* C)
{
/* For each row i of A */
for (int i = 0; i < n; ++i)
/* For each column j of B */
for (int j = 0; j < n; ++j)
{
/* Compute C(i,j) */
float cij = C[i+j*n];
for( int k = 0; k < n; k++ )
cij += A[i+k*n] * B[k+j*n];
C[i+j*n] = cij;
}
}
该程序的性能如下图所示,其平均性能为0.33Gflops/s。
在Step 0的基础上加上了O3优化,以及-fomit-frame-pointer -march=armv8-a -ffast-math -mtune=tsv110编译选项。在编译器对代码进行自动优化后,程序的性能有了明显提升,如下图所示,平均浮点运算速度为2.47Gflops/s。但是程序的性能不太稳定,尤其是在矩阵规模是32的倍数的时候,性能反而下降明显。
在ARM-v8中有32个128位定长寄存器,每个寄存器可以存4个单精度浮点数,支持SIMD向量化操作。利用这一特性,我们可以四个四个地计算矩阵 C C C中的元素。
#include "arm_neon.h"
#define A(i,j) a[ (j)*n + (i) ]
#define B(i,j) b[ (j)*n + (i) ]
#define C(i,j) c[ (j)*n + (i) ]
void solution_1 (int n, float* a, float* b, float* c){
int i, j;
for (j = 0; j < n; j++){
for (i = 0; i < ((n) & (~3)); i+=4){
float32x4_t buf = vld1q_f32(&C(i, j));
for (int k = 0; k < n; k++){
float32x4_t va = vld1q_f32(&A(i, k));
register float vb = B(k, j);
buf = vmlaq_n_f32(buf, va, vb);
}
vst1q_f32(&C(i, j), buf);
}
for (; i < n; i++){//deal with boundaries
register float temp = C(i, j);
for (int k = 0; k < n; k++){
temp += A(i, k) * B(k, j);
}
C(i, j) = temp;
}
}
}
加入SIMD向量化操作之后,程序的性能如下图所示,平均浮点运算速度达到3.68Gflops/s。
在Step2中,矩阵 B B B中的每个元素在被load后只被使用了一次,为了提高矩阵B中元素的使用率,我们可以每次load矩阵 B B B中相邻4列的元素,进而通过对矩阵 A A A中的 4 × k 4\times k 4×k 的子矩阵和矩阵 B B B中的 k × 4 k\times 4 k×4 的子矩阵进行相乘,得到矩阵 C C C中的大小为 4 × 4 4\times 4 4×4的子矩阵。
#include "arm_neon.h"
#define A(i,j) a[ (j)*n + (i) ]
#define B(i,j) b[ (j)*n + (i) ]
#define C(i,j) c[ (j)*n + (i) ]
//computing (4xk)x(kx4) dot product
void add_dot_4x4 (int n, int k, float* a, float* b, float* c){
float *b_ptr_0, *b_ptr_1, *b_ptr_2, *b_ptr_3;
b_ptr_0 = &B(0, 0);
b_ptr_1 = &B(0, 1);
b_ptr_2 = &B(0, 2);
b_ptr_3 = &B(0, 3);
float32x4_t c_sum_0 = {0};
float32x4_t c_sum_1 = {0};
float32x4_t c_sum_2 = {0};
float32x4_t c_sum_3 = {0};
register float b_reg_0, b_reg_1, b_reg_2, b_reg_3;
for (int p = 0; p < k; p++){
float32x4_t a_reg = vld1q_f32(&A(0, p));
b_reg_0 = *(b_ptr_0++);
b_reg_1 = *(b_ptr_1++);
b_reg_2 = *(b_ptr_2++);
b_reg_3 = *(b_ptr_3++);
c_sum_0 = vmlaq_n_f32(c_sum_0, a_reg, b_reg_0);
c_sum_1 = vmlaq_n_f32(c_sum_1, a_reg, b_reg_1);
c_sum_2 = vmlaq_n_f32(c_sum_2, a_reg, b_reg_2);
c_sum_3 = vmlaq_n_f32(c_sum_3, a_reg, b_reg_3);
}
float *c_ptr = 0;
c_ptr = &C(0, 0);
float32x4_t c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_0);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(0, 1);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_1);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(0, 2);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_2);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(0, 3);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_3);
vst1q_f32(c_ptr, c_reg);
}
void solution_2 (int n, float* a, float* b, float* c){
int i, j, k;
for (j = 0; j < ((n) & (~3)); j+=4){
for (i = 0; i < ((n) & (~3)); i+=4){
add_dot_4x4(n, n, &A(i, 0), &B(0, j), &C(i, j));
}
for (; i < n; i++){
register float c_0, c_1, c_2, c_3;
c_0 = C(i, j);
c_1 = C(i, j + 1);
c_2 = C(i, j + 2);
c_3 = C(i, j + 3);
for (int k = 0; k < n; k++){
c_0 += A(i, k) * B(k, j);
c_1 += A(i, k) * B(k, j + 1);
c_2 += A(i, k) * B(k, j + 2);
c_3 += A(i, k) * B(k, j + 3);
}
C(i, j) = c_0;
C(i, j + 1) = c_1;
C(i, j + 2) = c_2;
C(i, j + 3) = c_3;
}
}
for (; j < n; j++){
for (i = 0; i < ((n) & (~3)); i+=4){
float32x4_t buf = vld1q_f32(&C(i, j));
for (int k = 0; k < n; k++){
float32x4_t va = vld1q_f32(&A(i, k));
register float vb = B(k, j);
buf = vmlaq_n_f32(buf, va, vb);
}
vst1q_f32(&C(i, j), buf);
}
for (; i < n; i++){
float temp = C(i, j);
for (int k = 0; k < n; k++){
temp += A(i, k) * B(k, j);
}
C(i, j) = temp;
}
}
}
在提高了矩阵 B B B元素的访问效率后,程序的性能再一次得到大幅度的提升,平均浮点运算速度达到7.72Gflops/s。
将add_dot_4x4
函数中对 k k k的循环四个四个进行展开,改进后的程序性能变化不明显,平均浮点运算速度为7.84Gflops/s。
考虑到如果每次计算 C C C中 4 × 4 4\times 4 4×4大小的子矩阵,那么只需要用到 4 + 1 = 5 4+1=5 4+1=5个128bit定长寄存器,这对于定长128bit寄存器而言是一种浪费。为了能用上更多的寄存器,我们改为每次计算 8 × 8 8\times 8 8×8大小的子矩阵,这样就需要使用 16 + 2 = 18 16+2=18 16+2=18个128bit定长寄存器。
其核心部分代码如下:
#include "arm_neon.h"
#define A(i,j) a[ (j)*n + (i) ]
#define B(i,j) b[ (j)*n + (i) ]
#define C(i,j) c[ (j)*n + (i) ]
void add_dot_8x8 (int n, int k, float* a, float* b, float* c){
float *b_ptr_0, *b_ptr_1, *b_ptr_2, *b_ptr_3;
float *b_ptr_4, *b_ptr_5, *b_ptr_6, *b_ptr_7;
b_ptr_0 = &B(0, 0);
b_ptr_1 = &B(0, 1);
b_ptr_2 = &B(0, 2);
b_ptr_3 = &B(0, 3);
b_ptr_4 = &B(0, 4);
b_ptr_5 = &B(0, 5);
b_ptr_6 = &B(0, 6);
b_ptr_7 = &B(0, 7);
float32x4_t c_sum_00 = {0};
float32x4_t c_sum_01 = {0};
float32x4_t c_sum_02 = {0};
float32x4_t c_sum_03 = {0};
float32x4_t c_sum_04 = {0};
float32x4_t c_sum_05 = {0};
float32x4_t c_sum_06 = {0};
float32x4_t c_sum_07 = {0};
float32x4_t c_sum_40 = {0};
float32x4_t c_sum_41 = {0};
float32x4_t c_sum_42 = {0};
float32x4_t c_sum_43 = {0};
float32x4_t c_sum_44 = {0};
float32x4_t c_sum_45 = {0};
float32x4_t c_sum_46 = {0};
float32x4_t c_sum_47 = {0};
register float b_reg_0, b_reg_1, b_reg_2, b_reg_3;
register float b_reg_4, b_reg_5, b_reg_6, b_reg_7;
for (int p = 0; p < k; p++){
float32x4_t a_reg_0, a_reg_4;
a_reg_0 = vld1q_f32(&A(0, p));
a_reg_4 = vld1q_f32(&A(4, p));
b_reg_0 = *(b_ptr_0++);
b_reg_1 = *(b_ptr_1++);
b_reg_2 = *(b_ptr_2++);
b_reg_3 = *(b_ptr_3++);
b_reg_4 = *(b_ptr_4++);
b_reg_5 = *(b_ptr_5++);
b_reg_6 = *(b_ptr_6++);
b_reg_7 = *(b_ptr_7++);
c_sum_00 = vmlaq_n_f32(c_sum_00, a_reg_0, b_reg_0);
c_sum_01 = vmlaq_n_f32(c_sum_01, a_reg_0, b_reg_1);
c_sum_02 = vmlaq_n_f32(c_sum_02, a_reg_0, b_reg_2);
c_sum_03 = vmlaq_n_f32(c_sum_03, a_reg_0, b_reg_3);
c_sum_04 = vmlaq_n_f32(c_sum_04, a_reg_0, b_reg_4);
c_sum_05 = vmlaq_n_f32(c_sum_05, a_reg_0, b_reg_5);
c_sum_06 = vmlaq_n_f32(c_sum_06, a_reg_0, b_reg_6);
c_sum_07 = vmlaq_n_f32(c_sum_07, a_reg_0, b_reg_7);
c_sum_40 = vmlaq_n_f32(c_sum_40, a_reg_4, b_reg_0);
c_sum_41 = vmlaq_n_f32(c_sum_41, a_reg_4, b_reg_1);
c_sum_42 = vmlaq_n_f32(c_sum_42, a_reg_4, b_reg_2);
c_sum_43 = vmlaq_n_f32(c_sum_43, a_reg_4, b_reg_3);
c_sum_44 = vmlaq_n_f32(c_sum_44, a_reg_4, b_reg_4);
c_sum_45 = vmlaq_n_f32(c_sum_45, a_reg_4, b_reg_5);
c_sum_46 = vmlaq_n_f32(c_sum_46, a_reg_4, b_reg_6);
c_sum_47 = vmlaq_n_f32(c_sum_47, a_reg_4, b_reg_7);
}
float *c_ptr = 0;
c_ptr = &C(0, 0);
float32x4_t c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_00);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(0, 1);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_01);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(0, 2);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_02);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(0, 3);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_03);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(0, 4);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_04);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(0, 5);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_05);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(0, 6);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_06);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(0, 7);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_07);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(4, 0);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_40);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(4, 1);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_41);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(4, 2);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_42);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(4, 3);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_43);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(4, 4);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_44);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(4, 5);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_45);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(4, 6);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_46);
vst1q_f32(c_ptr, c_reg);
c_ptr = &C(4, 7);
c_reg = vld1q_f32(c_ptr);
c_reg = vaddq_f32(c_reg, c_sum_47);
vst1q_f32(c_ptr, c_reg);
}
改进后的程序平均浮点运算速度为10.44Gflops/s。虽然性能相较之前有较大提升,但是随着矩阵规模的增大,程序的性能下降明显。这是由于在对矩阵 C C C的同一行不同列的子矩阵进行计算时,矩阵 A A A中相同的 8 × k 8\times k 8×k 大小的block会被加载两次,并且 8 × k 8\times k 8×k 大小的block在内存中是不连续的,因此每一次加载都会造成一定数量的cache miss,这一现象随着矩阵规模的增大越来越明显,降低了程序运行的效率。对于矩阵 B B B也存在类似的问题,不过由于矩阵 B B B中 k × 8 k\times 8 k×8 大小的block在内存中是连续的,因此问题会小很多。
为了解决上述问题,我们在首次访问矩阵 A A A中 8 × k 8\times k 8×k 大小的block以及矩阵 B B B中 k × 8 k\times 8 k×8 大小的block时,会将其packing至某片连续的内存区域,这样在下次访问的时候,访问的就是一片连续的内存,理论上可以降低cache miss的次数。核心代码如下(只对矩阵 A A A进行了packing):
#include "arm_neon.h"
#define A(i,j) a[ (j)*n + (i) ]
#define B(i,j) b[ (j)*n + (i) ]
#define C(i,j) c[ (j)*n + (i) ]
void PackMatrixA_8x8( int k, float *a, int n, float *a_to ){
int j;
for(j = 0; j < k; j++){ /* loop over columns of A */
float *a_ij_pntr = &A(0, j);
*a_to++ = *a_ij_pntr;
*a_to++ = *(a_ij_pntr + 1);
*a_to++ = *(a_ij_pntr + 2);
*a_to++ = *(a_ij_pntr + 3);
*a_to++ = *(a_ij_pntr + 4);
*a_to++ = *(a_ij_pntr + 5);
*a_to++ = *(a_ij_pntr + 6);
*a_to++ = *(a_ij_pntr + 7);
}
}
void solution_3_packed (int n, float* a, float* b, float* c){
int i, j;
float packedA[((n) & (~7)) * n];
for (j = 0; j < ((n) & (~7)); j+=8){
for (i = 0; i < ((n) & (~7)); i+=8){
if (j == 0)
PackMatrixA_8x8(n, &A(i, 0), n, &packedA[i * n]);
add_dot_8x8_packed(n, K, &packedA[i * n], &B(0, j), &C(i, j));
}
for (; i < n; i++){
register float c_0, c_1, c_2, c_3, c_4, c_5, c_6, c_7;
c_0 = C(i, j);
c_1 = C(i, j + 1);
c_2 = C(i, j + 2);
c_3 = C(i, j + 3);
c_4 = C(i, j + 4);
c_5 = C(i, j + 5);
c_6 = C(i, j + 6);
c_7 = C(i, j + 7);
for (int k = 0; k < K; k++){
c_0 += A(i, k) * B(k, j);
c_1 += A(i, k) * B(k, j + 1);
c_2 += A(i, k) * B(k, j + 2);
c_3 += A(i, k) * B(k, j + 3);
c_4 += A(i, k) * B(k, j + 4);
c_5 += A(i, k) * B(k, j + 5);
c_6 += A(i, k) * B(k, j + 6);
c_7 += A(i, k) * B(k, j + 7);
}
C(i, j) = c_0;
C(i, j + 1) = c_1;
C(i, j + 2) = c_2;
C(i, j + 3) = c_3;
C(i, j + 4) = c_4;
C(i, j + 5) = c_5;
C(i, j + 6) = c_6;
C(i, j + 7) = c_7;
}
for (; j < n; j++){
for (i = 0; i < ((n) & (~7)); i+=8){
float32x4_t buf_0, buf_1;
buf_0 = vld1q_f32(&C(i, j));
buf_1 = vld1q_f32(&C(i + 4, j));
for (int k = 0; k < n; k++){
float32x4_t va_0, va_1;
va_0 = vld1q_f32(&A(i, k));
va_1 = vld1q_f32(&A(i + 4, k));
register float vb = B(k, j);
buf_0 = vmlaq_n_f32(buf_0, va_0, vb);
buf_1 = vmlaq_n_f32(buf_1, va_1, vb);
}
vst1q_f32(&C(i, j), buf_0);
vst1q_f32(&C(i + 4, j), buf_1);
}
for (; i < n; i++){
float temp = C(i, j);
for (int k = 0; k < n; k++){
temp += A(i, k) * B(k, j);
}
C(i, j) = temp;
}
}
}
packing后的性能如下图所示。第一,程序的总性能得到了提高,达到了12.55Gflops/s。第二,由于对不连续的区域进行了packing使其连续,因此程序对于矩阵规模的敏感度下降了,不会出现由于矩阵规模的细微变化造成cache miss显著增加,进而严重影响性能的情况。第三,当矩阵规模较小时,程序的性能有细微下降,这是由于packing的开销造成的;而当矩阵规模增大时,packing的好处显现出来,矩阵规模增大,程序性能不降反增。
为了方便叙述,我们把之前的计算矩阵相乘的过程称为macro kernel,把其中计算 4 × 4 4 \times 4 4×4或 8 × 8 8\times 8 8×8的子矩阵的过程称为micro kernel。当求解的矩阵规模很大时,直接对原矩阵采用macro kernel会造成比较差的程序局域性,因此需要对矩阵先进行blocking,即将矩阵 C C C划分成若干个大小为 M C × N C M_C\times N_C MC×NC的子矩阵,将矩阵 A A A划分成若干个大小为 M C × K C M_C\times K_C MC×KC的子矩阵,将矩阵 B B B划分成若干个大小为 K C × N C K_C\times N_C KC×NC的子矩阵。在子矩阵上用macro kernel进行计算,然后进行累加拼接,得到完整的 C C C矩阵。算法示意图如下所示:
除此以外,对于packing的方法我们也进行了优化:如果遇到矩阵规模不能被8或者4整除的情况,我们会通过补0点方式进行padding,以保证内存地址是对齐的。在整体代码架构上我也进行了一些修改,以增强其可读性和可复用性,完整代码见文件。在代码中分块矩阵的大小十分关键,经过一些简单的调参后发现当 K C = 384 KC=384 KC=384,$ MC=NC=256$时,程序的性能达到最优,平均浮点运算速度为13.07Gflops/s。程序性能相较于没有使用block提升不是很明显,这主要是因为测试用例规模不够大所致。
下图是naive版本的GEMM,本文加速后实现的GEMM以及OpenBLAS中GEMM的性能对比图。本文通过向量化,提高cache利用率等方法对naive版本的GEMM进行加速,使性能提升了接近40倍,但是相较于BLAS的31.27Gflops/s的运算速度还有较大的差距。差距的主要来源应当在于micro kernel的实现部分。本实验仅仅采用了neon intrinsic对micro kernel进行了比较粗糙的优化,经过编译后micro kernel的实现效率远远未达到极限。如果要进行更细致的优化,还需要用内联汇编代码编写micro kernel,更加仔细地操作内存的读写,prefetch,以及寄存器的读写运算。
在与同学交流的过程中发现neon intrinsic函数可以做进一步优化,完整的代码可见https://github.com/xiaoyi-jason/simple_gemm。优化后的性能可以达到接近85%BLAS的水平。接下来我还将从并行计算的角度对GEMM进行进一步的优化,可以期待之后的博客哦。