下面展示一些 `tirtuple.h文件。
// trituple.h
#pragma once
#include
using namespace std;
template<class T>
struct Trituple
{
int row, col;
T value;
Trituple<T>& operator=(Trituple<T>& x)
{
row = x.row;
col = x.col;
value = x.value;
return *this;
}
};
下面展示一些 "sparematrix.h
。
// An highlighted block
#pragma once
#include"trituple.h"
using std::ostream;
template<class T>
class sparsematrix
{
public:
sparsematrix<T>& transpose();
sparsematrix<T>& add(sparsematrix<T>& );
sparsematrix(int,int);
sparsematrix(sparsematrix<T>& );
~sparsematrix();
T& getnum(int, int)const ;
void insert(Trituple<T>& tmp);
sparsematrix<T>& operator=(sparsematrix<T>& SM);
sparsematrix<T>& multiply(sparsematrix<T>& b);
friend ostream& operator<<(ostream& ostr, sparsematrix<T>& SM);
friend istream& operator >> (istream& istr,sparsematrix<T>& SM);
private:
int Rows; //行数
int Cols; //列数
int Terms; //非0元素的个数
Trituple<T> *smArry; //存放非零元素的三元数组
int maxTerms; //能容纳的最大元素个数
};
下面展示一些 "sparematrix.cpp
。
// An highlighted block
#include"sparematrix.h"
template<class T>
inline sparsematrix<T>::sparsematrix(int maxcol,int maxrow) :maxTerms(maxcol*maxrow),Rows(maxrow),Cols(maxcol),Terms(0)
{
if (maxcol<1||maxrow<1)
{
cerr << "init ERROR" << endl;
return;
}
}
template<class T>
inline sparsematrix<T>::sparsematrix(sparsematrix<T>& SM)
{
if (SM.Terms==0)
{
smArry = nullptr;
Terms = 0;
return;
}
if (SM == *this)
{
return;
}
Rows = SM.Rows;
Cols = SM.Cols;
Terms = 0;
maxTerms = SM.maxTerms;
for (int i = 0; i < Terms; i++)
{
insert(SM.smArry[i]);
}
}
template<class T>
sparsematrix<T>& sparsematrix<T>::operator=(sparsematrix<T>& SM)
{
if (SM.Terms == 0)
{
smArry = nullptr;
Terms = 0;
return SM;
}
if (SM == *this)
{
return *this;
}
Rows = SM.Rows;
Cols = SM.Cols;
Terms = SM.Terms;
maxTerms = SM.maxTerms;
smArry = new Trituple<T>[Terms];
for (int i = 0; i < Terms; i++)
{
smArry[i] = SM.smArry[i];
}
}
template<class T>
inline sparsematrix<T>::~sparsematrix()
{
if (maxTerms!=0)
{
delete[] smArry;
Terms = 0;
}
}
template<class T>
inline T & sparsematrix<T>::getnum(int row1, int col1)const
{
for (int i = 0; i < Terms; i++)
{
if (smArry[i].col == col1&&smArry[i].row == row1)
{
return smArry[i].value;
}
}
return NULL;
}
template<class T>
void sparsematrix<T>::insert(Trituple<T>& tmp)
{
for (int i = 0; i < Terms; i++)
{
if (smArry[i].row = tmp.row&&smArry[i].col == tmp.col)
{
crr << "already exist" << endl;
return;
}
}
Terms++;
Trituple<T>* smarry = new Trituple<T>[Terms];
if(Terms>1&&tmp.col<=Cols&&tmp.row<=Rows)//方便后续加法操作,按照索引排序,如果之前就存在元素,那么:
{
for (int i = 0; i < Terms-2; i++)//遍历到倒数第二个就行,不然指针越界
{
//如果比最小的还要小
if (tmp.row*Cols*tmp.col < smArry[0].row*Cols + smArry[0].col)
{
smarry = tmp;
for (int j = 0; j < Terms - 1; j++)
{
smarry[j + 1] = smArry[j];
}
break;
}
//如果在中间
else if (smArry[i].row*Cols + smArry[i].col <= tmp.row*Cols*tmp.col
&&tmp.row*Cols*tmp.col <= smArry[i+1].row*Cols + smArry[i+1].col)
{
for (int j = 0; j < i; j++)
{
smarry[j] = smArry[j];
}
smarry[i] = tmp;
for (int j = i; j < Terms-1; j++)
{
smarry[j+1] = smArry[j];
}
break;
}
//如果在尾巴上
else if (smArry[Terms - 2].row*Cols + smArry[Terms - 2].col < tmp.row*Cols*tmp.col)
{
smarry[Terms - 1] = tmp;
break;
}
}
delete[] smArry;
}
smArry = smarry;//头指针赋值
}
template<class T>
ostream& operator<<(ostream& ostr, sparsematrix<T>& SM)
{
ostr << "rows=" << SM.Rows << endl;
ostr << "Cols=" << SM.Cols << endl;
ostr << "terms=" << SM.Terms << endl;
for (int i = 0; i < SM.Terms; i++)
{
ostr << i + 1 << ":<" << SM.smArry[i].row << "," << SM.smArry[i].col << ">="
<< SM.smArry[i].value << endl;
}
return ostr;
}
template<class T>
istream & operator>>(istream & istr, sparsematrix<T>& SM)
{
istr >> SM.Rows >> SM.Cols >> SM.Terms;
if (SM.Terms > SM.maxTerms)
{
cerr << "index overflowed" << endl;
exit(1);
}
for (int i=0; i < SM.Terms; i++)
{
cin >> SM.smArry[i].row >> SM.smArry[i].col >> SM.smArry[i].value;
}
return istr;
}
template<class T>
sparsematrix<T>& sparsematrix<T>::add(sparsematrix<T>& b)
{
sparsematrix<T> result(cols,rows);
if (Rows != b.Rows || Cols != b.Cols)
{
cerr << "incompatable!" << endl;;
return result;
}
if (b.Terms == 0)
{
return *this;
}
if (Terms == 0)
{
return b;
}
result.Rows = Rows;
result.Cols = Cols;
result.Terms = 0;
result.maxTerms = Rows*Cols;
int i = 0, j = 0, index_a, index_b;
while (i < Terms&&j<b.Terms) //当两个矩形的非零元素都没有到极限的时候
{
//前提是smarry是按照左上到右下,从左到右排列的,也就是每个非零元素在矩阵中的索引递增
index_a = smArry[i].row*Cols + smArry[i].col;//计算每一个非零元素在矩阵中的位置
index_b = b.smArry[j].row*b.Cols + b.smArry[i].col;
if (index_a < index_b)//如果本矩阵的元素排在前面,就把本矩阵的元素插入目标矩阵
{
result.insert(smArry[i]);
i++;//下次计算下一个位置
}
else if (index_a > index_b) //如果被加矩阵的元素排在前面
{
result.insert(smArry[j]);
j++;//计算被加矩阵下一个元素的位置
}
else //如果两元素在同一个索引
{
if (smArry[i].value + b.smArry[j].value)
{
Trituple<T> tmp;
tmp = smArry[j];
tmp.value = smArray[i].value + b.smArray[j].value;
result.insert(tmp);
i++;
j++;
}
}
}
if (Terms > b.Terms) //当被加矩阵算完了,还剩下原来矩阵的数,因为原来矩阵中数据不重复,因此直接插入就行
{
for (; i < Terms; i++)
{
result.insert(smArry[i]);
i++;
}
}
else
{
for (; j<b.Terms; j++)
{
result.insert(b.smArry[j]);
j++;
}
}
return result;
}
template<class T>
sparsematrix<T>& sparsematrix<T>::transpose()
{
int *colSize = new int[Cols+1];//每列非零元素啊的个数
int *rowstart = new int[Cols+1];//b中每一行的首个非零元素在smArry中的索引,因为是转置,因此列为行
sparsematrix<T> b(maxTerms);//转置矩阵对应的三元组
b.Rows = Cols; //b的性质
b.Cols = Rows;
b.Terms = Terms;
b.maxTerms = maxTerms;
if (Terms > 0) //如果存在非零元素
{
int i;
for (i = 0; i < Cols; i++)
{
colSize[i] = 0;//初始化
}
for (i = 0; i < Terms; i++)
{
colSize[smArry[i].col]++;//根据第i个非零数据的列属性,计算出每列
}//new必须加1,否则这里会越界
rowstart[0] = 0;
for (i = 1; i < Cols+1; i++)//
{
rowstart[i] = rowstart[i - 1] + colSize[i - 1];//后一行的第一个首个非零元素的索引是前行首个非零元素索引+非零元素个数
}
//以上的所有步骤都是为了找到a中每一个smarry在转置后的矩阵中的脚标
for (i = 0; i < Terms; i++)//遍历三元组
{
int j = rowstart[smArry[i].col]++;//对应第i给元素所在列的首个不为零的数据的索引,按照行从左到右顺序存入
//获取值之后自加1,因为一列可能有多个元素,下次再到这一列的时候,就是这一列的下一个元素,
//同样转置后的矩阵的非零元素的下一个元素的脚标也要加1
b.smArry[j] = smArry[i];
}
}
delete[] rowstart;
delete[] colSize;
return b;
}
template<class T>
sparsematrix<T>& sparsematrix<T>::multiply(sparsematrix<T>& b)
{
sparsematrix<T> result = {
};
if (Cols != b.Rows)
{
cerr << "unable to multiply" << endl;
return result;
}
int *colSize = new int[b.Cols+1];//b矩阵每列非零元素的个数
int *colStart = new int[b.Cols + 1];//b矩阵每列第一个非零元素在b中的下标
int i;
for (i = 0; i < b.Cols; i++)
{
colSize[i] = 0;
}
for (i = 0; i <Terms; i++)
{
colSize[b.smArry[i].col]++;
}
colStart[0] = 0;
for (i = 1; i < b.Cols+1; i++)
{
colStart[i] = colStart[i - 1] + colSize[i - 1];
}
int index = 0;//非0元素的脚标
int temp[b.Cols] = {
};//暂存每个元素的运算结果
//获取第一个元素所在行
while (index < Terms)
//只要不结束,结束上一行的所有计算后,进行下一行的计算
{
int row_a = smArry[index].row;
while (index < Terms&&smArry[index].row == row_a) //数组是从左到右,从上到下排列,因此可以这么写,获得本身在rowa_a的元素
{
//第一个判断条件是怕index越界,产生乱码
for (int j = 1; j < b.Cols + 1; j++) //针对每列元素
{
for (i = colStart[j]; i < colStart[j + 1]; i++) //针对b在第col_a列的非零元素,进行乘积操作
{
int row_b = b.smArray[i].row;
if (smArry[index].col == row_b) //如果b矩阵中非零元素的行等于本身非零元素的列,也就是说两个位置元素都为非零,才能相乘
{
temp[j-1] += smArry[index].value*b.smArry[i].value;
}
}
}
index++;
}
for (i = 0; i < b.Cols; i++)
{
if (temp[i] != 0)
{
Trituple<T>& Result;
Result.row = smArray[index-1].row;//一定要减一,不然会row会全部+1
Result.col = i;
Result.value = temp[i];
result.insert(Result);
}
}
}
result.Rows = Rows;
result.Cols = b.Cols;
delete[] colSize;
delete[] colStart;
return result;
}