刚学习C++,之前把 Primer 看了一遍,现在也在刷 leetcode,感觉学习编程语言光看书页刷题也是不够的,最好是能做一些实际的项目,这样要用到哪些东西时不明白再看书,就会印象深刻些,否则光看书只是走马观花,看了也就忘了。
打算自己用C++实现一个简单的矩阵 Mat 类,包括一些简单的操作就可以了,但实现起来发现也并没有那么简单,还是遇到很多问题。这个过程也还是学到了不少东西。
先从一个固定数据类型的矩阵的开始吧,以 int 为例子,那么我们需要什么呢? 首先,我们需要记录矩阵的行和列,也需要一块内存存储矩阵的数据;然后需要一些简单的构造函数和其他的成员函数,实现一些基本的操作。
废话不多说,直接上程序:
头文件 Mat.h:
#ifndef _MAT_H_
#define _MAT_H_
#include
#include
#include
#include
#include
//implement Mat class in c++
class Mat{
friend std::ostream& operator<<(std::ostream &os, const Mat &m);
friend std::istream& operator>>(std::istream &is, Mat &m);
public:
typedef int value_type;
typedef std::vector<int>::size_type size_type;
//construct
Mat();
Mat(size_t i, size_t j);
//copy constructor
Mat(const Mat& m);
//copy assignment
Mat& operator=(const Mat& m);
// +=
Mat& operator+=(const Mat& m);
// -=
Mat& operator-=(const Mat& m);
//destructor
~Mat();
//access element value
int& operator()(size_t i, size_t j);
const int& operator()(size_t i, size_t j) const;
//get row and col number
const size_t rows() const{ return row; }
const size_t cols() const{ return col; }
//resize
void resize(size_t nr, size_t nc);
private:
size_t row;
size_t col;
std::vector<std::vector<int>> data;
};
#endif
刚开始的时候,我是用 int* 来存储的数据,用int* 不方便的一个地方就是需要自己来管理内存,在实现一些 复制赋值和移动操作时都需要很小心处理,来管理内存,稍有不慎就会出问题。所以我后来改用C++标准容器 vector 了。这样就不用自己来管理内存了。
另外,像 + 、- 这样两边的数可以互换位置的操作符,重载时最好用非成员函数来重载,而且像 +=、 -= 这样的复合操作符,必须要用成员函数来重载。另外,可以用 +=、-= 来实现非成员函数的 + 、- 操作。最后要强调的是,非成员的运算符重载函数(+,-,,/)之类,返回类型应该是 const 类型,避免类似 a b = c 这样的语句通过编译。
像输出(<<)、输入(>>)运算符,需要用非成员函数来重载,一般还要声明为友元函数。不过可以通过成员函数完成输入输出操作,非成员函数来重载、调用成员函数输入输出来避免声明友元函数,后面一个例子就是这样实现的。
最后要强调的一点就是针对const对象,需要有对应的const版本的函数,具体说来,就是对一些不改变类类型的变量成员函数,尽量声明为const类型,这样const对象也能调用这些函数。在返回矩阵数据时,需要针对const 和 非const的对象,实现两个函数。
//access element value
int& operator()(size_t i, size_t j);
const int& operator()(size_t i, size_t j) const;
例如,重载 ( ) 运算符时,需要实现两个版本的函数,const 对象调用第二个版本函数,非 const 对象调用第一个版本的函数。
源文件Mat.cpp
#include
#include
#include
#include "matint.h"
using std::cout;
using std::endl;
using std::istream;
using std::ostream;
using std::stringstream;
ostream& operator<<(ostream &os, const Mat&m){
for (size_t i = 0; i < m.row; i++){
for (size_t j = 0; j < m.col; j++){
os << m.data[i][j] << " ";
}
os << std::endl;
}
os << std::endl;
return os;
}
istream& operator>>(istream &is, Mat&m){
for (size_t i = 0; i < m.row; i++){
for (size_t j = 0; j < m.col; j++){
is >> m.data[i][j];
}
}
return is;
}
// +
const Mat operator+(const Mat& m1, const Mat& m2){
Mat t = m1;
t += m2;
return t;
}
// -
const Mat operator-(const Mat& m1, const Mat& m2){
Mat t = m1;
t -= m2;
return t;
}
//constructor
Mat::Mat(){
cout << "default constructor" << endl;
row = 0;
col = 0;
data.clear();
}
Mat::Mat(size_t i, size_t j){
row = i; col = j;
std::vector<std::vector<int>> vdata(row, std::vector<int>(col, 0));
data = std::move(vdata);
}
//copy constructor
Mat::Mat(const Mat& m){
cout << "copy constructor" << endl;
row = m.row; col = m.col;
data = m.data;
}
//copy assignment
Mat& Mat::operator=(const Mat& m){
cout << "copy assignment" << endl;
row = m.row; col = m.col;
data = m.data;
return *this;
}
//destructor
Mat::~Mat(){
data.clear();
}
//access element value
int& Mat::operator()(size_t i, size_t j){
assert(i >= 0 && j >= 0 && i < row && j < col);
return data[i][j];
}
const int& Mat::operator()(size_t i, size_t j) const{
assert(i >= 0 && j >= 0 && i < row && j < col);
return data[i][j];
}
//resize
void Mat::resize(size_t nr, size_t nc){
data.resize(nr);
for (size_t i = 0; i < nr; i++){
data[i].resize(nc);
}
col = nc; row = nr;
}
// +=
Mat& Mat::operator+=(const Mat& m){
if (row == m.row && col == m.col){
for (size_t i = 0; i < row; i++)
{
for (size_t j = 0; j < col; j++)
data[i][j] += m.data[i][j];
}
}
else{
std::cerr << "mat must be the same size." << std::endl;
}
return *this;
}
// -=
Mat& Mat::operator-=(const Mat& m){
if (row == m.row && col == m.col){
for (size_t i = 0; i < row; i++)
{
for (size_t j = 0; j < col; j++)
data[i][j] -= m.data[i][j];
}
}
else{
std::cerr << "mat must be the same size." << std::endl;
}
return *this;
}
#if 1
int main(){
Mat mat1(3, 4);
Mat mat2(3, 4);
for (size_t i = 0; i < mat1.rows(); i++){
for (size_t j = 0; j < mat1.cols(); j++){
mat1(i, j) = 1;
mat2(i, j) = 3;
}
}
std::cout << "mat1: " << std::endl << mat1;
std::cout << "mat2: " << std::endl << mat2;
Mat mat3 = (mat2 + mat1);
std::cout << "mat3 = mat2 + mat1: " << std::endl << mat3;
Mat mat4 = (mat3 + mat2 - mat1);
std::cout << "mat4 = mat3 + mat2 - mat1: " << std::endl << mat4;
stringstream ss;
ss << mat1;
ss >> mat4;
std::cout << "mat4:" << std::endl << mat4;
const Mat mat6(mat4);
std::cout << "const mat6:" << std::endl << mat6;
cout << mat6(0, 0) << " " << mat6.rows() << " "<" ";
Mat mat7 = mat2;
std::cout << "mat7: " << std::endl << mat7;
mat2(0, 0) = 11;
std::cout << "mat7: " << std::endl << mat7;
mat7.resize(2, 3);
std::cout << "mat7.resize(2, 3): " << std::endl << mat7;
mat7.resize(5, 6);
std::cout << "mat7.resize(5, 6): " << std::endl << mat7;
return 1;
}
#endif
以上是针对 int 数据类型实现的矩阵类,那么如果我的矩阵数据类型是double怎么办?总不能再重新实现一遍吧。为了实现多种数据类型,有两种思路。第一种就是像 OpenCV 的矩阵类一样,单独用一个变量来定义使用什么样的数据类型,这样做的好处是不需要使用模板;第二种思路就是使用C++的Templeate来实现。我实现的是第二种思路。
头文件 Mat.h
#ifndef _MAT_H_
#define _MAT_H_
#include
#include
#include
#include
#include
#include
//implement Mat class in c++
template<typename T>
class Mat{
public:
typedef T value_type;
//construct
Mat();
Mat(size_t i, size_t j);
copy constructor
Mat(const Mat& m);
copy assignment
Mat& operator=(const Mat&m);
// +=
Mat& operator+=(const Mat& m);
// -=
Mat& operator-=(const Mat& m);
//move constructor
Mat( Mat&& m);
//move assignment
Mat& operator=( Mat&& m);
//destructor
~Mat();
//access element value
T& operator()(size_t i, size_t j);
const T& operator()(size_t i, size_t j) const;
//get row and col number
const size_t rows() const{ return vdata.size(); }
const size_t cols() const{
if (vdata.empty()) return 0;
else return vdata[0].size();
}
//resize
void resize(size_t nr, size_t nc);
//print mat
void CoutMat(std::ostream& os) const;
void CinMat(std::istream& is);
private:
std::vector<std::vector > vdata;
};
#endif
需要说明的是,当把数据存储从指针改为vector之后,也就不需要单独记录矩阵的行和列了,这些信息都蕴含在vector之中了。要注意的是,在使用取元素下标 [ ] 的时候,需要确保vector不为空。否则就会出现 vector out of range 的错误。
Mat.c
#include "mat.h"
using std::cout;
using std::endl;
using std::istream;
using std::ostream;
using std::stringstream;
template<typename T>
ostream& operator<<(ostream &os, const Mat &m){
m.CoutMat(os);
return os;
}
template<typename T>
istream& operator>>(istream &is, Mat&m){
m.CinMat(is);
return is;
}
// +
template<typename T>
const Mat operator+(const Mat& m1, const Mat& m2){
Mat t(m1);
t += m2;
return t;
}
// -
template<typename T>
const Mat operator-(const Mat& m1, const Mat& m2){
Mat t(m1);
t -= m2;
return t;
}
// print mat
template<typename T>
void Mat::CoutMat(std::ostream& os) const
{
if (vdata.empty()) return;
for (size_t i = 0; i < vdata.size(); i++){
for (size_t j = 0; j < vdata[0].size(); j++){
os << vdata[i][j] << " ";
}
os << std::endl;
}
os << std::endl;
}
template<typename T>
void Mat::CinMat(std::istream& is)
{
if (vdata.empty()) return;
for (size_t i = 0; i < vdata.size(); i++){
for (size_t j = 0; j < vdata[0].size(); j++){
is >> vdata[i][j];
}
}
}
//construct
template<typename T>
Mat::Mat(){
cout << "default constructor" << endl;
vdata.clear();
}
template<typename T>
Mat::Mat(size_t i, size_t j){
std::vector<std::vector > tdata(i, std::vector (j, 0));
vdata = std::move(tdata);
}
//copy constructor
template<typename T>
Mat::Mat(const Mat& m){
cout << "copy constructor" << endl;
vdata.assign(m.vdata.cbegin(), m.vdata.cend());
}
//copy assignment
template<typename T>
Mat& Mat::operator=(const Mat& m){
cout << "copy assignment" << endl;
if (this != &m){
vdata.assign(m.vdata.cbegin(), m.vdata.cend());
}
return *this;
}
//move constructor
template<typename T>
Mat::Mat( Mat&& m ){
cout << "move constructor" << endl;
vdata = std::move(m.vdata);
}
//move assignment
template<typename T>
Mat& Mat::operator=(Mat&& m){
cout << "move assignment" << endl;
if (this != &m){
vdata.clear();
vdata = std::move(m.vdata);
}
return *this;
}
//destructor
template<typename T>
Mat::~Mat(){
vdata.clear();
}
//access element value
template<typename T>
inline T& Mat::operator()(size_t i, size_t j){
assert(!vdata.empty());
assert(i >= 0 && j >= 0 && i < vdata.size() && j < vdata[0].size());
return vdata[i][j];
}
template<typename T>
inline const T& Mat::operator()(size_t i, size_t j) const{
assert(!vdata.empty());
assert(i >= 0 && j >= 0 && i < vdata.size() && j < vdata[0].size());
return vdata[i][j];
}
// +=
template<typename T>
Mat& Mat::operator+=(const Mat& m){
if (vdata.empty() || m.vdata.empty()) return *this;
const size_t row = vdata.size();
const size_t col = vdata[0].size();
const size_t mrow = m.vdata.size();
const size_t mcol = m.vdata[0].size();
if (row == mrow && col == mcol){
for (size_t i = 0; i < row; i++)
for (size_t j = 0; j < col; j++)
vdata[i][j] += m.vdata[i][j];
}
else{
std::cerr << "mat must be the same size." << std::endl;
}
return *this;
}
// -=
template<typename T>
Mat& Mat::operator-=(const Mat& m){
if (vdata.empty() || m.vdata.empty()) return *this;
const size_t row = vdata.size();
const size_t col = vdata[0].size();
const size_t mrow = m.vdata.size();
const size_t mcol = m.vdata[0].size();
if (row == mrow && col == mcol){
for (size_t i = 0; i < row; i++)
for (size_t j = 0; j < col; j++)
vdata[i][j] -= m.vdata[i][j];
}
else{
std::cerr << "mat must be the same size." << std::endl;
}
return *this;
}
//resize
template<typename T>
void Mat::resize(size_t nr, size_t nc){
vdata.resize(nr);
for (size_t i = 0; i < nr; i++){
vdata[i].resize(nc);
}
}
//test Mat class
typedef double Type;
int main(){
Mat mat1(3, 4);
Mat mat2(3, 4);
for (size_t i = 0; i < mat1.rows(); i++){
for (size_t j = 0; j < mat1.cols(); j++){
mat1(i, j) = i*mat1.cols() + j;
mat2(i, j) = 2 * i*mat1.cols() + 2 * j;
}
}
std::cout << "mat1: " << std::endl << mat1;
std::cout << "mat2: " << std::endl << mat2;
Mat mat3 = (mat2 + mat1);
std::cout << "mat3 = mat2 + mat1: " << std::endl << mat3;
Mat mat4 = (mat3 + mat2 - mat1);
std::cout << "mat4 = mat3 + mat2 - mat1: " << std::endl << mat4;
stringstream ss;
ss << mat1;
ss >> mat4;
std::cout << "mat4:" << std::endl << mat4;
const Mat mat6(mat4);
std::cout << "const mat6:" << std::endl << mat6;
cout << mat6(0, 0) << " " << mat6.rows() << " " << mat6.cols() << endl;
Mat mat7;
mat7 = std::move(mat1);
std::cout << "mat1: " << std::endl << mat1;
std::cout << "mat7: " << std::endl << mat7;
mat7.resize(2, 3);
std::cout << "mat7.resize(2, 3): " << std::endl << mat7;
mat7.resize(4, 6);
std::cout << "mat7.resize(4, 6): " << std::endl << mat7;
Mat mat8;
cout << " " << mat8.rows() << " " << mat8.cols() << endl;
//this will cause assertion error since mat8 is empty
//cout<
return 1;
}
需要说明的时,重载输入元算符<<时,我开始用的友元函数,但是出现连接错误,没有找到bug,然后在 stackoverflow 上看到别人推荐这种非友元的方式来实现,就采用了这种方式,避免了之前的链接错误。。。
如有错误,恳请指正!