通用矩阵乘的十种实现(x86平台)

矩阵乘法的十种实现(x86版本)

前言

本文在intel平台上对矩阵乘进行优化,主要依靠调整内存排布(for cache friendly)、SIMD(SSE)、多线程等方法。A,B,C矩阵大小分别为MK,KN,MN。文中性能数据均为M=N=K=1024下循环T次下的平均性能,完整代码见最后。下面给出公式和示意图,可以对照着理解代码。

C m , n = ∑ n = 1 K A m , k ∗ B k , n C_{m,n}=\sum_{n=1}^KA_{m,k}*B_{k,n} Cm,n=n=1KAm,kBk,n

v0.严格按照定义的实现

完全按照数学定义来实现,无任何优化。以此为起点一步步尝试开始优化。注意一下函数后缀MNK表示三层循环的顺序。

void gemm_v0_MNK(const float *A, const float *B, float *C, int M, int N, int K)
{
	memset(C, 0, M*N*sizeof(float));
	for (int m = 0; m < M; m++)
	{
		for (int n = 0; n < N; n++)
		{
			for (int k = 0; k < K; k++)
			{
				C[m*N + n] += A[m*K + k] * B[k*N + n];
			}
		}
	}
	return;
}

在我的i5-7400上耗时4289ms

v1.调整循环顺序

v0的实现中,最内存循环对B矩阵是按照列的方向访存的,这样在B矩阵的宽度较大时很容易cache miss。调整一下循环顺序,对B矩阵按照行方向来访问。

void gemm_v1_MKN(const float *A, const float *B, float *C, int M, int N, int K)
{
	memset(C, 0, M*N*sizeof(float));
	for (int m = 0; m < M; m++)
	{
		for (int k = 0; k < K; k++)
		{
			float a = A[m*K + k];
			for (int n = 0; n < N; n++)
			{
				C[m*N + n] += a* B[k*N + n];
			}
		}
	}
	return;
}

调整顺序后,最内层循环变为一个某一个A的单值和B的一行相乘再加回对应的C的位置。可以看到C的同一个位置会被写回K次。这个版本耗时1790ms

v2.对B矩阵进行转置

v1中会有C多次写回的问题。换一个思路,对B矩阵进行转置这样曾经对B读取一列就变成了读取一行。此时三层循环的顺序为MNK。

void transpose(const float *A, float *B, int M, int N)
{
	for (int n = 0; n < N; n++)
	{
		for (int m = 0; m < M; m++)
		{
			B[n*M + m] = A[N*m + n];
		}
	}
}

void gemm_v2_MNK_transposeB(const float *A, const float *B, float *C, int M, int N, int K)
{
	for (int m = 0; m < M; m++)
	{
		for (int n = 0; n < N; n++)
		{
			float sum = 0.0f;
			for (int k = 0; k < K; k++)
			{
				sum += A[m*K + k] * B[n*K + k];
			}
			C[m*N + n] = sum;
		}
	}
	return;
}

这一版本矩阵转置加乘加的性能一共为1620ms

v3.分块矩阵乘

还是为了优化cache命中率。我们把矩阵分成若干个小矩阵,小矩阵的尺寸是足够被L1 cache缓存的。分块的大小和具体机器cache大小有关,甚至和矩阵的规模有关。可以参考这篇博客。其实当问题无法完全建模时求各种限制条件下的最优解时,穷举法暴力最优参数(在这个问题里是分块大小这个参数的选择)也是种常见的优化方法。

inline void do_block(const float *A, const float *B, float *C, int K, int N, int BLOCKSIZE)
{
	
	for (int m = 0; m < BLOCKSIZE; m++)
	{
		for (int n = 0; n < BLOCKSIZE; n++)
		{
			float c = C[m*N + n];
			for (int k = 0; k < BLOCKSIZE; k++)
				c += A[m*K + k] * B[k*N + n];
			C[m*N + n] = c;
		}
	}
}


// 矩阵分块乘法
void dgemm_block(const float *A, const float *B, float *C, int M, int N, int K)
{
	const int BLOCKSIZE = 64;
	memset(C, 0, M*N*sizeof(float));
	for (int m = 0; m < M; m += BLOCKSIZE)
	{
		for (int n = 0; n < N; n += BLOCKSIZE)
		{
			for (int k = 0; k < K; k += BLOCKSIZE)
			{
				do_block(A + m*K + k, B + k*N + n, C + m*N + n, K, N, BLOCKSIZE);
			}
		}
	}
	return;

}

本版本性能为1633ms。到这里为止都是对内存访问进行优化,期望提升cache命中率。不过上面的几份代码效率都不太高,比如index的重复计算、没有循环展开等。后面的代码会把这些细节完善起来。分块矩阵乘的实现也可以不用函数,减少出栈入栈的时间。

v4.SSE初步优化

v1中最内层循环中A矩阵中的一个值乘以B矩阵中的一整行。可以比较直观的用向量化来实现。具体见代码和注释。

void gemm_v1_MKN_SSE(const float *A, const float *B, float *C, int M, int N, int K)
{
	memset(C, 0, M*N*sizeof(float));
	int m, n, k;
	for (m = 0; m < M; m++)
	{
		for (k = 0; k < K; k++)
		{
			__m128 v4_a = _mm_set1_ps(*(A + m*K + k));// Am,k Am,k Am,k Am,k
			for (n = 0; n < N - 3; n += 4)
			{
				__m128 v4_b = _mm_loadu_ps(B + k*N + n); // Bk,n Bk,n+1 Bk,n+2 Bk,n+3
				__m128 v4_c = _mm_loadu_ps(C + m*N + n);
				_mm_storeu_ps(C + m*N + n, _mm_add_ps(v4_c, _mm_mul_ps(v4_a, v4_b)));
			}
			for (; n < N; n++)
			{
				C[m*N + n] += A[m*K + k] * B[k*N + n];
			}
		}
	}
	return;
}

性能为794ms,对比v11794ms差不多优化了一倍。仅仅简单的应用下SSE对性能的提升也还是可观的。

v5.循环展开(unroll)

为了编译器可以更好的排软件流水,减少循环判断次数,并且减少对C矩阵的写回操作,我们在v4的基础上改成最内层循环一次做4行。

void gemm_v1_MKN_SSE_UNROLL(const float *A, const float *B, float *C, int M, int N, int K)
{
	memset(C, 0, M*N*sizeof(float));
	int m, n, k;
	for (m = 0; m < M; m++)
	{
		for (k = 0; k < K - 3; k += 4)
		{
			__m128 v4_a0 = _mm_set1_ps(*(A + m*K + k));
			__m128 v4_a1 = _mm_set1_ps(*(A + m*K + k + 1));
			__m128 v4_a2 = _mm_set1_ps(*(A + m*K + k + 2));
			__m128 v4_a3 = _mm_set1_ps(*(A + m*K + k + 3));
			for (n = 0; n < N - 3; n += 4)
			{
				__m128 v4_b0 = _mm_loadu_ps(B + k*N + n);
				__m128 v4_b1 = _mm_loadu_ps(B + k*N + n + N);
				__m128 v4_b2 = _mm_loadu_ps(B + k*N + n + 2 * N);
				__m128 v4_b3 = _mm_loadu_ps(B + k*N + n + 3 * N);

				__m128 v4_c = _mm_loadu_ps(C + m*N + n);
				v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a0, v4_b0));
				v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a1, v4_b1));
				v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a2, v4_b2));
				v4_c = _mm_add_ps(v4_c, _mm_mul_ps(v4_a3, v4_b3));
				_mm_storeu_ps(C + m*N + n, v4_c);
			}
			for (; n < N; n++)
			{
				C[m*N + n] += A[m*K + k] * B[k*N + n];
				C[m*N + n] += A[m*K + k + 1] * B[(k + 1)*N + n];
				C[m*N + n] += A[m*K + k + 2] * B[(k + 2)*N + n];
				C[m*N + n] += A[m*K + k + 3] * B[(k + 3)*N + n];
			}
		}
		for (; k < K; k++)
		{
			__m128 v4_a0 = _mm_set1_ps(*(A + m*K + k));

			for (n = 0; n < N - 3; n += 4)
			{
				__m128 v4_b = _mm_loadu_ps(B + k*N + n);
				__m128 v4_c = _mm_loadu_ps(C + m*N + n);
				_mm_storeu_ps(C + m*N + n, _mm_add_ps(v4_c, _mm_mul_ps(v4_a0, v4_b)));
			}

			float a = A[m*K + k];
			for (; n < N; n++)
			{
				C[m*N + n] += a* B[k*N + n];
			}
		}
	}
	return;
}

实现的时候注意下要处理A矩阵高度不是4整除时的情况。可以看到**_mm_storeu_ps的指令减少到了原来的1/4。本版性能为463ms**。又优化了一倍,继续向下优化吧。

v6.在v2(转置B矩阵)的版本上做SSE+UNROLL

虽然v5已经有了一定的性能提升,但是如之前分析的,这种计算流程会让C有很多的写回操作。再来看看转置版本上SSE优化的结果吧。

void gemm_v2_MNK_SSE_UNROLL(const float *A, const float *B, float *C, int M, int N, int K)
{
	int k = 0, n = 0;
	__m128 v4_1_ps = _mm_set1_ps(1.0f);
	__m128 v4_sum_tmp_ps, v4_sumv_tmp_ps;
	for (int m = 0; m < M; m++)
	{
		for (n = 0; n < N - 3; n += 4)
		{
			float sum0, sum1, sum2, sum3;
			__m128 v4_sum0 = _mm_setzero_ps();
			__m128 v4_sum1 = _mm_setzero_ps();
			__m128 v4_sum2 = _mm_setzero_ps();
			__m128 v4_sum3 = _mm_setzero_ps();

			sum0 = sum1 = sum2 = sum3 = 0.0f;
			for (k = 0; k < K - 3; k += 4)
			{
				__m128 a = _mm_loadu_ps(A + m*K + k);

				__m128 b0 = _mm_loadu_ps(B + n*K + k);
				__m128 b1 = _mm_loadu_ps(B + n*K + k + K);
				__m128 b2 = _mm_loadu_ps(B + n*K + k + 2 * K);
				__m128 b3 = _mm_loadu_ps(B + n*K + k + 3 * K);

				v4_sum0 = _mm_add_ps(v4_sum0, _mm_mul_ps(a, b0));
				v4_sum1 = _mm_add_ps(v4_sum1, _mm_mul_ps(a, b1));
				v4_sum2 = _mm_add_ps(v4_sum2, _mm_mul_ps(a, b2));
				v4_sum3 = _mm_add_ps(v4_sum3, _mm_mul_ps(a, b3));
			}
			for (; k < K; k++)
			{
				sum0 += A[m*K + k] * B[n*K + k];
				sum1 += A[m*K + k] * B[n*K + k + K];
				sum2 += A[m*K + k] * B[n*K + k + 2 * k];
				sum3 += A[m*K + k] * B[n*K + k + 3 * k];
			}
			v4_sum_tmp_ps = _mm_setr_ps(sum0, sum1, sum2, sum3);

			//v4_sumv_tmp_ps.m128_f32[0] = v4_sum0.m128_f32[0] + v4_sum0.m128_f32[1] + v4_sum0.m128_f32[2] + v4_sum0.m128_f32[3];
			v4_sumv_tmp_ps = _mm_dp_ps(v4_sum0, v4_1_ps, 0xF1);
			v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);

			v4_sumv_tmp_ps = _mm_dp_ps(v4_sum1, v4_1_ps, 0xF2);
			v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);

			v4_sumv_tmp_ps = _mm_dp_ps(v4_sum2, v4_1_ps, 0xF4);
			v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);

			v4_sumv_tmp_ps = _mm_dp_ps(v4_sum3, v4_1_ps, 0xF8);
			v4_sum_tmp_ps = _mm_add_ps(v4_sum_tmp_ps, v4_sumv_tmp_ps);

			_mm_storeu_ps(C + m*N + n, v4_sum_tmp_ps);
		}//end for n=0~N-3
		for (; n < N; n++)
		{
			float sum0;
			__m128 v4_sum0 = _mm_setzero_ps();
			sum0 = 0.0f;
			for (k = 0; k < K - 3; k += 4)
			{
				__m128 a = _mm_loadu_ps(A + m*K + k);
				__m128 b0 = _mm_loadu_ps(B + n*K + k);
				v4_sum0 = _mm_add_ps(v4_sum0, _mm_mul_ps(a, b0));
			}
			for (; k < K; k++)
			{
				sum0 += A[m*K + k] * B[n*K + k];
			}
			C[m*N + n] = sum0 + v4_sum0.m128_f32[0] + v4_sum0.m128_f32[1] + v4_sum0.m128_f32[2] + v4_sum0.m128_f32[3];
		}//end for n=N-3~N
	}// end for m
	return;
}

性能为451ms。这份代码有一个不舒服的地方在于,循环结束后需要把向量v4_sum_tmp_ps 中的四个通道值累加。这个操作效率很低(DSP平台上可能更加明显),大部分向量化(SIMD)优化的代码都会避免这样的操作。这里的实现用了_mm_dp_ps点积指令来实现的,即向量乘以1后累加到mask指定的某个位置。

v7.以1x4小矩阵为单位转置B

这部分优化前,我们先重新定义一个矩阵转置:以1x4为基本单位进行转置。
举例:
下面代码中数字表示一个1x4向量。比如原矩阵为

1 5 
2 6
3 7
4 8

转置后为

1 2 3 4 
5 6 7 8

但是各个数字表示的向量中四个元素的排序还和原矩阵一致。在这种内存排布下,我们用A中单值和B矩阵的一行相乘可得到四个结果。和v6相比,消除了向量内加法的操作。
也可以这么理解,矩阵乘法中一次处理B矩阵的四列(这四列组成一个1x4的向量和A矩阵中的同一个单值相乘),所以我们转置的时候以这四列为单位进行转置即可。显而易见,这种转置要求N为4的倍数。

// 向量转置vector4版本,注意转置后矩阵宽高的变化
// M*N -> 1/4N*4M
void transpose_vec4(const float *A, float *B, int M, int N)
{
	int m, n;
	for (m = 0; m < M; m++)
	{
		for (n = 0; n < N; n += 4)
		{
			__m128 a = _mm_loadu_ps(A + m*N + n);
			_mm_storeu_ps(B + n*M + (m << 2), a);
		}
	}
}

// 4大小向量转置B矩阵乘法
void gemm_v2_MNK_SSE_UNROLL_TRANSPOSEV4(const float *A, const float *B, float *C, int M, int N, int K)
{
	assert(0 == N % 4);
	for (int m = 0; m < M; m++)
	{
		for (int n = 0; n < N; n += 4)
		{
			__m128 v4_sum = _mm_set1_ps(0.0f);
			const float* pA = A + m*K;
			const float* pB = B + n*K;
			int k;
			for (k = 0; k < K - 3; k += 4)
			{
				__m128 v4_a0 = _mm_load1_ps(pA);
				__m128 v4_a1 = _mm_load1_ps(pA + 1);
				__m128 v4_a2 = _mm_load1_ps(pA + 2);
				__m128 v4_a3 = _mm_load1_ps(pA + 3);

				__m128 v4_b0 = _mm_loadu_ps(pB);
				__m128 v4_b1 = _mm_loadu_ps(pB + 4);
				__m128 v4_b2 = _mm_loadu_ps(pB + 8);
				__m128 v4_b3 = _mm_loadu_ps(pB + 12);

				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				v4_c = _mm_mul_ps(v4_a1, v4_b1);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				v4_c = _mm_mul_ps(v4_a2, v4_b2);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				v4_c = _mm_mul_ps(v4_a3, v4_b3);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				pA += 4;
				pB += 16;
			}
			for (; k < K; k++)
			{
				__m128 v4_a0 = _mm_load1_ps(pA);

				__m128 v4_b0 = _mm_loadu_ps(pB);

				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				pA += 1;
				pB += 4;
			}
			_mm_storeu_ps(C + m*N + n, v4_sum);
		}
	}
	return;
}

我们对k做了展开,一次做四行。最终性能为449ms

v8.使用omp做多线程

直接在v7版本上加上omp。

// 4大小向量转置矩阵乘法+OMP
void gemm_v2_MNK_SSE_UNROLL_TRANSPOSEV4_OMP(const float *A, const float *B, float *C, int M, int N, int K)
{
	assert(0 == N % 4);
#ifdef _OPENMP 
	omp_set_num_threads(4);
#pragma omp parallel for 
#endif
	for (int m = 0; m < M; m++)
	{
		for (int n = 0; n < N; n += 4)
		{
			__m128 v4_sum = _mm_set1_ps(0.0f);
			const float* pA = A + m*K;
			const float* pB = B + n*K;
			int k;
			for (k = 0; k < K - 3; k += 4)
			{

				__m128 v4_a0 = _mm_load1_ps(pA);
				__m128 v4_b0 = _mm_loadu_ps(pB);
				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				__m128 v4_a1 = _mm_load1_ps(pA + 1);
				__m128 v4_b1 = _mm_loadu_ps(pB + 4);
				v4_c = _mm_mul_ps(v4_a1, v4_b1);
				v4_sum = _mm_add_ps(v4_sum, v4_c);


				__m128 v4_a2 = _mm_load1_ps(pA + 2);
				__m128 v4_b2 = _mm_loadu_ps(pB + 8);

				v4_c = _mm_mul_ps(v4_a2, v4_b2);
				v4_sum = _mm_add_ps(v4_sum, v4_c);
				
				__m128 v4_a3 = _mm_load1_ps(pA + 3);
				__m128 v4_b3 = _mm_loadu_ps(pB + 12);
				v4_c = _mm_mul_ps(v4_a3, v4_b3);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				pA += 4;
				pB += 16;
			}
			for (; k < K; k++)
			{
				__m128 v4_a0 = _mm_load1_ps(pA);

				__m128 v4_b0 = _mm_loadu_ps(pB);

				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum = _mm_add_ps(v4_sum, v4_c);

				pA += 1;
				pB += 4;
			}
			_mm_storeu_ps(C + m*N + n, v4_sum);
		}
	}
	return;
}

性能为108ms

v9.更多的unroll

我们再试一下做更多的展开是否能进一步提升性能,相比v8版本这个版本又对m做了2为单位的展开。

void gemm_v2_MNK_SSE_UNROLL2_TRANSPOSEV4_OMP(const float *A, const float *B, float *C, int M, int N, int K)
{
#define CAL_ROWX(x) \
	v4_c = _mm_mul_ps(v4_a0, v4_b0); \
	v4_sum##x = _mm_add_ps(v4_sum##x, v4_c); \
	v4_c = _mm_mul_ps(v4_a1, v4_b1);	 \
	v4_sum##x = _mm_add_ps(v4_sum##x, v4_c); \
	v4_c = _mm_mul_ps(v4_a2, v4_b2);	 \
	v4_sum##x = _mm_add_ps(v4_sum##x, v4_c); \
	v4_c = _mm_mul_ps(v4_a3, v4_b3);	 \
	v4_sum##x = _mm_add_ps(v4_sum##x, v4_c);

	assert(0 == N % 4);
	int m = 0;
#ifdef _OPENMP 
	omp_set_num_threads(4);
#pragma omp parallel for lastprivate(m)
#endif
	for (m = 0; m < M - 1; m += 2)
	{
		for (int n = 0; n < N; n += 4)
		{
			__m128 v4_sum0 = _mm_set1_ps(0.0f);
			__m128 v4_sum1 = v4_sum0;
			const float* pA0 = A + m*K;
			const float* pA1 = A + m*K + K;
			const float* pB = B + n*K;
			int k;
			for (k = 0; k < K - 3; k += 4)
			{
				__m128 v4_c;
				// row0
				__m128 v4_a0 = _mm_load1_ps(pA0);
				__m128 v4_a1 = _mm_load1_ps(pA0 + 1);
				__m128 v4_a2 = _mm_load1_ps(pA0 + 2);
				__m128 v4_a3 = _mm_load1_ps(pA0 + 3);


				__m128 v4_b0 = _mm_loadu_ps(pB);
				__m128 v4_b1 = _mm_loadu_ps(pB + 4);
				__m128 v4_b2 = _mm_loadu_ps(pB + 8);
				__m128 v4_b3 = _mm_loadu_ps(pB + 12);

				CAL_ROWX(0)

					// row1
					v4_a0 = _mm_load1_ps(pA1);
				v4_a1 = _mm_load1_ps(pA1 + 1);
				v4_a2 = _mm_load1_ps(pA1 + 2);
				v4_a3 = _mm_load1_ps(pA1 + 3);

				CAL_ROWX(1)

					pA0 += 4;
				pA1 += 4;
				pB += 16;
			}
			for (; k < K; k++)
			{
				__m128 v4_a0 = _mm_load1_ps(pA0);
				__m128 v4_a1 = _mm_load1_ps(pA1);

				__m128 v4_b0 = _mm_loadu_ps(pB);

				// row0
				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum0 = _mm_add_ps(v4_sum0, v4_c);

				// row1
				v4_c = _mm_mul_ps(v4_a1, v4_b0);
				v4_sum1 = _mm_add_ps(v4_sum1, v4_c);

				pA0++;
				pA1++;
				pB += 4;
			}
			_mm_storeu_ps(C + m*N + n, v4_sum0);
			_mm_storeu_ps(C + m*N + N + n, v4_sum1);
		}
	}

	// m = M&(-1)
	for (; m < M; m++)
	{
		for (int n = 0; n < N; n += 4)
		{
			__m128 v4_sum0 = _mm_set1_ps(0.0f);
			__m128 v4_c;
			const float* pA0 = A + m*K;
			const float* pB = B + n*K;
			int k;
			for (k = 0; k < K - 3; k += 4)
			{
				// row0
				__m128 v4_a0 = _mm_load1_ps(pA0);
				__m128 v4_a1 = _mm_load1_ps(pA0 + 1);
				__m128 v4_a2 = _mm_load1_ps(pA0 + 2);
				__m128 v4_a3 = _mm_load1_ps(pA0 + 3);


				__m128 v4_b0 = _mm_loadu_ps(pB);
				__m128 v4_b1 = _mm_loadu_ps(pB + 4);
				__m128 v4_b2 = _mm_loadu_ps(pB + 8);
				__m128 v4_b3 = _mm_loadu_ps(pB + 12);

				CAL_ROWX(0)

					pA0 += 4;
				pB += 16;
			}
			for (; k < K; k++)
			{
				__m128 v4_a0 = _mm_load1_ps(pA0);

				__m128 v4_b0 = _mm_loadu_ps(pB);

				// row0
				__m128 v4_c = _mm_mul_ps(v4_a0, v4_b0);
				v4_sum0 = _mm_add_ps(v4_sum0, v4_c);

				pA0++;
				pB += 4;
			}
			_mm_storeu_ps(C + m*N + n, v4_sum0);
		}
	}

	return;
}

这里注意下对omp对m变量的私有声明。不过不幸的是性能反而变差了,性能变成150ms。大部分情况下展开过多导致的负优化都是因为寄存器溢出,这一点可以通过反汇编来验证。
所以怎么知道到底展开多少合适呢,因为展开的时候大部分逻辑都是重复的,只是处理的数据不同。可以对变量命名做参数化,然后用这版本中类似CAL_ROWX(x)这样的宏来总结这些重复代码,甚至变量声明也用宏来代替,宏的输入为行号列号等变量。
所以总结的够好的话,展开几次就是换成调用几次宏而已。在此基础上可进一步写个自动生成优化代码的程序,生成N份不同展开组合的实现,穷举到底哪种展开最好(不过一般选出来的最优解比次优解不会强很多)。
缺点是:代码可读性下降,单步调试变难(宏内逻辑不能单步debug)。所以这里v9我也就偷懒没写成全参数化的样子了,毕竟还是以讲原理为目的。

v10.openBlas Gemm

还有第十一个版本,就是调用openblas的矩阵乘。结果发现比我们最快的版本还要快个2~3倍,大概4、50ms。我们这里少了汇编级别调整,而且分块矩阵乘法也没合进最后的版本中去。

小结

本文利用粗浅的优化知识对矩阵乘法进行了一系列优化。后续可能在本篇基础上再做一些精细的优化。若想学习生产环境中的优化代码可以参考各大厂开源的DL推理库。

reference

1.SSE指令查询1
2.SSE指令查询2

Code

完整代码
note:在main.cpp中修改select数组的值来选择测试哪些版本的矩阵乘。

本人水平有限,有理解不对的地方欢迎指正,共同进步。

你可能感兴趣的:(DL框架)