c++实现简单矩阵类Mat

刚学习C++,之前把 Primer 看了一遍,现在也在刷 leetcode,感觉学习编程语言光看书页刷题也是不够的,最好是能做一些实际的项目,这样要用到哪些东西时不明白再看书,就会印象深刻些,否则光看书只是走马观花,看了也就忘了。
打算自己用C++实现一个简单的矩阵 Mat 类,包括一些简单的操作就可以了,但实现起来发现也并没有那么简单,还是遇到很多问题。这个过程也还是学到了不少东西。

固定数据类型的矩阵

先从一个固定数据类型的矩阵的开始吧,以 int 为例子,那么我们需要什么呢? 首先,我们需要记录矩阵的行和列,也需要一块内存存储矩阵的数据;然后需要一些简单的构造函数和其他的成员函数,实现一些基本的操作。
废话不多说,直接上程序:
头文件 Mat.h:

#ifndef _MAT_H_
#define _MAT_H_

#include 
#include 
#include 
#include 
#include 

//implement Mat class in c++

class Mat{
    friend  std::ostream& operator<<(std::ostream &os, const Mat &m);
    friend  std::istream& operator>>(std::istream &is, Mat &m);

public:
    typedef int value_type;
    typedef std::vector<int>::size_type size_type;

    //construct
    Mat();
    Mat(size_t i, size_t j);

    //copy constructor
    Mat(const Mat& m);

    //copy assignment
    Mat& operator=(const Mat& m);

    // +=
    Mat& operator+=(const Mat& m);

    // -=
    Mat& operator-=(const Mat& m);

    //destructor
    ~Mat();

    //access element value
    int& operator()(size_t i, size_t j);
    const int& operator()(size_t i, size_t j) const;


    //get row and col number
    const size_t rows() const{ return row; }
    const size_t cols() const{ return col; }

    //resize
    void resize(size_t nr, size_t nc);

private:

    size_t row;
    size_t col;
    std::vector<std::vector<int>> data;
};

#endif

刚开始的时候,我是用 int* 来存储的数据,用int* 不方便的一个地方就是需要自己来管理内存,在实现一些 复制赋值和移动操作时都需要很小心处理,来管理内存,稍有不慎就会出问题。所以我后来改用C++标准容器 vector 了。这样就不用自己来管理内存了。

另外,像 + 、- 这样两边的数可以互换位置的操作符,重载时最好用非成员函数来重载,而且像 +=、 -= 这样的复合操作符,必须要用成员函数来重载。另外,可以用 +=、-= 来实现非成员函数的 + 、- 操作。最后要强调的是,非成员的运算符重载函数(+,-,,/)之类,返回类型应该是 const 类型,避免类似 a b = c 这样的语句通过编译。

像输出(<<)、输入(>>)运算符,需要用非成员函数来重载,一般还要声明为友元函数。不过可以通过成员函数完成输入输出操作,非成员函数来重载、调用成员函数输入输出来避免声明友元函数,后面一个例子就是这样实现的。

最后要强调的一点就是针对const对象,需要有对应的const版本的函数,具体说来,就是对一些不改变类类型的变量成员函数,尽量声明为const类型,这样const对象也能调用这些函数。在返回矩阵数据时,需要针对const 和 非const的对象,实现两个函数。

//access element value
int& operator()(size_t i, size_t j);
const int& operator()(size_t i, size_t j) const;

例如,重载 ( ) 运算符时,需要实现两个版本的函数,const 对象调用第二个版本函数,非 const 对象调用第一个版本的函数。

源文件Mat.cpp


#include 
#include 
#include 
#include "matint.h"


using std::cout;
using std::endl;
using std::istream;
using std::ostream;
using std::stringstream;

ostream& operator<<(ostream &os, const Mat&m){
    for (size_t i = 0; i < m.row; i++){
        for (size_t j = 0; j < m.col; j++){
            os << m.data[i][j] << " ";
        }
        os << std::endl;
    }
    os << std::endl;
    return os;
}

istream& operator>>(istream &is, Mat&m){
    for (size_t i = 0; i < m.row; i++){
        for (size_t j = 0; j < m.col; j++){
            is >> m.data[i][j];
        }
    }
    return is;
}

// +
const Mat operator+(const Mat& m1, const Mat& m2){
    Mat t = m1;
    t += m2;
    return t;
}


// -
const Mat operator-(const Mat& m1, const Mat& m2){
    Mat t = m1;
    t -= m2;
    return t;
}

//constructor
Mat::Mat(){
    cout << "default constructor" << endl;
    row = 0;
    col = 0;
    data.clear();
}


Mat::Mat(size_t i, size_t j){
    row = i; col = j;
    std::vector<std::vector<int>> vdata(row, std::vector<int>(col, 0));
    data = std::move(vdata);
}

//copy constructor
Mat::Mat(const Mat& m){
    cout << "copy constructor" << endl;
    row = m.row; col = m.col;
    data = m.data;
}

//copy assignment
Mat& Mat::operator=(const Mat& m){
    cout << "copy assignment" << endl;
    row = m.row; col = m.col;
    data = m.data;
    return *this;
}

//destructor
Mat::~Mat(){
    data.clear();
}

//access element value
int& Mat::operator()(size_t i, size_t j){
    assert(i >= 0 && j >= 0 && i < row && j < col);
    return data[i][j];
}

const int& Mat::operator()(size_t i, size_t j) const{
    assert(i >= 0 && j >= 0 && i < row && j < col);
    return data[i][j];
}

//resize
void Mat::resize(size_t nr, size_t nc){
    data.resize(nr);
    for (size_t i = 0; i < nr; i++){
        data[i].resize(nc);
    }
    col = nc; row = nr;
}



// +=
Mat& Mat::operator+=(const Mat& m){
    if (row == m.row && col == m.col){
        for (size_t i = 0; i < row; i++)
        {
            for (size_t j = 0; j < col; j++)
                data[i][j] += m.data[i][j];
        }
    }
    else{
        std::cerr << "mat must be the same size." << std::endl;
    }
    return *this;
}

// -=
Mat& Mat::operator-=(const Mat& m){
    if (row == m.row && col == m.col){
        for (size_t i = 0; i < row; i++)
        {
            for (size_t j = 0; j < col; j++)
                data[i][j] -= m.data[i][j];
        }
    }
    else{
        std::cerr << "mat must be the same size." << std::endl;
    }
    return *this;
}

#if 1

int main(){

    Mat mat1(3, 4);
    Mat mat2(3, 4);

    for (size_t i = 0; i < mat1.rows(); i++){
        for (size_t j = 0; j < mat1.cols(); j++){
            mat1(i, j) = 1;
            mat2(i, j) = 3;
        }
    }
    std::cout << "mat1: " << std::endl << mat1;
    std::cout << "mat2: " << std::endl << mat2;

    Mat mat3 = (mat2 + mat1);
    std::cout << "mat3 = mat2 + mat1: " << std::endl << mat3;

    Mat mat4 = (mat3 + mat2 - mat1);
    std::cout << "mat4 = mat3 + mat2 - mat1: " << std::endl << mat4;

    stringstream ss;
    ss << mat1;
    ss >> mat4;
    std::cout << "mat4:" << std::endl << mat4;

    const Mat   mat6(mat4);
    std::cout << "const mat6:" << std::endl << mat6;
    cout << mat6(0, 0) << " " << mat6.rows() << " "<" ";

    Mat mat7 = mat2;
    std::cout << "mat7: " << std::endl << mat7;

    mat2(0, 0) = 11;
    std::cout << "mat7: " << std::endl << mat7;

    mat7.resize(2, 3);
    std::cout << "mat7.resize(2, 3): " << std::endl << mat7;

    mat7.resize(5, 6);
    std::cout << "mat7.resize(5, 6): " << std::endl << mat7;

    return 1;
}

#endif

使用Template实现通用数据类型的Mat

以上是针对 int 数据类型实现的矩阵类,那么如果我的矩阵数据类型是double怎么办?总不能再重新实现一遍吧。为了实现多种数据类型,有两种思路。第一种就是像 OpenCV 的矩阵类一样,单独用一个变量来定义使用什么样的数据类型,这样做的好处是不需要使用模板;第二种思路就是使用C++的Templeate来实现。我实现的是第二种思路。

头文件 Mat.h

#ifndef _MAT_H_
#define _MAT_H_

#include 
#include 
#include 
#include 
#include 
#include 

//implement Mat class in c++


template<typename T>
class Mat{

public:
    typedef T value_type;

    //construct
    Mat();
    Mat(size_t i, size_t j);

    copy constructor
    Mat(const Mat& m);

    copy assignment
    Mat& operator=(const Mat&m);

    // +=
    Mat& operator+=(const Mat& m);

    // -=
    Mat& operator-=(const Mat& m);

    //move constructor
    Mat( Mat&& m);

    //move assignment
    Mat& operator=( Mat&& m);

    //destructor
    ~Mat();

    //access element value
    T& operator()(size_t i, size_t j);
    const T& operator()(size_t i, size_t j) const;

    //get row and col number
    const size_t rows() const{ return vdata.size(); }
    const size_t cols() const{ 
        if (vdata.empty()) return 0;
        else return vdata[0].size(); 
    }

    //resize
    void resize(size_t nr, size_t nc);

    //print mat
    void CoutMat(std::ostream& os) const;
    void CinMat(std::istream& is);

private:

    std::vector<std::vector> vdata;
};

#endif

需要说明的是,当把数据存储从指针改为vector之后,也就不需要单独记录矩阵的行和列了,这些信息都蕴含在vector之中了。要注意的是,在使用取元素下标 [ ] 的时候,需要确保vector不为空。否则就会出现 vector out of range 的错误。

Mat.c



#include "mat.h"

using std::cout;
using std::endl;
using std::istream;
using std::ostream;
using std::stringstream;

template<typename T>
ostream& operator<<(ostream &os, const Mat &m){
    m.CoutMat(os);
    return os;
}

template<typename T>
istream& operator>>(istream &is, Mat&m){
    m.CinMat(is);
    return is;
}

// +
template<typename T>
const Mat operator+(const Mat& m1, const Mat& m2){
    Mat t(m1);
    t += m2;
    return t;
}

// -
template<typename T>
const Mat operator-(const Mat& m1, const Mat& m2){
    Mat t(m1);
    t -= m2;
    return t;
}

// print mat
template<typename T>
void Mat::CoutMat(std::ostream& os) const
{
    if (vdata.empty()) return;
    for (size_t i = 0; i < vdata.size(); i++){
        for (size_t j = 0; j < vdata[0].size(); j++){
            os << vdata[i][j] << " ";
        }
        os << std::endl;
    }
    os << std::endl;
}

template<typename T>
void Mat::CinMat(std::istream& is)
{
    if (vdata.empty()) return;
    for (size_t i = 0; i < vdata.size(); i++){
        for (size_t j = 0; j < vdata[0].size(); j++){
            is >> vdata[i][j];
        }
    }
}

//construct
template<typename T>
Mat::Mat(){
    cout << "default constructor" << endl;
    vdata.clear();
}

template<typename T>
Mat::Mat(size_t i, size_t j){
    std::vector<std::vector> tdata(i, std::vector(j, 0));
    vdata = std::move(tdata);
}

//copy constructor
template<typename T>
Mat::Mat(const Mat& m){
    cout << "copy constructor" << endl;
    vdata.assign(m.vdata.cbegin(), m.vdata.cend());
}

//copy assignment
template<typename T>
Mat& Mat::operator=(const Mat& m){
    cout << "copy assignment" << endl;
    if (this != &m){
        vdata.assign(m.vdata.cbegin(), m.vdata.cend());
    }
    return *this;
}

//move constructor
template<typename T>
Mat::Mat( Mat&& m ){
    cout << "move constructor" << endl;
    vdata = std::move(m.vdata);
}

//move assignment
template<typename T>
Mat& Mat::operator=(Mat&& m){
    cout << "move assignment" << endl;
    if (this != &m){
        vdata.clear();
        vdata = std::move(m.vdata);
    }
    return *this;
}

//destructor
template<typename T>
Mat::~Mat(){
    vdata.clear();
}

//access element value
template<typename T>
inline T& Mat::operator()(size_t i, size_t j){
    assert(!vdata.empty());
    assert(i >= 0 && j >= 0 && i < vdata.size() && j < vdata[0].size());
    return vdata[i][j];
}
template<typename T>
inline const T& Mat::operator()(size_t i, size_t j) const{
    assert(!vdata.empty());
    assert(i >= 0 && j >= 0 && i < vdata.size() && j < vdata[0].size());
    return vdata[i][j];
}

// +=
template<typename T>
Mat& Mat::operator+=(const Mat& m){
    if (vdata.empty() || m.vdata.empty()) return *this;

    const size_t row = vdata.size();
    const size_t col = vdata[0].size();
    const size_t mrow = m.vdata.size();
    const size_t mcol = m.vdata[0].size();

    if (row == mrow && col == mcol){
        for (size_t i = 0; i < row; i++)
        for (size_t j = 0; j < col; j++)
            vdata[i][j] += m.vdata[i][j];
    }
    else{
        std::cerr << "mat must be the same size." << std::endl;
    }

    return *this;
}

// -=
template<typename T>
Mat& Mat::operator-=(const Mat& m){
    if (vdata.empty() || m.vdata.empty()) return *this;

    const size_t row = vdata.size();
    const size_t col = vdata[0].size();
    const size_t mrow = m.vdata.size();
    const size_t mcol = m.vdata[0].size();

    if (row == mrow && col == mcol){
        for (size_t i = 0; i < row; i++)
        for (size_t j = 0; j < col; j++)
            vdata[i][j] -= m.vdata[i][j];
    }
    else{
        std::cerr << "mat must be the same size." << std::endl;
    }

    return *this;
}

//resize
template<typename T>
void Mat::resize(size_t nr, size_t nc){
    vdata.resize(nr);
    for (size_t i = 0; i < nr; i++){
        vdata[i].resize(nc);
    }
}

//test Mat class 
typedef double Type;
int main(){

    Mat mat1(3, 4);
    Mat mat2(3, 4);

    for (size_t i = 0; i < mat1.rows(); i++){
        for (size_t j = 0; j < mat1.cols(); j++){
            mat1(i, j) = i*mat1.cols() + j;
            mat2(i, j) = 2 * i*mat1.cols() + 2 * j;
        }
    }

    std::cout << "mat1: " << std::endl << mat1;
    std::cout << "mat2: " << std::endl << mat2;

    Mat mat3 = (mat2 + mat1);
    std::cout << "mat3 = mat2 + mat1: " << std::endl << mat3;

    Mat mat4 = (mat3 + mat2 - mat1);
    std::cout << "mat4 = mat3 + mat2 - mat1: " << std::endl << mat4;

    stringstream ss;
    ss << mat1;
    ss >> mat4;
    std::cout << "mat4:" << std::endl << mat4;

    const Mat mat6(mat4);
    std::cout << "const mat6:" << std::endl << mat6;
    cout << mat6(0, 0) << " " << mat6.rows() << " " << mat6.cols() << endl;

    Mat mat7;
    mat7 = std::move(mat1);
    std::cout << "mat1: " << std::endl << mat1;
    std::cout << "mat7: " << std::endl << mat7;

    mat7.resize(2, 3);
    std::cout << "mat7.resize(2, 3): " << std::endl << mat7;

    mat7.resize(4, 6);
    std::cout << "mat7.resize(4, 6): " << std::endl << mat7;

    Mat mat8;
    cout  << " " << mat8.rows() << " " << mat8.cols() << endl;
    //this will cause assertion error since mat8 is empty
    //cout<
    return 1;

}

需要说明的时,重载输入元算符<<时,我开始用的友元函数,但是出现连接错误,没有找到bug,然后在 stackoverflow 上看到别人推荐这种非友元的方式来实现,就采用了这种方式,避免了之前的链接错误。。。

如有错误,恳请指正!

你可能感兴趣的:(c-c++)