SPGEMM_example解析

// 设备端并行求偏移数组 不判断当前列号是否出现过也添加进去
// 列号排序
// 然后计算
#include 

// 核函数每个线程负责一行 计算当前行中有多少个元素 并先存入相应的偏移量数组中行号的+1位置 (不判断列号是否重复的版本)
__global__ void getRowNnz(const int *dptr_offset_A, const int *dptr_offset_B,
						  const int *dptr_colindex_A, const int *dptr_colindex_B, int *dptr_offset_C, int m)
{
	int rowindex = threadIdx.x + blockDim.x * blockIdx.x;
	if (rowindex < m)
	{
		int row_nnz = 0; // row_nnz记录当前第i行一共有多少个元素 初始设为0
		int A_begin = dptr_offset_A[rowindex];
		int A_end = dptr_offset_A[rowindex + 1];
		for (int jj = A_begin; jj < A_end; jj++)
		{								 // jj为当前第rowindex行的非0元素在value数组与col数组中的起始位置
			int j = dptr_colindex_A[jj]; // j为当前A的第rowindex行中非0元素所处于的列号  然后找B中第j行的非0元素
			int B_begin = dptr_offset_B[j];
			int B_end = dptr_offset_B[j + 1];
			row_nnz += B_end - B_begin;
		}
		dptr_offset_C[rowindex + 1] = row_nnz; // 得到每行有多少个元素先存入相应的偏移量数组中行号的+1位置
	}
}

// 核函数每个线程负责一行 标识当前元素的行号 并且对每行列索引数组中的相应区域进行排序
__global__ void SortAndRow(const int *dptr_offset_A, const int *dptr_offset_B,
						   const int *dptr_colindex_A, const int *dptr_colindex_B,
						   int *dptr_colindex_C, int *dptr_rowindex_C, int *dptr_offset_C, int m)
{
	int rowindex = threadIdx.x + blockDim.x * blockIdx.x;
	if (rowindex < m)
	{
		// 找到当前行元素再col和val数组的下标范围
		int left = dptr_offset_C[rowindex];
		int right = dptr_offset_C[rowindex + 1];
		// 将遍历得到的列数组的列号存储到相应位置
		// 先设置初始插入的位置
		// 所有元素插入完成后再排序
		int pos = left; // 插入位置初始为left位置
		int i = left;
		int A_begin = dptr_offset_A[rowindex];
		int A_end = dptr_offset_A[rowindex + 1];
		for (int jj = A_begin; jj < A_end; jj++)
		{								 // jj为当前第i行的非0元素在value数组与col数组中的起始位置
			int j = dptr_colindex_A[jj]; // j为当前A的第i行中非0元素所处于的列号  然后找B中第j行的非0元素
			int B_begin = dptr_offset_B[j];
			int B_end = dptr_offset_B[j + 1];
			for (int kk = B_begin; kk < B_end; kk++)
			{								 // kk为当前B中第j行中的非0元素在value数组与col数组中的起始位置
				int k = dptr_colindex_B[kk]; // k为当前B的第j行中非0元素所处的列号 即最终结果元素所处的列号
				dptr_colindex_C[pos] = k;
				pos++;
			}
		}
		// 排序算法

		// // 插入排序
		// for (int i = left + 1; i < right; i++)
		// {
		// 	int key = dptr_colindex_C[i];
		// 	int j = i - 1;
		// 	while (j >= left && dptr_colindex_C[j] > key)
		// 	{
		// 		dptr_colindex_C[j + 1] = dptr_colindex_C[j];
		// 		j--;
		// 	}
		// 	dptr_colindex_C[j + 1] = key;
		// }

		// 希尔排序
		int n = right - left;
		int p, q, gap;
		for (gap = n / 2; gap > 0; gap /= 2)
		{
			for (p = 0; p < gap; p++)
			{
				for (q = p + gap + left; q < n + left; q += gap)
				{
					if (dptr_colindex_C[q] < dptr_colindex_C[q - gap])
					{
						int tmp = dptr_colindex_C[q];
						int k = q - gap;
						while (k >= left && dptr_colindex_C[k] > tmp)
						{
							dptr_colindex_C[k + gap] = dptr_colindex_C[k];
							k -= gap;
						}
						dptr_colindex_C[k + gap] = tmp;
					}
				}
			}
		}

		// 归并排序 待补充

		//  初始化行号数组的值
		for (int i = left; i < right; i++)
		{
			dptr_rowindex_C[i] = rowindex;
		}
	}
}
// 核函数每个线程负责结果C中的每个位置 通过此位置对用的行号和列号 去遍历A和B中相应的元素乘积再相加 得到的结果存到当前位置
__global__ void cal(int *dptr_rowindex_C, int *dptr_colindex_C, double *dptr_value_C,
					const int *dptr_offset_A, const int *dptr_offset_B,
					const int *dptr_colindex_A, const int *dptr_colindex_B,
					const double *dptr_value_A, const double *dptr_value_B,
					int nonzero, double alpha)
{
	int idx = threadIdx.x + blockDim.x * blockIdx.x; // 对应在value_C数组的下标
	if (idx < nonzero)
	{
		if (idx != 0 && dptr_colindex_C[idx] == dptr_colindex_C[idx - 1] && dptr_rowindex_C[idx] == dptr_rowindex_C[idx - 1])
		{
			dptr_value_C[idx] = 0.0;
		}
		else
		{

			int row = dptr_rowindex_C[idx]; // 当前位置所对应的行号与列号
			int col = dptr_colindex_C[idx]; // 通过行号确定遍历A的非0元素所在的列号 通过列号确定寻找B的列号为col的元素
			double sum = 0;					// 记录当前位置存入的最终结果
			double value_A = 0;
			double value_B = 0;
			int A_begin = dptr_offset_A[row];
			int A_end = dptr_offset_A[row + 1];
			for (int jj = A_begin; jj < A_end; jj++)
			{								 // jj为当前第row行的非0元素在value_A数组与col_A数组中的起始位置
				value_A = dptr_value_A[jj];	 // 当前A的值
				int j = dptr_colindex_A[jj]; // j为当前A的第row行中非0元素所处于的列号
				// 折半查找 寻找B中第j行的列号为col的非0元素 与A位置的元素相乘再相加得到最终结果
				int left = dptr_offset_B[j];
				int right = dptr_offset_B[j + 1] - 1;
				int mid = 0;
				while (left <= right)
				{
					int mid = left + (right - left) / 2;
					if (dptr_colindex_B[mid] < col)
					{
						left = mid + 1;
					}
					else if (dptr_colindex_B[mid] > col)
					{
						right = mid - 1;
					}
					else if (dptr_colindex_B[mid] == col)
					{
						value_B = dptr_value_B[mid];
						sum = sum + value_A * value_B;
						break;
					}
				}
			}
			dptr_value_C[idx] = sum * alpha; // 最终结果需要乘以一个系数
		}
	}
}

void call_device_spgemm(const int transA,
						const int transB,
						const dtype alpha,
						const size_t m,
						const size_t n,
						const size_t k,
						const size_t nnz_A,
						const csrIdxType *dptr_offset_A,
						const csrIdxType *dptr_colindex_A,
						const dtype *dptr_value_A,
						const size_t nnz_B,
						const csrIdxType *dptr_offset_B,
						const csrIdxType *dptr_colindex_B,
						const dtype *dptr_value_B,
						size_t *ptr_nnz_C,
						csrIdxType *dptr_offset_C,
						csrIdxType **pdptr_colindex_C,
						dtype **pdptr_value_C)
// device_valueC的值指向设备端中的存储位置首地址 传入进来的&device_valueC是指向device_valueC指针存储位置的指针
{

	dim3 dimBlock(256);
	dim3 dimGrid((m + dimBlock.x - 1) / dimBlock.x);
	getRowNnz<<>>(dptr_offset_A, dptr_offset_B, dptr_colindex_A, dptr_colindex_B, dptr_offset_C, m);

	// 主机端申请m+1大小的C偏移量数组将设备端的内容传回 得到dptr_offset_C[m]即nnz
	int *hptr_offset_C;
	HIP_CHECK(hipHostMalloc(&hptr_offset_C, (m + 1) * sizeof(csrIdxType), hipHostRegisterDefault));
	HIP_CHECK(hipMemcpy(hptr_offset_C, dptr_offset_C, (m + 1) * sizeof(csrIdxType), hipMemcpyDeviceToHost));
	// 求前缀和
	hptr_offset_C[0] = 0;
	for (int i = 1; i <= m; i++)
	{
		hptr_offset_C[i] = hptr_offset_C[i] + hptr_offset_C[i - 1];
	}
	// 求完前缀和再传回设备
	HIP_CHECK(hipMemcpy(dptr_offset_C, hptr_offset_C, (m + 1) * sizeof(csrIdxType), hipMemcpyHostToDevice));

	// 得到结果非0元总数
	int nnz = hptr_offset_C[m];
	int nonzero = nnz;
	*ptr_nnz_C = nnz;

	// 释放主机端空间
	HIP_CHECK(hipHostFree(hptr_offset_C));

	// Malloc pdptr_colindex_C and pdptr_value_C    C有多少个非0元跟对应CSR格式中 列索引数组和值数组的大小相关
	HIP_CHECK(hipMalloc((void **)pdptr_colindex_C, nonzero * sizeof(csrIdxType)));
	HIP_CHECK(hipMalloc((void **)pdptr_value_C, nonzero * sizeof(dtype)));
	HIP_CHECK(hipMemset(*pdptr_value_C, 0.0, nonzero * sizeof(dtype)));

	// 核函数每个线程负责一行 对每行列索引数组中的相应区域进行排序
	// 并且将额外辅助标识行号的行数组初始化 得到每个元素对应在哪一行的位置
	int *dptr_rowindex_C;
	HIP_CHECK(hipMalloc((void **)&dptr_rowindex_C, nonzero * sizeof(int)));

	SortAndRow<<>>(dptr_offset_A, dptr_offset_B,
									  dptr_colindex_A, dptr_colindex_B, *pdptr_colindex_C, dptr_rowindex_C, dptr_offset_C, m);
	// 核函数 每个线程负责一个元素 计算最终结果
	dim3 dimBlocks(256);
	dim3 dimGrids((nonzero + dimBlocks.x - 1) / dimBlocks.x);

	cal<<>>(dptr_rowindex_C, *pdptr_colindex_C, *pdptr_value_C,
								 dptr_offset_A, dptr_offset_B, dptr_colindex_A, dptr_colindex_B, dptr_value_A, dptr_value_B, nonzero, alpha);
	// 释放额外的行号数组空间
	HIP_CHECK(hipFree(dptr_rowindex_C));
}

1、核函数 getRowNnz: 这个核函数的作用是计算结果矩阵 C 的每一行中非零元素的数量,并将这些数量存储在偏移量数组 dptr_offset_C 中。这个数组实际上用于构建结果矩阵 C 的压缩行存储格式(CSR格式)。每个线程处理一个行,计算出对应行的非零元素个数,然后将这些个数存储在偏移量数组中。

2、核函数 SortAndRow: 这个核函数的主要作用是构建结果矩阵 C 的压缩行存储格式。它为矩阵 C 的列索引数组 dptr_colindex_C 赋值,同时为行索引数组 dptr_rowindex_C 赋值。首先,它计算出每行在结果矩阵 C 中的存储位置范围(起始位置和结束位置),然后将遍历得到的矩阵 A 的非零元素所在的列索引存储在 dptr_colindex_C 数组中,并在 dptr_rowindex_C 数组中为每个元素记录对应的行号。这些操作将结果矩阵 C 转化为了压缩行存储格式。

通过上面这个方法可以过滤掉某一行为0或者某一列为0的情况

3、核函数 cal: 这个核函数的作用是执行稀疏矩阵乘法的计算,并将结果存储在结果矩阵 C 的值数组 dptr_value_C 中。每个线程处理结果矩阵中的一个位置,计算出对应位置的值。它在计算之前会检查相邻位置的行列索引是否相同,以避免重复计算。对于每个位置,它通过遍历矩阵 A 的非零元素和矩阵 B 的对应元素,按照稀疏矩阵乘法规则进行相乘累加操作,并将最终结果存储在 dptr_value_C 中。

if (idx != 0 && dptr_colindex_C[idx] == dptr_colindex_C[idx - 1] && dptr_rowindex_C[idx] == dptr_rowindex_C[idx - 1])
		{
			dptr_value_C[idx] = 0.0;
		}

如果是这样的话,那就代表有空行或者有空列,那就代表是0嘛,相当于提前判断了一下。

之后就开始正式的计算:

else
		{

			int row = dptr_rowindex_C[idx]; // 当前位置所对应的行号与列号
			int col = dptr_colindex_C[idx]; // 通过行号确定遍历A的非0元素所在的列号 通过列号确定寻找B的列号为col的元素
			double sum = 0;					// 记录当前位置存入的最终结果
			double value_A = 0;
			double value_B = 0;
			int A_begin = dptr_offset_A[row];
			int A_end = dptr_offset_A[row + 1];
			for (int jj = A_begin; jj < A_end; jj++)
			{								 // jj为当前第row行的非0元素在value_A数组与col_A数组中的起始位置
				value_A = dptr_value_A[jj];	 // 当前A的值
				int j = dptr_colindex_A[jj]; // j为当前A的第row行中非0元素所处于的列号
				// 折半查找 寻找B中第j行的列号为col的非0元素 与A位置的元素相乘再相加得到最终结果
				int left = dptr_offset_B[j];
				int right = dptr_offset_B[j + 1] - 1;
				int mid = 0;
				while (left <= right)
				{
					int mid = left + (right - left) / 2;
					if (dptr_colindex_B[mid] < col)
					{
						left = mid + 1;
					}
					else if (dptr_colindex_B[mid] > col)
					{
						right = mid - 1;
					}
					else if (dptr_colindex_B[mid] == col)
					{
						value_B = dptr_value_B[mid];
						sum = sum + value_A * value_B;
						break;
					}
				}
			}
			dptr_value_C[idx] = sum * alpha; // 最终结果需要乘以一个系数
		}

上面这个程序也不难理解,就是找到对应行找到对应列进行相乘相加最后得到结果,但是得益于之前进行了判断,所以会减少很多的计算量。

综合来看,这三个核函数一起完成了稀疏矩阵乘法的计算过程。首先,getRowNnz 核函数确定了结果矩阵 C 的压缩行存储格式所需的偏移量信息。然后,SortAndRow 核函数构建了矩阵 C 的压缩行存储格式,并在其中执行了排序操作。最后,cal 核函数进行了实际的稀疏矩阵乘法计算,并将结果存储在矩阵 C 的值数组中。

你可能感兴趣的:(SPGEMM,c++)