C++模板实战8:矩阵乘法

     矩阵乘法采用迭代器实现,矩阵采用行优先方式存储,其关键操作是“行×列”,该操作分成三部分:行由一个迭代器完成移动,列有一个列迭代器完成移动,乘法采用transform完成其中需要一个累加操作有累加迭代器完成

1 矩阵乘法中涉及到行×列,若矩阵采用行优先方式存储,那么行的移动相对简单,列的移动相对复杂一点,针对列移动采用一个迭代器实现,如下:

// 文件名:skip_iterator.hpp
#pragma once
#include <iterator>

template<typename T>
class skip_iterator :
    public std::iterator<std::random_access_iterator_tag, T>
{
    T *pos;                     // 所指元素地址,pos是当前列的某个位置上的元素
    size_t step;                // 步长,其实就是列数,列上每移动一步相当于移动的元素等于列数
public:
    // 类型定义
    typedef std::iterator<std::random_access_iterator_tag, T> base_type;
    typedef typename base_type::difference_type difference_type;
    typedef typename base_type::reference reference;

    // 构造函数
    skip_iterator(T *pos, size_t step) : pos(pos), step(step) {}
    skip_iterator(const skip_iterator &i) : pos(i.pos), step(step) {}

    difference_type
    operator - (skip_iterator r) {return (pos - r.pos) / step;}

    skip_iterator
    operator + (typename base_type::difference_type n)
    {return skip_iterator(pos + n * step);}

    skip_iterator
    operator ++() {pos += step; return *this;}//从某一特定的列上一个元素移动到该列的下一个元素则要移动的元素数目等于列数step

    bool operator == (skip_iterator const &r) {return pos == r.pos;}
    bool operator != (skip_iterator const &r) {return !(*this == r);}

    // 去引用
    reference operator *() {return *pos;}
};
   

2  矩阵类模板,矩阵采用行优先方式存储,需要特殊设计的就是列的移动,如下:

// 文件名:matrix.hpp
#pragma once
#include "skip_iterator.hpp"
#include <algorithm>

template<typename T>
class matrix
{
public:
    // 嵌套类型定义
    typedef T value_type;
    typedef T* iterator; // 迭代全部数据
    typedef T* row_iterator; // 迭代一行数据
    typedef skip_iterator<T> col_iterator; // 迭代一列数据,列迭代器采用特殊设计的skip_iterator
    // 只读迭代器类型定义
    typedef const T* const_iterator;
    typedef const T* const_row_iterator;
    typedef skip_iterator<const T> const_col_iterator;

private:
    T *data; // 数据指针
    size_t n_row;   // 行数
    size_t n_col;   // 列数
public:
    // 构造与析构
    matrix(size_t n_row, size_t n_col) :
        data(new T[n_row * n_col]), n_row(n_row), n_col(n_col) {}

    // 复制构造函数
    matrix(matrix const &m) :
        data(new T[m.n_row * m.n_col]),
        n_row(m.n_row), n_col(m.n_col) {
        std::copy(m.begin(), m.end(), begin());
    }

    template<typename Iterator>
    matrix(size_t n_row, size_t n_col, Iterator i) :
        data(new T[n_row * n_col]), n_row(n_row), n_col(n_col) 
    {
        Iterator j = i;
        std::advance(j, n_row * n_col);
        std::copy(i, j, begin());
        // 此处更适合用C++11中的copy_n
    }
    ~matrix() {delete[] data;} //析构函数中释放data的空间

    
    iterator begin() {return data;}
    iterator end() {return data + n_row * n_col;}
    row_iterator row_begin(size_t n) {return data + n * n_col;}//第n行首地址
    row_iterator row_end(size_t n) {return row_begin(n) + n_col;}//第n行尾地址
    col_iterator col_begin(size_t n) {return col_iterator(data + n, n_col);}//第n列首位置,n_col为step
    col_iterator col_end(size_t n) {return col_begin(n) + n_row;}

    const_iterator begin() const {return data;}//只读版本
    const_iterator end() const {return data + n_row * n_col;}
    const_row_iterator row_begin(size_t n) const {return data + n * n_col;}
    const_row_iterator row_end(size_t n) const {return row_begin(n) + n_col;}
    const_col_iterator col_begin(size_t n) const {
        return const_col_iterator(data + n, n_col);
    }
    const_col_iterator col_end(size_t n) const {return col_begin(n) + n_row;}

    size_t num_row() const {return n_row;}
    size_t num_col() const {return n_col;}

    T& operator() (size_t i, size_t j) {return data[i * n_col + j];}
    T const& operator() (size_t i, size_t j) const {return data[i * n_col + j];}

    // 赋值操作符也需要特别处理
    matrix&
    operator=(matrix const &m) {
        if (&m == this) return *this;
        if (n_row * n_col < m.n_row * m.n_col) {
            delete[] data;
            data = new T[m.n_row * m.n_col];
        }
        n_row = m.n_row;
        n_col = m.n_col;
        std::copy(m.begin(), m.end(), begin());
    }
};

3 累计迭代器:矩阵的行×列相当于两个序列逐个元素相乘再累加,执行累加操作采用一个迭代器实现,如下:

// 文件名:accumulate_iterator.hpp
#pragma once
#include <iterator>
template<typename T, typename BinFunc>
class accumulate_iterator :
    public std::iterator<std::output_iterator_tag, T>
{
    T &ref_x;                   // 累计所用变量引用
    BinFunc bin_func;           // 累计所用函数。ref_x = bin_func(ref_x, v)
public:
    accumulate_iterator(T &ref_x, BinFunc bin_func) :
        ref_x(ref_x), bin_func(bin_func) {}

    // 去引用操作返回自身
    accumulate_iterator operator*() {return *this;}

    // 赋值操作实现累计
    template<typename T0>
    T0 const & operator=(T0 const &v) {ref_x = bin_func(ref_x, v);}

    accumulate_iterator& operator++() {return *this;}
};

// 生成accumulate_iterator的助手函数
template<typename T, typename BinFunc>
accumulate_iterator<T, BinFunc>
accumulater(T &ref_x, BinFunc bin_func)
{
    return accumulate_iterator<T, BinFunc>(ref_x, bin_func);
}

4 矩阵乘法的实现,如下:

#include "matrix.hpp"
#include "accumulate_iterator.hpp"
#include <algorithm>
#include <stdexcept>

template<typename T>
matrix<T> operator * (matrix<T> const &m0, matrix<T> const &m1)
    throw (std::runtime_error)
{
    // 矩阵尺寸不符合时,无法相乘。抛出一个运行期异常
    if (m0.num_col() != m1.num_row())
        throw std::runtime_error("Bad matrix size for multiplication.");

    matrix<T> m(m0.num_row(), m1.num_col());

    typename matrix<T>::iterator pos = m.begin();
    
    for (size_t i = 0; i < m.num_row(); ++i) {
        for (size_t j = 0; j < m.num_col(); ++j) {
            *pos = 0;
            std::transform(m0.row_begin(i), m0.row_end(i),m1.col_begin(j),accumulater(*pos, std::plus<T>()),std::multiplies<T>());//这里注意transform的实现#1#
            ++pos;
        }
    }

    return m;
}
#1#处说明:

template <class InputIterator, class OutputIterator, class UnaryOperator>
  OutputIterator transform (InputIterator first1, InputIterator last1,InputIterator first2,OutputIterator result, BinaryOperator binary_op)
{
  while (first1 != last1) {
    *result=binary_op(*first1,*first2++);//注意前面#1#处result是累加迭代器,其赋值操作实质是累加。first2是个列迭代器
    ++result;//累加迭代器自增是返回自身 
    ++first1;
  }
  return result;
}





你可能感兴趣的:(C++模板实战8矩阵乘法)