矩阵模板类

测了下int型没问题,之后再完善,还有个卷积运算没看懂懒得做,其他应该都做好了,做的比较快,比较粗糙,也没人测,但大致功能应该都行。

#pragma once
#include 
#include 
#include 
#include "..\utils\FileOperation.h"
#include "..\utils\Utils.h"
#include 
using namespace std;


#define  MinSize 1                             //控制行列最小值
#define  MaxSize 500                           //控制行列最大值


template 
class CMat;

template
class CMat
{

	template 
	friend CMat operator*(CMat&,CMat&); //矩阵乘法

public:

//---------------构造
	CMat() = delete;                         //删除默认构造函数
	CMat(int r, int c, T value);             //指定行数r、列数c构造初始值均为value的矩阵
	CMat(int r, int c, T* arr);              //给定一维数组arr,构建r行、c列矩阵
	CMat(int r, int c);                      //创建构建r行、c列矩阵,矩阵值从键盘中输入

//---------------析构
	virtual ~CMat();                            //析构函数

//---------------拷贝构造(deep copy)
	CMat(CMat&);


//---------------方法
	void printMat() const;                      //按矩阵行列格式打印矩阵
	void saveMat(const char *filename);         //保存矩阵数据到文件
	void loadMat(const char *filename);         //从文件中读取数据到矩阵
	int  getRows() const;                       //读取矩阵的行数
	int  getCols() const;                       //读取矩阵的列数
	bool isSquare() const;                      //判断是否为方阵
	void reSize(int p, int q);                  //将矩阵形状变换为p行、q列

	CMat fetchRow(int row);                  //取给定行所有元素 row col从0开始取
	CMat fetchCol(int col);                  //取给定列所有元素 row col从0开始取
	void swapRows(int a, int b);                //交换矩阵第a行和第b行元素的值 row col从0开始取
	void swapCols(int a, int b);                //交换矩阵第a列和第b列元素的值row col从0开始取
	static CMat createEyeMat(int n);         //创建一个行数列数均为n的单位矩阵


	void detMat();                             //求矩阵行列式
	CMatconvMat(CMata, CMatb);        //矩阵卷积运算

//---------------重载						    
	T operator()(int row, int col);            //按给定的行和列取元素 row col从0开始取
	CMat&operator+(CMat&);               //矩阵加法
	CMat&operator-(CMat&);               //矩阵减法
	CMat&operator=(CMat&);               //矩阵深度复制
	bool operator==(CMat&);                 //判断矩阵相等
	bool operator!=(CMat&);                 //判断矩阵不相等

private:
	int row;                                   //行
	int col;                                   //列
	T*  data;                                  //指向数据的一维指针
};


template
void CMat::printMat() const
{
		cout<<"当前实例矩阵为:\n";
	cout << "-------------------\n";
	for(int i = 0; i < row; i++)
	{
		for (int j = 0; j < col; j++)
		{
			cout<<*(data + i*col + j);
			if (j < col - 1)
			{
				cout << "    ";
			}
			else if(j==col-1)
			{
				cout <<"\n";
			}
		}
	}
	cout << "-------------------\n";
}

template
void CMat::saveMat(const char *filename)
{
	//文件格式: 1   21  21
	//          12  23  32     
	//          12  23  11 
	//          23  12  33

	string write_str = "";

	for (int i = 0; i < row; i++)
	{
		for (int j = 0; j < col; j++)
		{
			//string cur_str=atio

			if (write_str == "")
			{
				write_str += to_string(*(data + i*col + j));
			}
			else
			{

				if (j == col - 1)
				{	//每行的最后一个加换行符
					write_str += " ";
					write_str += to_string(*(data + i*col + j));
					write_str += "\n";
				}
				else if (j == 0)
				{
					//每行的第一个什么都不加
					write_str += to_string(*(data + i*col + j));
				}
				else
				{
					//中间段 加空格即可
					write_str += " ";
					write_str += to_string(*(data + i*col + j));
				}
			}
		}
	}

	//删除结尾的\n
	write_str.pop_back();


	//暂时不考虑中文编码问题-有需求再加
	//if (!FileOperation::IsFileExist(filename))/不考虑不存在
	//{
	//	return;
	//}

	//删除旧的
	bool delete_ret = FileOperation::GetFileOperation()->DeleteFile(filename);

	//如果没有 创建一个 然后直接写入
	bool create_ret = FileOperation::GetFileOperation()->CreateFile(write_str, filename);
	//失败情况不考虑
}

template
void CMat::loadMat(const char *filename)
{
		//cout<<"开始读取"<< endl;
		string read_str;
		if (!FileOperation::GetFileOperation()->IsFileExist(filename)) //不考虑不存在
		{
			cout << "读取的文件路径不存在:" << filename << endl;
			return;
		}
		else
		{
			bool read_ret = FileOperation::GetFileOperation()->ReadStringFromFile(filename, read_str);
		}

		//将字符串分割获取每一行
		vectorall_row_vec;
		SplitString(read_str, all_row_vec, "\n");

		if (all_row_vec.size() <= 0)
		{
			cout << "读取文件失败" << filename << endl;
		}

		this->row = static_cast(all_row_vec.size());
		cout << "文件矩阵行数" << to_string(row) << endl;


		//先取第一行列数 来构造data
		vectorfirst_row_vec;
		SplitString(all_row_vec[0], first_row_vec, " ");
		if (first_row_vec.size() <= 0)
		{
			cout << "读取文件失败" << filename << endl;
			return;
		}
		//去第一行分割的列数 赋值给当前类
		this->col = static_cast(first_row_vec.size());
		cout << "文件矩阵列数" << to_string(col) << endl;

		//构造data
		data = new T[col*row];


		//以下语句赋值有崩溃风险 必须保证文件存取一定要正确
		for (int i = 0; i < all_row_vec.size(); i++)
		{
			vectorone_row_vec;
			SplitString(all_row_vec[i], one_row_vec, " ");

			for (int j = 0; j < col; j++)
			{
				istringstream iss(one_row_vec[j]);
				T num;
				iss >> num;
				*(data + i*col + j) = num;
			}
		}
		//cout << "结束读取,读取成功" << endl;
}

template
int CMat::getRows() const
{
	cout << "当前矩阵行数 " << to_string(row) << endl;
	return row;
}

template
int CMat::getCols() const
{
	cout << "当前矩阵列数 " << to_string(col) << endl;
	return col;
}

template
bool CMat::isSquare() const
{
	if (col == row)
	{
		cout << "当前矩阵是方阵" << endl;
		return true;
	}
	else
	{
		cout << "当前矩阵不是方阵" << endl;
		return false;
	}
}

template
void CMat::reSize(int p, int q)
{
	//p*q必须和当前矩阵row*col一样
	if (p*q != row*col)
	{
		//不符合条件取消
		return;
	}

	//符合条件,进行变换
	//直接将row和col用p q赋值即可
	this->row = p;
	this->col = q;
}

template
T CMat::operator()(int row, int col)
{
	if (row < 0 || row >= this->row || col < 0 || col >= this->col)
	{
		cout << "重载()输入超出范围" << endl;
	}
	T cur_value = *(data + row*(this->col) + col);
	cout << "该矩阵第" << to_string(row) << "行,第" << to_string(col) << "列值为:" << to_string(cur_value) << endl;
	return cur_value;
}

template
CMat& CMat::operator+(CMat&mat)
{
		//判断是不是同型的矩阵
		int target_row = mat.getRows();
		int target_col = mat.getCols();

		if (target_row != this->row || target_row != this->col)
		{
			cout << "不同型矩阵不能相加" << endl;
		}

		//数据部分不判断了

		T*target_data = mat.data;

		//相加
		for (int i = 0; i < row; i++)
		{
			for (int j = 0; j < col; j++)
			{
				T target_value = *(target_data + i*row + j);
				T cur_value = *(this->data + i*row + j);
				*(this->data + i*row + j) = cur_value + target_value;
			}
		}
		return *this;
}

template
CMat& CMat::operator-(CMat&mat)
{	//判断是不是同型的矩阵
	int target_row = mat.getRows();
	int target_col = mat.getCols();

	if (target_row != this->row || target_row != this->col)
	{
		cout <<"不同型矩阵不能相减"<data + i*row + j);
			*(this->data + i*row + j) = cur_value - target_value;
		}
	}
	return *this;
}

template
CMat CMat::convMat(CMata, CMatb)
{

}

template
void CMat::detMat()
{

}

template
CMat CMat::createEyeMat(int n)
{
	T*data = new T[n*n];
	for (int i = 0; i < n*n; i++)
	{
		*(data + i) = 1;
	}

	static CMatresult_mat(n, n, data);

	return result_mat;
}

template
void CMat::swapCols(int a, int b)
{
	//判断
	if(a <0||a>= col||b<0||b>=col)
	{
		//参数不符合条件
		return;
	}

	//new一个暂存的数据 存下b列的数据 有row项
	T*temp_data = new T[row];
	for(int i= 0;i
void CMat::swapRows(int a, int b)
{
	//判断
	if(a<0||a>=row||b<0||b>= row)
	{
		//参数不符合条件
		return;
	}

	//new一个暂存的数据 存下b行的数据 有col项
	T*temp_data = new T[col];

	for(int j=0; j
CMat CMat::fetchCol(int col)
{
	//给定行,则矩阵为n行 1列  参数col从0开始取
	int mat_row = this->row;
	int mat_col = 1;

	T*data = new T[mat_row*mat_col];

	for (int i= 0; i< mat_row; i++)
	{
		*(data + i) = *(this->data + row*i + col);
	}
	CMatresult_mat(mat_row, mat_col, data);
	return result_mat;
}

template
CMat CMat::fetchRow(int row)
{
	//给定行,则矩阵为1行 n列  参数row从0开始取
	int mat_row =1;
	int mat_col= this->col;

	T*data = new T[mat_row*mat_col];


	for (int j = 0; j < mat_col; j++)
	{
		*(data + j) = *(this->data + row*col + j);
	}
	CMatresult_mat(mat_row, mat_col,data);
	return result_mat;
}

template
bool CMat::operator!=(CMat&)
{
	//判断是不是同一个对象
	if (this == &mat)
	{
		return false;
	}

	//判断数据是否相等
	T *target_data = mat.data;
	bool is_data_equal = true;
	for (int i = 0; i < row; i++)
	{
		for (int j = 0; j < col; j++)
		{
			T target_value = *(target_data + i*row + j);
			T cur_value = *(this->data + i*row + j);
			if (target_value != cur_value)
			{
				is_data_equal = false;
			}
		}
	}
	if (mat.row==this->row&&mat.col==this->col&&is_data_equal)
	{
		return false;
	}
	return true;
}

template
bool CMat::operator==(CMat&mat)
{
	//判断是不是同一个对象
	if (this == &mat)
	{
		return true;
	}

	//判断数据是否相等
	T *target_data = mat.data;
	bool is_data_equal=true;
	for (int i = 0; i < row; i++)
	{
		for (int j = 0; j < col; j++)
		{
			T target_value = *(target_data + i*row + j);
			T cur_value = *(this->data + i*row + j);
			if (target_value != cur_value)
			{
				is_data_equal = false;
			}
		}
	}


	if (mat.row != this->row || mat.col != this->col || !is_data_equal)
	{
		return false;
	}
	return true;
}

template
CMat& CMat::operator=(CMat&mat)
{
	//判断是不是同一个对象
	if (this== &mat)
	{		
		return *this;
	}

	if (this->data != NULL)
	{
		delete [] this->data;
	}
	this->row = mat.row;
	this->col = mat.col;
	this->data = new T(row*col);
	//深拷贝
	memcpy(this->data, mat.data,row*col * sizeof(T)); //后面sizeof(T)不要忘记
	return *this;
}


template
CMat operator*(CMat&mat1, CMat&mat2)
{
	U*temp_data = new U[1];
	*temp_data = 1;
	CMatresult_mat(1, 1, temp_data);

	//第一个矩阵的列数col和第二个矩阵的行数row相同才可相乘
	//判断是不是同一个对象
	if (mat1 == mat2)
	{
		return  result_mat;
	}
	if (mat1.col != mat2.row || mat1.row < 0 || mat1.col < 0 || mat2.row < 0 || mat2.col < 0)
	{
		//不符合要求
		return result_mat;
	}

	//通过判断 开始矩阵相乘
	int result_row = mat1.row;
	int result_col = mat2.col;



	//清除之前的内存空间,重新分配空间 !!这里不需要清除空间 析构会调用 会删除一次报错
	//delete[]temp_data;
	//temp_data = NULL;
	
	U *new_data = new U[result_row*result_col];
	
   
	//a第i行和b第j列相乘,相加
	int num = mat1.col;//取一个基准数值 应等于a列 或者b行
	for (int i = 0; i < result_row; i++)
	{
		for (int j = 0; j < result_col; j++)
		{
			//从mat1的第i行取num列数的值  从mat2的第j列取num行数的值 相乘再相加  
			U result_value = 0;
			for (int k = 0; k < num; k++)
			{
				U value = (*(mat1.data + i*mat1.col + k))*(*(mat2.data + j + k*mat2.col));
				result_value += value;
			}
			*(new_data + i*result_col + j) = result_value;
		}
	}
	CMatnew_mat(result_row, result_col, new_data);
	return new_mat;
}

template
CMat::CMat(CMat&mat)
{
	if (mat.rowMaxSize || mat.colMaxSize|| mat.data==NULL)
	{
		//不符合要求
		assert(0);
		return;
	}

	this->row = mat.row;
	this->col = mat.col;

	//申请空间
	data = new T[row*col];
	//深拷贝
	memcpy(this->data, mat.data, row*col*sizeof(T));
}

template
CMat::CMat(int r, int c, T value)
{
	if (rMaxSize || cMaxSize)
	{
		//行列不符合要求
		assert(0);
		return;
	}
	//构造
	data = new T[r*c];

	//赋值
	for (int cur_r = 0; cur_r < r; r++)
	{
		for (int cur_c = 0; cur_c < c; c++)
		{
			*(data + cur_r*cur_c - 1) = value;
		}
	}
}

template
CMat::CMat(int r, int c, T* arr)
{
	if (rMaxSize || cMaxSize || arr == NULL)
	{
		assert(0);
		return;
	}
	//这里要再加个判断 判断 arr是否符合r*c的size

	this->row = r;
	this->col = c;
	this->data = arr;

}

template
CMat::CMat(int r, int c)
{
	if (rMaxSize || cMaxSize)
	{
		//行列不符合要求或者数据指向为空
		cout << "Struct failed---OutOfRange" << endl;
		return;
	}

	//开始
	this->data = new T[r*c];
	this->row = r;
	this->col = c;

	int size = r*c;
	cout << "需要输入:" << to_string(size) << "个数字" << "\n";

	for (int i = 0; i < size; i++)
	{
		cout << "编号:" << to_string(i) << " " << "数值:";
		T input;
		cin >>input;
		//赋值
		*(data + i) = input;
	}
}

template 
CMat::~CMat()
{
	if (this->data != NULL)
	{
		delete[]data;
	}
}


 

你可能感兴趣的:(c++)