数据结构 稀疏矩阵乘法

【数据结构】稀疏矩阵乘法

1.传统矩阵相乘的算法使用三个嵌套循环实现,算法复杂度为O(m * n1 * n2)
2.使用三元组顺序表存储稀疏矩阵时,实现 Q= M * N,对于M中M(i,j)元素来说,只需要与N中第j行元素N(j,q)相乘,再存入Q(i,q)中。为了实现这一操作,增加一个向量rpos,表示每一行的第一个非零元在三元组中的位置,rpos作用相当于快速转置中的cpot向量。
这种结构叫做 行链接的顺序表

typedef struct {
	Triple data[MAXSIZE + 1]; //非零三元组表,data0未用
	int rpos[MAXRC + 1];	  //各行第一个非零元的位置表
	int mu, nu, tu;			  //三元组行数,列数,非零元素值
}RLSMatrix;

在创建矩阵时,需要求得矩阵的rpos向量,求法与cpot一致。

Status CreateRLSMatrix(RLSMatrix &M) {
	cout << "输入矩阵行数,列数,非零元素个数:" << endl;
	cin >> M.mu >> M.nu >> M.tu;
	cout << "依次输入" << M.tu << "个元素的行数,列数,元素值:" << endl;
	for (int i = 1; i <= M.tu; i++)
	{
		cin >> M.data[i].i >> M.data[i].j >> M.data[i].e;
	}
	int* num = (int*)malloc(M.mu * sizeof(int));
	int arow;
	int q;
	if (M.tu) {
		for ( arow = 1; arow <= M.mu; ++arow) num[arow] = 0;
		for (int t = 1; t <= M.tu; ++t) ++num[M.data[t].i];//求M中每一行含非零元个数
		M.rpos[1] = 1;
		//求第col列中第一个非零元在b.data 中的序号
		for (arow = 2; arow <= M.mu; ++arow)M.rpos[arow] = M.rpos[arow - 1] + num[arow - 1];
	}
		cout << "矩阵构造完成" << endl;
	return OK;
}

稀疏矩阵乘法
算法思路:
大循环是对M中的每一行进行。
rpos向量用来确定每行元素个数
再对该行每个元素M(i,j)进行处理,找到N中对应的第 j 行,将这一行所有元素与M(i,j)相乘,结果存入ctemp累加器中,该行元素都处理完成后,ctemp中所存储的便是Q中第 i 行数据

Status MultSMatrix(RLSMatrix M, RLSMatrix N, RLSMatrix &Q) {
	//求稀疏矩阵乘积 Q = M x N  采用行逻辑链接存储表示
	if (M.nu != N.mu)return ERROR;
	Q.mu = M.mu; Q.nu = N.nu; Q.tu = 0;
	int arow,tp,p,brow,t,q,ccol;
	int* ctemp = (int*)malloc((M.mu+1) * sizeof(int));
	if (M.tu * N.tu != 0) {
		for (arow = 1; arow <= M.mu; ++arow) {
		//对M的每行进行处理
			for (int i = 0; i <= M.mu; i++)ctemp[i] = 0;//累加器清零
			Q.rpos[arow] = Q.tu + 1;
			if (arow < M.mu)tp = M.rpos[arow + 1];
			else tp = M.tu + 1;//确定M中该行元素个数
			for ( p = M.rpos[arow]; p < tp; p++)
			{//M中对该行的每个元素进行处理
				brow = M.data[p].j;
				if (brow < N.mu) t = N.rpos[brow + 1];
				else t = N.tu + 1;
				for (q = N.rpos[brow]; q < t; ++q) {
				//将该元素与矩阵N中对应行的元素进行相乘,结果存入累加器
					ccol = N.data[q].j;
					ctemp[ccol] += M.data[p].e *N.data[q].e;
				}
			}
			for(ccol =1;ccol<=Q.nu;++ccol)
				if (ctemp[ccol]) {
				//第arow行处理完毕,将累加器中值存入Q中
					if (++Q.tu > MAXSIZE)return ERROR;
					Q.data[Q.tu] = { arow, ccol, ctemp[ccol] };
				}//if
		}//for arow
	}//if
}

该算法总的时间复杂度为O(M.mu * N.nu + M.tu * N.tu/N.mu)
ctemp初始化 O(M.mu * N.nu)
求Q中所有非零元 O(M.tu * N.tu / N.mu)
压缩存储 O(M.mu * N.nu)

你可能感兴趣的:(数据结构严蔚敏)