本文在intel平台上对矩阵乘进行优化,主要依靠调整内存排布(for cache friendly)、SIMD(SSE)、多线程等方法。A,B,C矩阵大小分别为MK,KN,MN。文中性能数据均为M=N=K=1024下循环T次下的平均性能,完整代码见最后。下面给出公式和示意图,可以对照着理解代码。
C m , n = ∑ n = 1 K A m , k ∗ B k , n C_{m,n}=\sum_{n=1}^KA_{m,k}*B_{k,n} Cm,n=n=1∑KAm,k∗Bk,n
完全按照数学定义来实现,无任何优化。以此为起点一步步尝试开始优化。注意一下函数后缀MNK表示三层循环的顺序。
void gemm_v0_MNK(const float *A, const float *B, float *C, int M, int N, int K)
{
memset(C, 0, M*N*sizeof(float));
for (int m = 0; m < M; m++)
{
for (int n = 0; n < N; n++)
{
for (int k = 0; k < K; k++)
{
C[m*N + n] += A[m*K + k] * B[k*N + n];
}
}
}
return;
}
在我的i5-7400上耗时4289ms。
在v0的实现中,最内存循环对B矩阵是按照列的方向访存的,这样在B矩阵的宽度较大时很容易cache miss。调整一下循环顺序,对B矩阵按照行方向来访问。
void gemm_v1_MKN(const float *A, const float *B, float *C, int M, int N, int K)
{
memset(C, 0, M*N*sizeof(float));
for (int m = 0; m < M; m++)
{
for (int k = 0; k < K; k++)
{
float a = A[m*K + k];
for (int n = 0; n < N; n++)
{
C[m*N + n] += a* B[k*N + n];
}
}
}
return;
}
调整顺序后,最内层循环变为一个某一个A的单值和B的一行相乘再加回对应的C的位置。可以看到C的同一个位置会被写回K次。这个版本耗时1790ms。
v1中会有C多次写回的问题。换一个思路,对B矩阵进行转置这样曾经对B读取一列就变成了读取一行。此时三层循环的顺序为MNK。
void transpose(const float *A, float *B, int M, int N)
{
for (int n = 0; n < N; n++)
{
for (int m = 0; m < M; m++)
{
B[n*M + m] = A[N*m + n];
}
}
}
void gemm_v2_MNK_transposeB(const float *A, const float *B, float *C, int M, int N, int K)
{
for (int m = 0; m < M; m++)
{
for (int n = 0; n < N; n++)
{
float sum = 0.0f;
for (int k = 0; k < K; k++)
{
sum += A[m*K + k] * B[n*K + k];
}
C[m*N + n] = sum;
}
}
return;
}
这一版本矩阵转置加乘加的性能一共为1620ms。
还是为了优化cache命中率。我们把矩阵分成若干个小矩阵,小矩阵的尺寸是足够被L1 cache缓存的。分块的大小和具体机器cache大小有关,甚至和矩阵的规模有关。可以参考这篇博客。其实当问题无法完全建模时求各种限制条件下的最优解时,穷举法暴力最优参数(在这个问题里是分块大小这个参数的选择)也是种常见的优化方法。
inline void do_block(const float *A, const float *B, float *C, int K, int N, int BLOCKSIZE)
{
for (int m = 0; m < BLOCKSIZE; m++)
{
for (int n = 0; n < BLOCKSIZE; n++)
{
float c = C[m*N + n];
for (int k = 0; k < BLOCKSIZE; k++)
c += A[m*K + k] * B[k*N + n];
C[m*N + n] = c;
}
}
}
// 矩阵分块乘法
void dgemm_block(const float *A, const float *B, float *C, int M, int N, int K)
{
const int BLOCKSIZE = 64;
memset(C, 0, M*N*sizeof(float));
for (int m = 0; m < M; m += BLOCKSIZE)
{
for (int n = 0; n < N; n += BLOCKSIZE)
{
for (int k = 0; k < K; k += BLOCKSIZE)
{
do_block(A + m*K + k, B + k*N + n, C + m*N + n, K, N, BLOCKSIZE);
}
}
}
return;
}
本版本性能为1633ms。到这里为止都是对内存访问进行优化,期望提升cache命中率。不过上面的几份代码效率都不太高,比如index的重复计算、没有循环展开等。后面的代码会把这些细节完善起来。分块矩阵乘的实现也可以不用函数,减少出栈入栈的时间。
在v1中最内层循环中A矩阵中的一个值乘以B矩阵中的一整行。可以比较直观的用向量化来实现。具体见代码和注释。
void gemm_v1_MKN_SSE(const float *A, const float *B, float *C, int M, int N, int K)
{
memset(C, 0, M*N*sizeof(float));
int m, n, k;
for (m = 0; m < M; m++)
{
for (k = 0; k < K; k++)
{
__m128 v4_a = _mm_set1_ps(*(A + m*K + k));// Am,k Am,k Am,k Am,k
for (n = 0; n < N - 3; n += 4)
{
__m128 v4_b = _mm_loadu_ps(B + k*N + n); // Bk,n Bk,n+1 Bk,n+2 Bk,n+3
__m128 v4_c = _mm_loadu_ps(C + m*N + n);
_mm_storeu_ps(C + m*N + n, _mm_add_ps(v4_c, _mm_mul_ps(v4_a, v4_b)));
}
for (; n < N; n++)
{
C[m*N + n] += A[m*K + k] * B[k*N + n];
}
}
}
return;
}
性能为794ms,对比v1的1794ms差不多优化了一倍。仅仅简单的应用下SSE对性能的提升也还是可观的。
为了编译器可以更好的排软件流水,减少循环判断次数,并且减少对C矩阵的写回操作,我们在v4的基础上改成最内层循环一次做4行。
void gemm_v1_MKN_SSE_UNROLL(const float *A, const float *B, float *C, int M, int N, int K)
{
memset(C, 0, M*N*sizeof(float));
int m, n, k;
for (m = 0; m < M; m++)
{
for (k = 0; k < K - 3; k += 4)
{
__m128 v4_a0 = _mm_set1_ps(*(A + m*K + k));
__m128 v4_a1 = _mm_set1_ps(*(A + m*K + k + 1));
__m128 v4_a2 = _mm_set1_ps(*(A + m*K + k + 2));
__m128 v4_a3 = _mm_set1_ps(*(A + m*K + k + 3));
for (n = 0; n < N - 3; n += 4)
{
__m128 v4_b0 = _mm_loadu_ps(B + k*N + n);
__m128 v4_b1 = _mm_loadu_ps(B + k*N + n + N);
__m128 v4_b2 = _mm_loadu_ps(B + k*N + n + 2 * N);
__m128 v4_b3 = _mm_loadu_ps(B + k*N + n + 3 * N);
__m128 v4_c = _mm_loadu_ps(C + m*N + n);
v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a0, v4_b0));
v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a1, v4_b1));
v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a2, v4_b2));
v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a3, v4_b3));
_mm_storeu_ps(C + m*N + n, v4_c);
}
for (; n < N; n++)
{
C[m*N + n] += A[m*K + k] * B[k*N + n];
C[m*N + n] += A[m*K + k + 1] * B[(k + 1)*N + n];
C[m*N + n] += A[m*K + k + 2] * B[(k + 2)*N + n];
C[m*N + n] += A[m*K + k + 3] * B[(k + 3)*N + n];
}
}
for (; k < K; k++)
{
__m128 v4_a0 = _mm_set1_ps(*(A + m*K + k));
for (n = 0; n < N - 3; n += 4)
{
__m128 v4_b = _mm_loadu_ps(B + k*N + n);
__m128 v4_c = _mm_loadu_ps(C + m*N + n);
_mm_storeu_ps(C + m*N + n, _mm_add_ps(v4_c, _mm_mul_ps(v4_a0, v4_b)));
}
float a = A[m*K + k];
for (; n < N; n++)
{
C[m*N + n] += a* B[k*N + n];
}
}
}
return;
}
实现的时候注意下要处理A矩阵高度不是4整除时的情况。可以看到**_mm_storeu_ps的指令减少到了原来的1/4。本版性能为463ms**。又优化了一倍,继续向下优化吧。
虽然v5已经有了一定的性能提升,但是如之前分析的,这种计算流程会让C有很多的写回操作。再来看看转置版本上SSE优化的结果吧。
void gemm_v2_MNK_SSE_UNROLL(const float *A, const float *B, float *C, int M, int N, int K)
{
int k = 0, n = 0;
__m128 v4_1_ps = _mm_set1_ps(1.0f);
__m128 v4_sum_tmp_ps, v4_sumv_tmp_ps;
for (int m = 0; m < M; m++)
{
for (n = 0; n < N - 3; n += 4)
{
float sum0, sum1, sum2, sum3;
__m128 v4_sum0 = _mm_setzero_ps();
__m128 v4_sum1 = _mm_setzero_ps();
__m128 v4_sum2 = _mm_setzero_ps();
__m128 v4_sum3 = _mm_setzero_ps();
sum0 = sum1 = sum2 = sum3 = 0.0f;
for (k = 0; k < K - 3; k += 4)
{
__m128 a = _mm_loadu_ps(A + m*K + k);
__m128 b0 = _mm_loadu_ps(B + n*K + k);
__m128 b1 = _mm_loadu_ps(B + n*K + k + K);
__m128 b2 = _mm_loadu_ps(B + n*K + k + 2 * K);
__m128 b3 = _mm_loadu_ps(B + n*K + k + 3 * K);
v4_sum0 = _mm_add_ps(v4_sum0, _mm_mul_ps(a, b0));
v4_sum1 = _mm_add_ps(v4_sum1, _mm_mul_ps(a, b1));
v4_sum2 = _mm_add_ps(v4_sum2, _mm_mul_ps(a, b2));
v4_sum3 = _mm_add_ps(v4_sum3, _mm_mul_ps(a, b3));
}
for (; k < K; k++)
{
sum0 += A[m*K + k] * B[n*K + k];
sum1 += A[m*K + k] * B[n*K + k + K];
sum2 += A[m*K + k] * B[n*K + k + 2 * k];
sum3 += A[m*K + k] * B[n*K + k + 3 * k];
}
v4_sum_tmp_ps = _mm_setr_ps(sum0, sum1, sum2, sum3);
//v4_sumv_tmp_ps.m128_f32[0] = v4_sum0.m128_f32[0] + v4_sum0.m128_f32[1] + v4_sum0.m128_f32[2] + v4_sum0.m128_f32[3];
v4_sumv_tmp_ps = _mm_dp_ps(v4_sum0, v4_1_ps, 0xF1);
v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);
v4_sumv_tmp_ps = _mm_dp_ps(v4_sum1, v4_1_ps, 0xF2);
v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);
v4_sumv_tmp_ps = _mm_dp_ps(v4_sum2, v4_1_ps, 0xF4);
v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);
v4_sumv_tmp_ps = _mm_dp_ps(v4_sum3, v4_1_ps, 0xF8);
v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);
_mm_storeu_ps(C + m*N + n, v4_sum_tmp_ps);
}//end for n=0~N-3
for (; n < N; n++)
{
float sum0;
__m128 v4_sum0 = _mm_setzero_ps();
sum0 = 0.0f;
for (k = 0; k < K - 3; k += 4)
{
__m128 a = _mm_loadu_ps(A + m*K + k);
__m128 b0 = _mm_loadu_ps(B + n*K + k);
v4_sum0 = _mm_add_ps(v4_sum0, _mm_mul_ps(a, b0));
}
for (; k < K; k++)
{
sum0 += A[m*K + k] * B[n*K + k];
}
C[m*N + n] = sum0 + v4_sum0.m128_f32[0] + v4_sum0.m128_f32[1] + v4_sum0.m128_f32[2] + v4_sum0.m128_f32[3];
}//end for n=N-3~N
}// end for m
return;
}
性能为451ms。这份代码有一个不舒服的地方在于,循环结束后需要把向量v4_sum_tmp_ps 中的四个通道值累加。这个操作效率很低(DSP平台上可能更加明显),大部分向量化(SIMD)优化的代码都会避免这样的操作。这里的实现用了_mm_dp_ps
点积指令来实现的,即向量乘以1后累加到mask指定的某个位置。
这部分优化前,我们先重新定义一个矩阵转置:以1x4为基本单位进行转置。
举例:
下面代码中数字表示一个1x4向量。比如原矩阵为
1 5
2 6
3 7
4 8
转置后为
1 2 3 4
5 6 7 8
但是各个数字表示的向量中四个元素的排序还和原矩阵一致。在这种内存排布下,我们用A中单值和B矩阵的一行相乘可得到四个结果。和v6相比,消除了向量内加法的操作。
也可以这么理解,矩阵乘法中一次处理B矩阵的四列(这四列组成一个1x4的向量和A矩阵中的同一个单值相乘),所以我们转置的时候以这四列为单位进行转置即可。显而易见,这种转置要求N为4的倍数。
// 向量转置vector4版本,注意转置后矩阵宽高的变化
// M*N -> 1/4N*4M
void transpose_vec4(const float *A, float *B, int M, int N)
{
int m, n;
for (m = 0; m < M; m++)
{
for (n = 0; n < N; n += 4)
{
__m128 a = _mm_loadu_ps(A + m*N + n);
_mm_storeu_ps(B + n*M + (m << 2), a);
}
}
}
// 4大小向量转置B矩阵乘法
void gemm_v2_MNK_SSE_UNROLL_TRANSPOSEV4(const float *A, const float *B, float *C, int M, int N, int K)
{
assert(0 == N % 4);
for (int m = 0; m < M; m++)
{
for (int n = 0; n < N; n += 4)
{
__m128 v4_sum = _mm_set1_ps(0.0f);
const float* pA = A + m*K;
const float* pB = B + n*K;
int k;
for (k = 0; k < K - 3; k += 4)
{
__m128 v4_a0 = _mm_load1_ps(pA);
__m128 v4_a1 = _mm_load1_ps(pA + 1);
__m128 v4_a2 = _mm_load1_ps(pA + 2);
__m128 v4_a3 = _mm_load1_ps(pA + 3);
__m128 v4_b0 = _mm_loadu_ps(pB);
__m128 v4_b1 = _mm_loadu_ps(pB + 4);
__m128 v4_b2 = _mm_loadu_ps(pB + 8);
__m128 v4_b3 = _mm_loadu_ps(pB + 12);
__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
v4_sum = _mm_add_ps(v4_sum, v4_c);
v4_c = _mm_mul_ps(v4_a1, v4_b1);
v4_sum = _mm_add_ps(v4_sum, v4_c);
v4_c = _mm_mul_ps(v4_a2, v4_b2);
v4_sum = _mm_add_ps(v4_sum, v4_c);
v4_c = _mm_mul_ps(v4_a3, v4_b3);
v4_sum = _mm_add_ps(v4_sum, v4_c);
pA += 4;
pB += 16;
}
for (; k < K; k++)
{
__m128 v4_a0 = _mm_load1_ps(pA);
__m128 v4_b0 = _mm_loadu_ps(pB);
__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
v4_sum = _mm_add_ps(v4_sum, v4_c);
pA += 1;
pB += 4;
}
_mm_storeu_ps(C + m*N + n, v4_sum);
}
}
return;
}
我们对k做了展开,一次做四行。最终性能为449ms。
直接在v7版本上加上omp。
// 4大小向量转置矩阵乘法+OMP
void gemm_v2_MNK_SSE_UNROLL_TRANSPOSEV4_OMP(const float *A, const float *B, float *C, int M, int N, int K)
{
assert(0 == N % 4);
#ifdef _OPENMP
omp_set_num_threads(4);
#pragma omp parallel for
#endif
for (int m = 0; m < M; m++)
{
for (int n = 0; n < N; n += 4)
{
__m128 v4_sum = _mm_set1_ps(0.0f);
const float* pA = A + m*K;
const float* pB = B + n*K;
int k;
for (k = 0; k < K - 3; k += 4)
{
__m128 v4_a0 = _mm_load1_ps(pA);
__m128 v4_b0 = _mm_loadu_ps(pB);
__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
v4_sum = _mm_add_ps(v4_sum, v4_c);
__m128 v4_a1 = _mm_load1_ps(pA + 1);
__m128 v4_b1 = _mm_loadu_ps(pB + 4);
v4_c = _mm_mul_ps(v4_a1, v4_b1);
v4_sum = _mm_add_ps(v4_sum, v4_c);
__m128 v4_a2 = _mm_load1_ps(pA + 2);
__m128 v4_b2 = _mm_loadu_ps(pB + 8);
v4_c = _mm_mul_ps(v4_a2, v4_b2);
v4_sum = _mm_add_ps(v4_sum, v4_c);
__m128 v4_a3 = _mm_load1_ps(pA + 3);
__m128 v4_b3 = _mm_loadu_ps(pB + 12);
v4_c = _mm_mul_ps(v4_a3, v4_b3);
v4_sum = _mm_add_ps(v4_sum, v4_c);
pA += 4;
pB += 16;
}
for (; k < K; k++)
{
__m128 v4_a0 = _mm_load1_ps(pA);
__m128 v4_b0 = _mm_loadu_ps(pB);
__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
v4_sum = _mm_add_ps(v4_sum, v4_c);
pA += 1;
pB += 4;
}
_mm_storeu_ps(C + m*N + n, v4_sum);
}
}
return;
}
性能为108ms。
我们再试一下做更多的展开是否能进一步提升性能,相比v8版本这个版本又对m做了2为单位的展开。
void gemm_v2_MNK_SSE_UNROLL2_TRANSPOSEV4_OMP(const float *A, const float *B, float *C, int M, int N, int K)
{
#define CAL_ROWX(x) \
v4_c = _mm_mul_ps(v4_a0, v4_b0); \
v4_sum##x = _mm_add_ps(v4_sum##x, v4_c); \
v4_c = _mm_mul_ps(v4_a1, v4_b1); \
v4_sum##x = _mm_add_ps(v4_sum##x, v4_c); \
v4_c = _mm_mul_ps(v4_a2, v4_b2); \
v4_sum##x = _mm_add_ps(v4_sum##x, v4_c); \
v4_c = _mm_mul_ps(v4_a3, v4_b3); \
v4_sum##x = _mm_add_ps(v4_sum##x, v4_c);
assert(0 == N % 4);
int m = 0;
#ifdef _OPENMP
omp_set_num_threads(4);
#pragma omp parallel for lastprivate(m)
#endif
for (m = 0; m < M - 1; m += 2)
{
for (int n = 0; n < N; n += 4)
{
__m128 v4_sum0 = _mm_set1_ps(0.0f);
__m128 v4_sum1 = v4_sum0;
const float* pA0 = A + m*K;
const float* pA1 = A + m*K + K;
const float* pB = B + n*K;
int k;
for (k = 0; k < K - 3; k += 4)
{
__m128 v4_c;
// row0
__m128 v4_a0 = _mm_load1_ps(pA0);
__m128 v4_a1 = _mm_load1_ps(pA0 + 1);
__m128 v4_a2 = _mm_load1_ps(pA0 + 2);
__m128 v4_a3 = _mm_load1_ps(pA0 + 3);
__m128 v4_b0 = _mm_loadu_ps(pB);
__m128 v4_b1 = _mm_loadu_ps(pB + 4);
__m128 v4_b2 = _mm_loadu_ps(pB + 8);
__m128 v4_b3 = _mm_loadu_ps(pB + 12);
CAL_ROWX(0)
// row1
v4_a0 = _mm_load1_ps(pA1);
v4_a1 = _mm_load1_ps(pA1 + 1);
v4_a2 = _mm_load1_ps(pA1 + 2);
v4_a3 = _mm_load1_ps(pA1 + 3);
CAL_ROWX(1)
pA0 += 4;
pA1 += 4;
pB += 16;
}
for (; k < K; k++)
{
__m128 v4_a0 = _mm_load1_ps(pA0);
__m128 v4_a1 = _mm_load1_ps(pA1);
__m128 v4_b0 = _mm_loadu_ps(pB);
// row0
__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
v4_sum0 = _mm_add_ps(v4_sum0, v4_c);
// row1
v4_c = _mm_mul_ps(v4_a1, v4_b0);
v4_sum1 = _mm_add_ps(v4_sum1, v4_c);
pA0++;
pA1++;
pB += 4;
}
_mm_storeu_ps(C + m*N + n, v4_sum0);
_mm_storeu_ps(C + m*N + N + n, v4_sum1);
}
}
// m = M&(-1)
for (; m < M; m++)
{
for (int n = 0; n < N; n += 4)
{
__m128 v4_sum0 = _mm_set1_ps(0.0f);
__m128 v4_c;
const float* pA0 = A + m*K;
const float* pB = B + n*K;
int k;
for (k = 0; k < K - 3; k += 4)
{
// row0
__m128 v4_a0 = _mm_load1_ps(pA0);
__m128 v4_a1 = _mm_load1_ps(pA0 + 1);
__m128 v4_a2 = _mm_load1_ps(pA0 + 2);
__m128 v4_a3 = _mm_load1_ps(pA0 + 3);
__m128 v4_b0 = _mm_loadu_ps(pB);
__m128 v4_b1 = _mm_loadu_ps(pB + 4);
__m128 v4_b2 = _mm_loadu_ps(pB + 8);
__m128 v4_b3 = _mm_loadu_ps(pB + 12);
CAL_ROWX(0)
pA0 += 4;
pB += 16;
}
for (; k < K; k++)
{
__m128 v4_a0 = _mm_load1_ps(pA0);
__m128 v4_b0 = _mm_loadu_ps(pB);
// row0
__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
v4_sum0 = _mm_add_ps(v4_sum0, v4_c);
pA0++;
pB += 4;
}
_mm_storeu_ps(C + m*N + n, v4_sum0);
}
}
return;
}
这里注意下对omp对m变量的私有声明。不过不幸的是性能反而变差了,性能变成150ms。大部分情况下展开过多导致的负优化都是因为寄存器溢出,这一点可以通过反汇编来验证。
所以怎么知道到底展开多少合适呢,因为展开的时候大部分逻辑都是重复的,只是处理的数据不同。可以对变量命名做参数化,然后用这版本中类似CAL_ROWX(x)这样的宏来总结这些重复代码,甚至变量声明也用宏来代替,宏的输入为行号列号等变量。
所以总结的够好的话,展开几次就是换成调用几次宏而已。在此基础上可进一步写个自动生成优化代码的程序,生成N份不同展开组合的实现,穷举到底哪种展开最好(不过一般选出来的最优解比次优解不会强很多)。
缺点是:代码可读性下降,单步调试变难(宏内逻辑不能单步debug)。所以这里v9我也就偷懒没写成全参数化的样子了,毕竟还是以讲原理为目的。
还有第十一个版本,就是调用openblas的矩阵乘。结果发现比我们最快的版本还要快个2~3倍,大概4、50ms。我们这里少了汇编级别调整,而且分块矩阵乘法也没合进最后的版本中去。
本文利用粗浅的优化知识对矩阵乘法进行了一系列优化。后续可能在本篇基础上再做一些精细的优化。若想学习生产环境中的优化代码可以参考各大厂开源的DL推理库。
1.SSE指令查询1
2.SSE指令查询2
完整代码
note:在main.cpp中修改select数组的值来选择测试哪些版本的矩阵乘。
本人水平有限,有理解不对的地方欢迎指正,共同进步。