[数据结构]稀疏矩阵乘法算法实现


作者 zhonglihao
算法名 稀疏矩阵乘法 Sparse Matrix Multiplication
分类 数据结构
复杂度 O(n^2)
形式与数据结构 C++代码 一维结构体存储
特性 极简封装 不使用链表 不需要转置 计算过程容易理解
具体参考出处 《算法导论》(写的不想看)
备注

// ConsoleApplication1.cpp : 定义控制台应用程序的入口点。
//

#include "stdafx.h"
#include "stdio.h"
#include "stdlib.h"

//稀疏矩阵存储结构体 第一个元素为矩阵头,包含行列长度,元素总个数
typedef struct
{
	int row;
	int col;
	int element;
}sparse_mat;

void SparseMatrixRectPrint(sparse_mat* s_mat);
void SparseMatrixTriPrint(sparse_mat* s_mat);
sparse_mat* SparseMatrixMul(sparse_mat* s_mat_A, sparse_mat* s_mat_B);

int _tmain(int argc, _TCHAR* argv[])
{
	int i, j, k;
	const int mat_A_row = 4;
	const int mat_A_col = 4;
	const int mat_B_row = 4;
	const int mat_B_col = 4;

	//原矩阵
	int mat_A[mat_A_row][mat_A_col] = { 1, 1, 0, 0,
										0, 0, 1, 0, 
										0, 1, 0, 0, 
										0, 0, 1, 0 };

	int mat_B[mat_B_row][mat_B_col] = { 1, 0, 0, 0,
										0, 1, 0, 0, 
										0, 0, 1, 0, 
										0, 0, 0, 1 };

	//计算有效元素数量
	int mat_A_ele_count = 0;
	int mat_B_ele_count = 0;

	for (i = 0; i < mat_A_row; i++)
	{
		for (j = 0; j < mat_A_col; j++)
		{
			if (mat_A[i][j] != 0) mat_A_ele_count++;
		}
	}

	for (i = 0; i < mat_B_row; i++)
	{
		for (j = 0; j < mat_B_col; j++)
		{
			if (mat_B[i][j] != 0) mat_B_ele_count++;
		}
	}

	//动态分配
	sparse_mat* sparse_m_A = (sparse_mat*)malloc((mat_A_ele_count + 1)*sizeof(sparse_mat));
	sparse_mat* sparse_m_B = (sparse_mat*)malloc((mat_B_ele_count + 1)*sizeof(sparse_mat));

	//存入稀疏矩阵信息
	sparse_m_A[0].row		= mat_A_row;
	sparse_m_A[0].col		= mat_A_col;
	sparse_m_A[0].element	= mat_A_ele_count;
	sparse_m_B[0].row		= mat_B_row;
	sparse_m_B[0].col		= mat_B_col;
	sparse_m_B[0].element	= mat_B_ele_count;

	for (i = 0, mat_A_ele_count = 0; i < mat_A_row; i++)
	{
		for (j = 0; j < mat_A_col; j++)
		{
			if (mat_A[i][j] != 0)
			{
				mat_A_ele_count++;
				sparse_m_A[mat_A_ele_count].element = mat_A[i][j];
				sparse_m_A[mat_A_ele_count].row = i;
				sparse_m_A[mat_A_ele_count].col = j;
			}
		}
	}

	for (i = 0, mat_B_ele_count = 0; i < mat_B_row; i++)
	{
		for (j = 0; j < mat_B_col; j++)
		{
			if (mat_B[i][j] != 0)
			{
				mat_B_ele_count++;
				sparse_m_B[mat_B_ele_count].element = mat_B[i][j];
				sparse_m_B[mat_B_ele_count].row = i;
				sparse_m_B[mat_B_ele_count].col = j;
			}
		}
	}

	//打印原数组
	SparseMatrixRectPrint(sparse_m_A);
	SparseMatrixRectPrint(sparse_m_B);
	//SparseMatrixTriPrint(sparse_m_A); 
	//SparseMatrixTriPrint(sparse_m_B);

	//计算稀疏矩阵乘法
	sparse_mat* sparse_m_C = (sparse_mat*)SparseMatrixMul(sparse_m_A, sparse_m_B);
	SparseMatrixRectPrint(sparse_m_C);

	system("Pause");
	return 0;
}

//三元组稀疏矩阵乘法函数 极简封装 需要花费一点时间计算申请的内存 但是肯定比链表省空间啦
//Method Written By Zhonglihao
sparse_mat* SparseMatrixMul(sparse_mat* s_mat_A, sparse_mat* s_mat_B)
{
	int i, j, k;
	int s_mat_C_row			= s_mat_A[0].row;
	int s_mat_C_col			= s_mat_B[0].col;
	int s_mat_A_ele_count	= s_mat_A[0].element;
	int s_mat_B_ele_count   = s_mat_B[0].element;

	//判断是否能够相乘 或 有一个全为0 那就不用乘啦
	if (s_mat_A[0].col != s_mat_B[0].row) return NULL;
	if (s_mat_A_ele_count == 0 || s_mat_B_ele_count == 0)
	{
		sparse_mat* s_mat_C	= (sparse_mat*)malloc((1)*sizeof(sparse_mat));
		s_mat_C[0].row		= s_mat_C_row;
		s_mat_C[0].col		= s_mat_C_col;
		s_mat_C[0].element	= 0;
		return s_mat_C;
	}

	//申请一个长度为B列宽的缓存 两个用途 计算输出大小时做列封禁,计算相乘时做和缓存
	int* col_buffer = (int*)malloc(s_mat_C_col*sizeof(int));
	//清空缓存区
	for (k = 0; k < s_mat_C_col; k++) col_buffer[k] = 0;

	//判断需要输出的三元大小申请内存
	int malloc_element_count = 0;
	for (i = 1; i <= s_mat_A_ele_count; i++)
	{
		if (i >= 2 && s_mat_A[i].row != s_mat_A[i - 1].row) //换行解禁
		{
			for (k = 0; k < s_mat_C_col; k++) col_buffer[k] = 0;
		}

		for (j = 1; j <= s_mat_B_ele_count; j++)
		{
			if ((s_mat_A[i].col == s_mat_B[j].row) && col_buffer[s_mat_B[j].col] != 1)//没有列封禁
			{
				col_buffer[s_mat_B[j].col] = 1;//列封禁
				malloc_element_count++;
			}
		}
	}

	sparse_mat* s_mat_C		= (sparse_mat*)malloc((malloc_element_count + 1)*sizeof(sparse_mat));
	s_mat_C[0].row			= s_mat_C_row;
	s_mat_C[0].col			= s_mat_C_col;
	s_mat_C[0].element		= malloc_element_count;
	int s_mat_C_ele_count	= 0;//用于存入元素时做指针

	//开始进行乘法相乘
	for (k = 0; k < s_mat_C_col; k++) col_buffer[k] = 0;//清理列缓存
	for (i = 1; i <= s_mat_A_ele_count; i++)
	{
		for (j = 1; j <= s_mat_B_ele_count; j++)
		{
			if (s_mat_A[i].col == s_mat_B[j].row)//有效用 压入缓存区
				col_buffer[s_mat_B[j].col] += s_mat_A[i].element * s_mat_B[j].element;
		}

		//如果要换行或者是最后一行
		if (((i != s_mat_A_ele_count) && (s_mat_A[i].row != s_mat_A[i + 1].row)) || i == s_mat_A_ele_count)
		{
			//扫描缓存组
			for (k = 0; k < s_mat_C_col; k++)
			{
				//如果该点不是0 压入三元组 清零缓存
				if (col_buffer[k] != 0)
				{
					s_mat_C_ele_count++;
					s_mat_C[s_mat_C_ele_count].row = s_mat_A[i].row;
					s_mat_C[s_mat_C_ele_count].col = k;
					s_mat_C[s_mat_C_ele_count].element = col_buffer[k];
					col_buffer[k] = 0;
				}
			}
		}
	}

	//释放缓存 返回结果
	free(col_buffer);
	return s_mat_C;
}

//稀疏矩阵打印 按矩形打印 需要确定三元组按Z排列有序
void SparseMatrixRectPrint(sparse_mat* s_mat)
{
	//获取行列信息
	int i, j;
	int row = s_mat[0].row;
	int col = s_mat[0].col;

	//打印元素递增 前提是三元组按照行列顺序排好,就只需要递增下标
	int ele_count = 1;

	//按矩阵扫描打印
	for (i = 0; i < row; i++)
	{
		for (j = 0; j < col; j++)
		{
			if (i == s_mat[ele_count].row && j == s_mat[ele_count].col)
			{
				printf("%d\t", s_mat[ele_count].element);
				ele_count++;
			}
			else
			{
				printf("0\t");
			}
		}//for
		printf("\n");
	}//for
	
	//跳空换行 返回
	printf("\n");
	return;
}

//稀疏矩阵打印 按三元组结构打印
void SparseMatrixTriPrint(sparse_mat* s_mat)
{
	int i, j;
	int ele_count = s_mat[0].element;

	//按顺序打印
	for (i = 1; i <= ele_count; i++)
	{
		printf("%d\t%d\t%d\n", s_mat[i].row, s_mat[i].col, s_mat[i].element);
	}

	//跳空换行 返回
	printf("\n");
	return;
}


你可能感兴趣的:(数据结构)