// 设备端并行求偏移数组 不判断当前列号是否出现过也添加进去
// 列号排序
// 然后计算
#include
// 核函数每个线程负责一行 计算当前行中有多少个元素 并先存入相应的偏移量数组中行号的+1位置 (不判断列号是否重复的版本)
__global__ void getRowNnz(const int *dptr_offset_A, const int *dptr_offset_B,
const int *dptr_colindex_A, const int *dptr_colindex_B, int *dptr_offset_C, int m)
{
int rowindex = threadIdx.x + blockDim.x * blockIdx.x;
if (rowindex < m)
{
int row_nnz = 0; // row_nnz记录当前第i行一共有多少个元素 初始设为0
int A_begin = dptr_offset_A[rowindex];
int A_end = dptr_offset_A[rowindex + 1];
for (int jj = A_begin; jj < A_end; jj++)
{ // jj为当前第rowindex行的非0元素在value数组与col数组中的起始位置
int j = dptr_colindex_A[jj]; // j为当前A的第rowindex行中非0元素所处于的列号 然后找B中第j行的非0元素
int B_begin = dptr_offset_B[j];
int B_end = dptr_offset_B[j + 1];
row_nnz += B_end - B_begin;
}
dptr_offset_C[rowindex + 1] = row_nnz; // 得到每行有多少个元素先存入相应的偏移量数组中行号的+1位置
}
}
// 核函数每个线程负责一行 标识当前元素的行号 并且对每行列索引数组中的相应区域进行排序
__global__ void SortAndRow(const int *dptr_offset_A, const int *dptr_offset_B,
const int *dptr_colindex_A, const int *dptr_colindex_B,
int *dptr_colindex_C, int *dptr_rowindex_C, int *dptr_offset_C, int m)
{
int rowindex = threadIdx.x + blockDim.x * blockIdx.x;
if (rowindex < m)
{
// 找到当前行元素再col和val数组的下标范围
int left = dptr_offset_C[rowindex];
int right = dptr_offset_C[rowindex + 1];
// 将遍历得到的列数组的列号存储到相应位置
// 先设置初始插入的位置
// 所有元素插入完成后再排序
int pos = left; // 插入位置初始为left位置
int i = left;
int A_begin = dptr_offset_A[rowindex];
int A_end = dptr_offset_A[rowindex + 1];
for (int jj = A_begin; jj < A_end; jj++)
{ // jj为当前第i行的非0元素在value数组与col数组中的起始位置
int j = dptr_colindex_A[jj]; // j为当前A的第i行中非0元素所处于的列号 然后找B中第j行的非0元素
int B_begin = dptr_offset_B[j];
int B_end = dptr_offset_B[j + 1];
for (int kk = B_begin; kk < B_end; kk++)
{ // kk为当前B中第j行中的非0元素在value数组与col数组中的起始位置
int k = dptr_colindex_B[kk]; // k为当前B的第j行中非0元素所处的列号 即最终结果元素所处的列号
dptr_colindex_C[pos] = k;
pos++;
}
}
// 排序算法
// // 插入排序
// for (int i = left + 1; i < right; i++)
// {
// int key = dptr_colindex_C[i];
// int j = i - 1;
// while (j >= left && dptr_colindex_C[j] > key)
// {
// dptr_colindex_C[j + 1] = dptr_colindex_C[j];
// j--;
// }
// dptr_colindex_C[j + 1] = key;
// }
// 希尔排序
int n = right - left;
int p, q, gap;
for (gap = n / 2; gap > 0; gap /= 2)
{
for (p = 0; p < gap; p++)
{
for (q = p + gap + left; q < n + left; q += gap)
{
if (dptr_colindex_C[q] < dptr_colindex_C[q - gap])
{
int tmp = dptr_colindex_C[q];
int k = q - gap;
while (k >= left && dptr_colindex_C[k] > tmp)
{
dptr_colindex_C[k + gap] = dptr_colindex_C[k];
k -= gap;
}
dptr_colindex_C[k + gap] = tmp;
}
}
}
}
// 归并排序 待补充
// 初始化行号数组的值
for (int i = left; i < right; i++)
{
dptr_rowindex_C[i] = rowindex;
}
}
}
// 核函数每个线程负责结果C中的每个位置 通过此位置对用的行号和列号 去遍历A和B中相应的元素乘积再相加 得到的结果存到当前位置
__global__ void cal(int *dptr_rowindex_C, int *dptr_colindex_C, double *dptr_value_C,
const int *dptr_offset_A, const int *dptr_offset_B,
const int *dptr_colindex_A, const int *dptr_colindex_B,
const double *dptr_value_A, const double *dptr_value_B,
int nonzero, double alpha)
{
int idx = threadIdx.x + blockDim.x * blockIdx.x; // 对应在value_C数组的下标
if (idx < nonzero)
{
if (idx != 0 && dptr_colindex_C[idx] == dptr_colindex_C[idx - 1] && dptr_rowindex_C[idx] == dptr_rowindex_C[idx - 1])
{
dptr_value_C[idx] = 0.0;
}
else
{
int row = dptr_rowindex_C[idx]; // 当前位置所对应的行号与列号
int col = dptr_colindex_C[idx]; // 通过行号确定遍历A的非0元素所在的列号 通过列号确定寻找B的列号为col的元素
double sum = 0; // 记录当前位置存入的最终结果
double value_A = 0;
double value_B = 0;
int A_begin = dptr_offset_A[row];
int A_end = dptr_offset_A[row + 1];
for (int jj = A_begin; jj < A_end; jj++)
{ // jj为当前第row行的非0元素在value_A数组与col_A数组中的起始位置
value_A = dptr_value_A[jj]; // 当前A的值
int j = dptr_colindex_A[jj]; // j为当前A的第row行中非0元素所处于的列号
// 折半查找 寻找B中第j行的列号为col的非0元素 与A位置的元素相乘再相加得到最终结果
int left = dptr_offset_B[j];
int right = dptr_offset_B[j + 1] - 1;
int mid = 0;
while (left <= right)
{
int mid = left + (right - left) / 2;
if (dptr_colindex_B[mid] < col)
{
left = mid + 1;
}
else if (dptr_colindex_B[mid] > col)
{
right = mid - 1;
}
else if (dptr_colindex_B[mid] == col)
{
value_B = dptr_value_B[mid];
sum = sum + value_A * value_B;
break;
}
}
}
dptr_value_C[idx] = sum * alpha; // 最终结果需要乘以一个系数
}
}
}
void call_device_spgemm(const int transA,
const int transB,
const dtype alpha,
const size_t m,
const size_t n,
const size_t k,
const size_t nnz_A,
const csrIdxType *dptr_offset_A,
const csrIdxType *dptr_colindex_A,
const dtype *dptr_value_A,
const size_t nnz_B,
const csrIdxType *dptr_offset_B,
const csrIdxType *dptr_colindex_B,
const dtype *dptr_value_B,
size_t *ptr_nnz_C,
csrIdxType *dptr_offset_C,
csrIdxType **pdptr_colindex_C,
dtype **pdptr_value_C)
// device_valueC的值指向设备端中的存储位置首地址 传入进来的&device_valueC是指向device_valueC指针存储位置的指针
{
dim3 dimBlock(256);
dim3 dimGrid((m + dimBlock.x - 1) / dimBlock.x);
getRowNnz<<>>(dptr_offset_A, dptr_offset_B, dptr_colindex_A, dptr_colindex_B, dptr_offset_C, m);
// 主机端申请m+1大小的C偏移量数组将设备端的内容传回 得到dptr_offset_C[m]即nnz
int *hptr_offset_C;
HIP_CHECK(hipHostMalloc(&hptr_offset_C, (m + 1) * sizeof(csrIdxType), hipHostRegisterDefault));
HIP_CHECK(hipMemcpy(hptr_offset_C, dptr_offset_C, (m + 1) * sizeof(csrIdxType), hipMemcpyDeviceToHost));
// 求前缀和
hptr_offset_C[0] = 0;
for (int i = 1; i <= m; i++)
{
hptr_offset_C[i] = hptr_offset_C[i] + hptr_offset_C[i - 1];
}
// 求完前缀和再传回设备
HIP_CHECK(hipMemcpy(dptr_offset_C, hptr_offset_C, (m + 1) * sizeof(csrIdxType), hipMemcpyHostToDevice));
// 得到结果非0元总数
int nnz = hptr_offset_C[m];
int nonzero = nnz;
*ptr_nnz_C = nnz;
// 释放主机端空间
HIP_CHECK(hipHostFree(hptr_offset_C));
// Malloc pdptr_colindex_C and pdptr_value_C C有多少个非0元跟对应CSR格式中 列索引数组和值数组的大小相关
HIP_CHECK(hipMalloc((void **)pdptr_colindex_C, nonzero * sizeof(csrIdxType)));
HIP_CHECK(hipMalloc((void **)pdptr_value_C, nonzero * sizeof(dtype)));
HIP_CHECK(hipMemset(*pdptr_value_C, 0.0, nonzero * sizeof(dtype)));
// 核函数每个线程负责一行 对每行列索引数组中的相应区域进行排序
// 并且将额外辅助标识行号的行数组初始化 得到每个元素对应在哪一行的位置
int *dptr_rowindex_C;
HIP_CHECK(hipMalloc((void **)&dptr_rowindex_C, nonzero * sizeof(int)));
SortAndRow<<>>(dptr_offset_A, dptr_offset_B,
dptr_colindex_A, dptr_colindex_B, *pdptr_colindex_C, dptr_rowindex_C, dptr_offset_C, m);
// 核函数 每个线程负责一个元素 计算最终结果
dim3 dimBlocks(256);
dim3 dimGrids((nonzero + dimBlocks.x - 1) / dimBlocks.x);
cal<<>>(dptr_rowindex_C, *pdptr_colindex_C, *pdptr_value_C,
dptr_offset_A, dptr_offset_B, dptr_colindex_A, dptr_colindex_B, dptr_value_A, dptr_value_B, nonzero, alpha);
// 释放额外的行号数组空间
HIP_CHECK(hipFree(dptr_rowindex_C));
}
1、核函数 getRowNnz
: 这个核函数的作用是计算结果矩阵 C 的每一行中非零元素的数量,并将这些数量存储在偏移量数组 dptr_offset_C
中。这个数组实际上用于构建结果矩阵 C 的压缩行存储格式(CSR格式)。每个线程处理一个行,计算出对应行的非零元素个数,然后将这些个数存储在偏移量数组中。
2、核函数 SortAndRow
: 这个核函数的主要作用是构建结果矩阵 C 的压缩行存储格式。它为矩阵 C 的列索引数组 dptr_colindex_C
赋值,同时为行索引数组 dptr_rowindex_C
赋值。首先,它计算出每行在结果矩阵 C 中的存储位置范围(起始位置和结束位置),然后将遍历得到的矩阵 A 的非零元素所在的列索引存储在 dptr_colindex_C
数组中,并在 dptr_rowindex_C
数组中为每个元素记录对应的行号。这些操作将结果矩阵 C 转化为了压缩行存储格式。
通过上面这个方法可以过滤掉某一行为0或者某一列为0的情况
3、核函数 cal
: 这个核函数的作用是执行稀疏矩阵乘法的计算,并将结果存储在结果矩阵 C 的值数组 dptr_value_C
中。每个线程处理结果矩阵中的一个位置,计算出对应位置的值。它在计算之前会检查相邻位置的行列索引是否相同,以避免重复计算。对于每个位置,它通过遍历矩阵 A 的非零元素和矩阵 B 的对应元素,按照稀疏矩阵乘法规则进行相乘累加操作,并将最终结果存储在 dptr_value_C
中。
if (idx != 0 && dptr_colindex_C[idx] == dptr_colindex_C[idx - 1] && dptr_rowindex_C[idx] == dptr_rowindex_C[idx - 1]) { dptr_value_C[idx] = 0.0; }
如果是这样的话,那就代表有空行或者有空列,那就代表是0嘛,相当于提前判断了一下。
之后就开始正式的计算:
else { int row = dptr_rowindex_C[idx]; // 当前位置所对应的行号与列号 int col = dptr_colindex_C[idx]; // 通过行号确定遍历A的非0元素所在的列号 通过列号确定寻找B的列号为col的元素 double sum = 0; // 记录当前位置存入的最终结果 double value_A = 0; double value_B = 0; int A_begin = dptr_offset_A[row]; int A_end = dptr_offset_A[row + 1]; for (int jj = A_begin; jj < A_end; jj++) { // jj为当前第row行的非0元素在value_A数组与col_A数组中的起始位置 value_A = dptr_value_A[jj]; // 当前A的值 int j = dptr_colindex_A[jj]; // j为当前A的第row行中非0元素所处于的列号 // 折半查找 寻找B中第j行的列号为col的非0元素 与A位置的元素相乘再相加得到最终结果 int left = dptr_offset_B[j]; int right = dptr_offset_B[j + 1] - 1; int mid = 0; while (left <= right) { int mid = left + (right - left) / 2; if (dptr_colindex_B[mid] < col) { left = mid + 1; } else if (dptr_colindex_B[mid] > col) { right = mid - 1; } else if (dptr_colindex_B[mid] == col) { value_B = dptr_value_B[mid]; sum = sum + value_A * value_B; break; } } } dptr_value_C[idx] = sum * alpha; // 最终结果需要乘以一个系数 }
上面这个程序也不难理解,就是找到对应行找到对应列进行相乘相加最后得到结果,但是得益于之前进行了判断,所以会减少很多的计算量。
综合来看,这三个核函数一起完成了稀疏矩阵乘法的计算过程。首先,getRowNnz
核函数确定了结果矩阵 C 的压缩行存储格式所需的偏移量信息。然后,SortAndRow
核函数构建了矩阵 C 的压缩行存储格式,并在其中执行了排序操作。最后,cal
核函数进行了实际的稀疏矩阵乘法计算,并将结果存储在矩阵 C 的值数组中。