SIMD加速矩阵运算

一、SIMD指令简介

  • SIMD的全称叫做,单指令集多数据(Single Instruction Multiple Data)。最直观的理解就是,向量计算。比如一个加法指令周期只能算一组数(一维向量相加),使用SIMD的话,一个加法指令周期可以同时算多组数(n维向量相加),二者用时基本相等,极大地提高了运算效率。
  • SIMD (Single Instruction Multiple Data)指令集,指单指令多数据流技术,可用一组指令对多组数据通进行并行操作。SIMD指令可以在一个控制器上控制同时多个平行的处理微元,一次指令运算执行多个数据流,这样在很多时候可以提高程序的运算速度。SIMD指令在本质上非常类似一个向量处理器,可对控制器上的一组数据(又称“数据向量”) 同时分别执行相同的操作从而实现空间上的并行。SIMD是CPU实现DLP(Data Level Parallelism)的关键,DLP就是按照SIMD模式完成计算的。SSE和较早的MMX和 AMD的3DNow!都是SIMD指令集。它可以通过单指令多数据技术和单时钟周期并行处理多个浮点来有效地提高浮点运算速度
  • 可以使用CPU-Z程序查看自己设备的CPU支持哪些SIMD运算指令集。
    SIMD加速矩阵运算_第1张图片

二、核心代码

  • 矩阵的声明如下:
pragma once
#include //AVX(include wmmintrin.h)
#include 
#include "Vector3f.h"
#include "Myth.h"

class Matrix4
{
public:

	Matrix4();
	Matrix4(const std::initializer_list<float>& list);
	~Matrix4();

	Matrix4 operator+(const Matrix4& right)const;
	Matrix4 operator-(const Matrix4& right)const;
	Matrix4 operator*(const Matrix4& right)const;
	Vector3f operator*(const Vector3f& v)const;
	Matrix4 operator*(float k)const;
	Matrix4 operator-()const;

	void Identity();
	Vector3f MultiplyVector3(const Vector3f& v) const;
	Matrix4 transpose()const;
	void Print();

public:
	union
	{
		__m256 m[2];
		float data[16];
		float ptr[4][4];
	};
};
  • 我们要实现矩阵和矩阵加减运算、矩阵和矩阵乘法运算、矩阵和向量乘法运算、矩阵和常数乘法运算的加速。

2.1 矩阵的构造

  • 使用__m256来存储矩阵数据节省了每次运算时将数据加载到__m256变量所需的时间,但也因此带来获取矩阵元素的不变,对此我们使用union共用体来解决这个问题。
  • 使用union共用体定义的变量共用同一块内存区域,由于m、data和ptr所需内存字节数相同,因此无论使用哪一种方式索引数据都会得到正确的结果。
  • 注意当使用容器来存储此矩阵时,可能发生出乎意料的结果,比如 vector< Matrix4 > 会在push_back时改变容器中所有元素的值为新添加元素的值,原因未知。

2.2 矩阵和矩阵加减法运算

  • 重载矩阵和矩阵之间的加法运算符,我们使用_mm256_add_ps函数一次计算8个float变量的求和,因此两次调用函数花费两次时钟周期即可完成矩阵加法的运算。
Matrix4 Matrix4::operator + (const Matrix4& right) const
{
	Matrix4 res;
	
	for (int i = 0; i < 2; i++)
		res.m[i] = _mm256_add_ps(m[i], right.m[i]);
	return res;
}
  • 对于矩阵间减法运算,和加法模板相同,将加法函数 _mm256_add_ps 改为减法函数 _mm256_sub_ps即可。
Matrix4 Matrix4::operator - (const Matrix4& right) const
{
	Matrix4 res;

	for (int i = 0; i < 2; i++)
		res.m[i] = _mm256_sub_ps(m[i], right.m[i]);
	return res;
}

2.3 矩阵和常数相乘

  • 对于矩阵和常数的乘法,我们首先使用_mm256_set_ps函数构造一个分量都为k的8维向量,使用乘法函数 _mm256_mul_ps令矩阵和8维向量对应相乘即可。
  • 当我们使用SIMD时会涉及很多运算以意外的操作,比如常见的我们需要构造SIMD类型的变量,这时加载数据就需要用额外的时间,因此SIMD不能达到理论上超过普通运算几倍的速度。
Matrix4 Matrix4::operator*(float k)const
{
	Matrix4 res;
	__m256 mt = _mm256_set_ps(k, k, k, k, k, k, k, k);
	for (int i = 0; i < 2; i++)
		res.m[i] = _mm256_mul_ps(m[i], mt);
	return res;
}

2.4 矩阵和向量相乘

  • 对于矩阵和4维向量相乘,我们首先构造一个8维向量来存储向量4维向量,这里的四维向量是一个齐次坐标即其本质是空间中的一个三维向量。我们将8维向量的分量依次设置为xyzwxyzw,这里的xyzw指4维向量的分量,即将4维向量按顺序重复的平铺到8维向量中。
  • 我们使用_mm256dp_ps函数进行乘法并求和运算。_mm256dp_ps函数的前两个参数为8维向量__m256,第三个参数指定运算的规则。此函数将每个__m256变量分为前后两个部分,每个部分占4个float字节,函数依据设置的运算规对两个__m256变量的前后两部分分别实现相乘并求和。0b11110001中前四位1111表示要将两个__m256的前部分各个float对应相乘求和,后部分同样如此。而后四位0001表示将前4个float相乘求和的结果存储到返回值temp[0],将后4个float相乘求和的结果存储到返回值temp[4]中。
  • 现在你也许理解了我们为什么要将4维向量平铺到__m256中了,这样一次我们可以计算出矩阵的两行和向量列的乘法并求和,因此我们调用两次_mm256_dp_ps即可完成所有运算。
Vector3f Matrix4::operator*(const Vector3f& v)const
{
	Vector3f res;
	__declspec(align(16))	__m256 temp;
	__declspec(align(16))	__m256 mt = _mm256_set_ps(v.x, v.y, v.z, v.w, v.x, v.y, v.z, v.w);

	temp = _mm256_dp_ps(m[0], mt, 0b11110001);
	res.x = temp.m256_f32[0];
	res.y = temp.m256_f32[4];

	temp = _mm256_dp_ps(m[1], mt, 0b11110001);
	res.z = temp.m256_f32[0];
	res.w = temp.m256_f32[4];
	
	return res;
}

2.5 矩阵和矩阵相乘

  • 矩阵乘法是在计算机图形学中使用最多的,无论任何物体需要渲染,它的每一个顶点都需要进行数次的矩阵相乘运算,在现代游戏中需要渲染的物体成千上万,顶点更是数不胜数,因此矩阵相乘的速度很大程度上决定了渲染的速度。
  • __declspec(align(16))可以保证字节对齐,建立在定义任何SIMD变量时进行使用。下文中的gather都是取出矩阵的固定位置元素。gatherA12表示取出左矩阵的第一行和第二行,要注意_mm256_set_epi32中索引的顺序是逆序的。
  • 下列代码中gatherA12取出矩阵M的第7, 6, 5, 4, 3, 2, 1, 0个元素,因为是逆序因此其返回的是M(0),M(1)…,M(7)。我们为定义矩阵的union共用体中包含一个一维数组data,直接使用其作为参数即可。因此 __m256 a12 = _mm256_i32gather_ps(this->data, gatherA12, sizeof(float)); 表示从一维数组data中按照gatherA12索引取出元素,每个元素的大小为sizeof(float),这样可以构造的构造矩阵行列对应的SIMD变量。矩阵乘法就是行列相乘求和,因此如矩阵和向量相乘一样,调用 _mm256_dp_ps 将构造的行列相乘求和即可。
__declspec(align(16)) __m256i gatherA12 = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
__declspec(align(16)) __m256i gatherA34 = _mm256_set_epi32(15, 14, 13, 12, 11, 10, 9, 8);

__declspec(align(16)) __m256i gatherB11 = _mm256_set_epi32(12, 8, 4, 0, 12, 8, 4, 0);
__declspec(align(16)) __m256i gatherB22 = _mm256_set_epi32(13, 9, 5, 1, 13, 9, 5, 1);
__declspec(align(16)) __m256i gatherB33 = _mm256_set_epi32(14, 10, 6, 2, 14, 10, 6, 2);
__declspec(align(16)) __m256i gatherB44 = _mm256_set_epi32(15, 11, 7, 3, 15, 11, 7, 3);

Matrix4 Matrix4::operator*(const Matrix4& right)const
{
	Matrix4 ret;
	__declspec(align(16)) __m256 temp;
	__declspec(align(16)) __m256 a12, a34;
	__declspec(align(16)) __m256 b11, b22, b33, b44;

	a12 = _mm256_i32gather_ps(this->data, gatherA12, sizeof(float));
	a34 = _mm256_i32gather_ps(this->data, gatherA34, sizeof(float));

	b11 = _mm256_i32gather_ps(right.data, gatherB11, sizeof(float));
	b22 = _mm256_i32gather_ps(right.data, gatherB22, sizeof(float));
	b33 = _mm256_i32gather_ps(right.data, gatherB33, sizeof(float));
	b44 = _mm256_i32gather_ps(right.data, gatherB44, sizeof(float));

	temp = _mm256_dp_ps(a12, b11, 0b11110001);
	ret.data[0] = temp.m256_f32[0];
	ret.data[4] = temp.m256_f32[4];
	temp = _mm256_dp_ps(a34, b11, 0b11110001);
	ret.data[8] = temp.m256_f32[0];
	ret.data[12] = temp.m256_f32[4];

	temp = _mm256_dp_ps(a12, b22, 0b11110001);
	ret.data[1] = temp.m256_f32[0];
	ret.data[5] = temp.m256_f32[4];
	temp = _mm256_dp_ps(a34, b22, 0b11110001);
	ret.data[9] = temp.m256_f32[0];
	ret.data[13] = temp.m256_f32[4];

	temp = _mm256_dp_ps(a12, b33, 0b11110001);
	ret.data[2] = temp.m256_f32[0];
	ret.data[6] = temp.m256_f32[4];
	temp = _mm256_dp_ps(a34, b33, 0b11110001);
	ret.data[10] = temp.m256_f32[0];
	ret.data[14] = temp.m256_f32[4];

	temp = _mm256_dp_ps(a12, b44, 0b11110001);
	ret.data[3] = temp.m256_f32[0];
	ret.data[7] = temp.m256_f32[4];
	temp = _mm256_dp_ps(a34, b44, 0b11110001);
	ret.data[11] = temp.m256_f32[0];
	ret.data[15] = temp.m256_f32[4];

	return ret;
}

三、完整代码

  • __m256可以存储256个字节,即8个float变量,我们使用__m256的二维数组即可表示一个4X4的浮点矩阵。使用到的指令集为AVX指令集,在头文件中。
  • 如何 #define SIMD 1 则启动SIMD矩阵加速运算,否则使用普通运算。
  • 矩阵的完整定义如下:
#pragma once
#include //AVX(include wmmintrin.h)
#include 
#include "Vector3f.h"
#include "Myth.h"

class Matrix4
{
public:

	Matrix4();
	Matrix4(const std::initializer_list<float>& list);
	~Matrix4();

	Matrix4 operator+(const Matrix4& right)const;
	Matrix4 operator-(const Matrix4& right)const;
	Matrix4 operator*(const Matrix4& right)const;
	Vector3f operator*(const Vector3f& v)const;
	Matrix4 operator*(float k)const;
	Matrix4 operator-()const;

	void Identity();
	Vector3f MultiplyVector3(const Vector3f& v) const;
	Matrix4 transpose()const;
	void Print();

public:
	union
	{
		__m256 m[2];
		float data[16];
		float ptr[4][4];
	};
};

  • 使用SIMD加速的矩阵的运算操作定义如下。
  • 当定义宏SIMD的值为1时使用SIMD加速运算,否则使用普通运算。
  • 矩阵的加减法直接使用AVX指令一次运算8个float变量,因此加减法的纯运算时间为两个时钟周期。如果您的硬件支持AVX5,那么可以一次运算16个float变量,即可再次缩短运算时间。
#include "Matrix.h"

Matrix4::Matrix4()
{
	Identity();
}
Matrix4::Matrix4(const std::initializer_list<float>& list)
{
	auto begin = list.begin();
	auto end = list.end();
	int i = 0;
	while (begin != end)
	{
		data[i++] = *begin;
		++begin;
	}
}
Matrix4::~Matrix4()
{
}

#if SIMD

Matrix4 Matrix4::operator + (const Matrix4& right) const
{
	Matrix4 res;
	
	for (int i = 0; i < 2; i++)
		res.m[i] = _mm256_add_ps(m[i], right.m[i]);
	return res;
}
Matrix4 Matrix4::operator - (const Matrix4& right) const
{
	Matrix4 res;

	for (int i = 0; i < 2; i++)
		res.m[i] = _mm256_sub_ps(m[i], right.m[i]);
	return res;
}
Matrix4 Matrix4::operator*(float k)const
{
	Matrix4 res;
	__m256 mt = _mm256_set_ps(k, k, k, k, k, k, k, k);
	for (int i = 0; i < 2; i++)
		res.m[i] = _mm256_mul_ps(m[i], mt);
	return res;
}
Vector3f Matrix4::operator*(const Vector3f& v)const
{
	Vector3f res;
	__declspec(align(16))	__m256 temp;
	__declspec(align(16))	__m256 mt = _mm256_set_ps(v.x, v.y, v.z, v.w, v.x, v.y, v.z, v.w);

	temp = _mm256_dp_ps(m[0], mt, 0b11110001);
	res.x = temp.m256_f32[0];
	res.y = temp.m256_f32[4];

	temp = _mm256_dp_ps(m[1], mt, 0b11110001);
	res.z = temp.m256_f32[0];
	res.w = temp.m256_f32[4];
	
	return res;
}


__declspec(align(16)) __m256i gatherA12 = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
__declspec(align(16)) __m256i gatherA34 = _mm256_set_epi32(15, 14, 13, 12, 11, 10, 9, 8);

__declspec(align(16)) __m256i gatherB11 = _mm256_set_epi32(12, 8, 4, 0, 12, 8, 4, 0);
__declspec(align(16)) __m256i gatherB22 = _mm256_set_epi32(13, 9, 5, 1, 13, 9, 5, 1);
__declspec(align(16)) __m256i gatherB33 = _mm256_set_epi32(14, 10, 6, 2, 14, 10, 6, 2);
__declspec(align(16)) __m256i gatherB44 = _mm256_set_epi32(15, 11, 7, 3, 15, 11, 7, 3);

Matrix4 Matrix4::operator*(const Matrix4& right)const
{
	Matrix4 ret;
	__declspec(align(16)) __m256 temp;
	__declspec(align(16)) __m256 a12, a34;
	__declspec(align(16)) __m256 b11, b22, b33, b44;

	a12 = _mm256_i32gather_ps(this->data, gatherA12, sizeof(float));
	a34 = _mm256_i32gather_ps(this->data, gatherA34, sizeof(float));

	b11 = _mm256_i32gather_ps(right.data, gatherB11, sizeof(float));
	b22 = _mm256_i32gather_ps(right.data, gatherB22, sizeof(float));
	b33 = _mm256_i32gather_ps(right.data, gatherB33, sizeof(float));
	b44 = _mm256_i32gather_ps(right.data, gatherB44, sizeof(float));

	temp = _mm256_dp_ps(a12, b11, 0b11110001);
	ret.data[0] = temp.m256_f32[0];
	ret.data[4] = temp.m256_f32[4];
	temp = _mm256_dp_ps(a34, b11, 0b11110001);
	ret.data[8] = temp.m256_f32[0];
	ret.data[12] = temp.m256_f32[4];

	temp = _mm256_dp_ps(a12, b22, 0b11110001);
	ret.data[1] = temp.m256_f32[0];
	ret.data[5] = temp.m256_f32[4];
	temp = _mm256_dp_ps(a34, b22, 0b11110001);
	ret.data[9] = temp.m256_f32[0];
	ret.data[13] = temp.m256_f32[4];

	temp = _mm256_dp_ps(a12, b33, 0b11110001);
	ret.data[2] = temp.m256_f32[0];
	ret.data[6] = temp.m256_f32[4];
	temp = _mm256_dp_ps(a34, b33, 0b11110001);
	ret.data[10] = temp.m256_f32[0];
	ret.data[14] = temp.m256_f32[4];

	temp = _mm256_dp_ps(a12, b44, 0b11110001);
	ret.data[3] = temp.m256_f32[0];
	ret.data[7] = temp.m256_f32[4];
	temp = _mm256_dp_ps(a34, b44, 0b11110001);
	ret.data[11] = temp.m256_f32[0];
	ret.data[15] = temp.m256_f32[4];

	return ret;
}

#else

Matrix4 Matrix4::operator + (const Matrix4& right) const
{
	Matrix4 res;
	for (int i = 0; i < 4; i++)
		for (int j = 0; j < 4; j++)
			res.ptr[i][j] = ptr[i][j] + right.ptr[i][j];
	return res;
}
Matrix4 Matrix4::operator - (const Matrix4& right) const
{
	Matrix4 res;
	for (int i = 0; i < 4; i++)
		for (int j = 0; j < 4; j++)
			res.ptr[i][j] = ptr[i][j] - right.ptr[i][j];
	return res;
}

Matrix4 Matrix4::operator*(float k)const
{
	Matrix4 res;
	for (int i = 0; i < 4; ++i)
	{
		for (int j = 0; j < 4; ++j)
		{
			res.ptr[i][j] = ptr[i][j] * k;
		}
	}
	return res;
}


Vector3 Matrix4::operator*(const Vector3& v)const
{
	float x = v.x * ptr[0][0] + v.y * ptr[0][1] + v.z * ptr[0][2] + v.w * ptr[0][3];
	float y = v.x * ptr[1][0] + v.y * ptr[1][1] + v.z * ptr[1][2] + v.w * ptr[1][3];
	float z = v.x * ptr[2][0] + v.y * ptr[2][1] + v.z * ptr[2][2] + v.w * ptr[2][3];
	float w = v.x * ptr[3][0] + v.y * ptr[3][1] + v.z * ptr[3][2] + v.w * ptr[3][3];
	Vector3 returnValue(x, y, z);
	returnValue.w = w;
	return returnValue;
}

Matrix4 Matrix4::operator * (const Matrix4& right) const
{
	Matrix4 res;
	for (int i = 0; i < 4; i++)
	{
		for (int j = 0; j < 4; j++)
		{
			res.ptr[i][j] = 0;//temp
			for (int k = 0; k < 4; k++)
			{
				res.ptr[i][j] += this->ptr[i][k] * right.ptr[k][j];
			}
		}
	}
	return res;
}

#endif // SIMD


Matrix4 Matrix4::operator-()const
{
	Matrix4 trans;
	for (int i = 0; i < 4; ++i)
		for (int j = 0; j < 4; ++j)
			trans.ptr[i][j] = ptr[j][i];
	return trans;
}
Matrix4 Matrix4::transpose()const
{
	return -(*this);
}


void Matrix4::Identity()
{
	for (int i = 0; i < 4; ++i)
	{
		for (int j = 0; j < 4; ++j)
		{
			if (i != j)
				ptr[i][j] = 0;
			else
				ptr[i][j] = 1;
		}
	}
}

Vector3f Matrix4::MultiplyVector3(const Vector3f& v) const
{
	return (*this) * v;
}


void Matrix4::Print()
{
	std::cout << "-----------------Matrix Begin--------------" << std::endl;
	for (int i = 0; i < 4; ++i)
	{
		for (int j = 0; j < 4; ++j)
		{
			std::cout << "[" << ptr[i][j] << "]   ";
		}
		std::cout << std::endl;
	}
	std::cout << "-----------------Matrix End----------------" << std::endl;
}

你可能感兴趣的:(C++,矩阵,算法,SIMD,矩阵运算,AVX)