C++实现LeNet-5卷积神经网络

搞了好久好久,公式推导+网络设计就推了20多页草稿纸
花了近10天
程序进1k行,各种debug要人命,只能不断的单元测试+梯度检验
因为C++只有加减乘除,所以对这个网络模型不能有一丝丝的模糊,每一步都要理解的很透彻
挺考验能力的,很庆幸我做出来了,这个是第二版,第一版也写了1k行,写完才发现,模型错了,只能全删掉重新写
算是一次修行
网络的设计,编代码时的各种考虑,debug记录,我不想整理了,有问题的直接私信我吧

#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;

//自然数
const double E = 2.718281828459;
//极小值
const double EPS = 3e-3;
//MNIST
const int MNIST_HEIGHT = 28;
const int MNIST_WIDTH = 28;
//INPUT
const int INPUT_HEIGHT = 32;
const int INPUT_WIDTH = 32;
//padding
const int OFFSET = 2;
//标签的字节数
const int LABEL_BYTE = 1;
//输出的大小
const int OUT_SIZE = 10;
const int NUM_TRAIN = 20;
const int NUM_TEST = 20;

/*------------矩阵类-----------------------*/
typedef vector<vector<double>> Matrix;
inline bool reshapeMatrix(Matrix &mat, size_t row, size_t col)
{
    if (row <= 0 || col <= 0)
        throw "reshapeMatrix: row <= 0 || col <= 0";
    mat.resize(row);
    for (int i = 0; i < row; i++)
        mat[i].resize(col);
    return true;
}
inline bool reshapeMatrix(Matrix &mat, size_t row, size_t col, double val)
{
    if (row <= 0 || col <= 0)
        throw "reshapeMatrix: row <= 0 || col <= 0";
    mat.resize(row);
    for (int i = 0; i < row; i++)
    {
        //先清空,再重塑
        mat[i].clear();
        mat[i].resize(col, val);
    }

    return true;
}
//矩阵二维卷积
inline void convMatrix(const Matrix &a, const Matrix &b, Matrix &res)
{
    if (b.size() > a.size() || b[0].size() > a[0].size())
        throw "convMatrix: b is larger than a";
    reshapeMatrix(res, a.size() - b.size() + 1, a[0].size() - b[0].size() + 1, 0);
    for (int i = 0; i < res.size(); i++)
        for (int j = 0; j < res[0].size(); j++)
        {
            //遍历卷积矩阵
            for (int _i = 0; _i < b.size(); _i++)
                for (int _j = 0; _j < b[0].size(); _j++)
                    res[i][j] += a[_i + i][_j + j] * b[_i][_j];
        }
    return;
}
//矩阵加法
inline void plusMatrix(Matrix &a, const Matrix &b)
{
    if (a.size() != b.size() || a[0].size() != b[0].size())
        throw "plusMatrix: shape don't match";
    for (int i = 0; i < a.size(); i++)
        for (int j = 0; j < a[0].size(); j++)
            a[i][j] += b[i][j];
    return;
}
inline void plusMatrix(Matrix &a, double val)
{
    for (int i = 0; i < a.size(); i++)
        for (int j = 0; j < a[0].size(); j++)
            a[i][j] += val;
    return;
}
inline void plusMatrix(const Matrix &a, const Matrix &b, Matrix &res)
{
    if (a.size() != b.size() || a[0].size() != b[0].size())
        throw "plusMatrix: shape don't match";
    reshapeMatrix(res, a.size(), a[0].size());
    for (int i = 0; i < res.size(); i++)
        for (int j = 0; j < res[0].size(); j++)
            res[i][j] = a[i][j] + b[i][j];
    return;
}
//矩阵减法
inline void minusMatrix(const Matrix &a, const Matrix &b, Matrix &res)
{
    if (a.size() != b.size() || a[0].size() != b[0].size())
        throw "plusMatrix: shape don't match";
    reshapeMatrix(res, a.size(), a[0].size());
    for (int i = 0; i < res.size(); i++)
        for (int j = 0; j < res[0].size(); j++)
            res[i][j] = a[i][j] - b[i][j];
    return;
}
//矩阵乘法
inline void multiplyMatrix(const Matrix &a, const Matrix &b, Matrix &res)
{
    if (a[0].size() != b.size())
        throw "multiplyMatrix: a.col != b.row";
    reshapeMatrix(res, a.size(), b[0].size(), 0);
    for (int i = 0; i < res.size(); i++)
        for (int j = 0; j < res[0].size(); j++)
            for (int k = 0; k < a[0].size(); k++)
                res[i][j] += a[i][k] * b[k][j];
    return;
}
//矩阵与标量相乘
inline void multiplyMatrix(Matrix &mat, double val)
{
    for (int i = 0; i < mat.size(); i++)
        for (int j = 0; j < mat[0].size(); j++)
            mat[i][j] *= val;
    return;
}
inline void multiplyMatrix(double val, const Matrix &mat, Matrix &res)
{
    reshapeMatrix(res, mat.size(), mat[0].size());
    for (int i = 0; i < res.size(); i++)
        for (int j = 0; j < res[0].size(); j++)
            res[i][j] = mat[i][j] * val;
    return;
}
//矩阵点乘
void matmulMatrix(const Matrix &a, const Matrix &b, Matrix &res)
{
    if (a.size() != b.size() || a[0].size() != b[0].size())
        throw "matmulMatrix: shape don't match";
    reshapeMatrix(res, a.size(), a[0].size());
    for (int i = 0; i < res.size(); i++)
        for (int j = 0; j < res[0].size(); j++)
            res[i][j] = a[i][j] * b[i][j];
    return;
}
//矩阵池化,步长=大小
inline void downSampleMatrix(const Matrix &mat, size_t height, size_t width, Matrix &res)
{
    if (mat.size() % height != 0 || mat[0].size() % width != 0)
        throw "downSampleMatrix: height/width don't match matrix";
    reshapeMatrix(res, mat.size() / height, mat[0].size() / width);
    for (int i = 0; i < res.size(); i++)
        for (int j = 0; j < res[0].size(); j++)
        {
            //求和
            int row_b = i * height;
            int row_e = (i + 1) * height;
            int col_b = j * width;
            int col_e = (j + 1) * width;
            res[i][j] = 0;
            for (int _i = row_b; _i < row_e; _i++)
                for (int _j = col_b; _j < col_e; _j++)
                    res[i][j] += mat[_i][_j];
        }
    return;
}
//矩阵转置
inline void transposeMatrix(const Matrix &mat, Matrix &res)
{
    reshapeMatrix(res, mat[0].size(), mat.size());
    for (int i = 0; i < res.size(); i++)
        for (int j = 0; j < res[0].size(); j++)
            res[i][j] = mat[j][i];
    return;
}
//矩阵旋转180度
inline void rot180Matrix(const Matrix &mat, Matrix &res)
{
    reshapeMatrix(res, mat.size(), mat[0].size());
    for (int i = 0; i < res.size(); i++)
        for (int j = 0; j < res[0].size(); j++)
            res[i][j] = mat[res.size() - i - 1][res[0].size() - j - 1];
    return;
}
//求和
inline double sumMatrix(const Matrix &mat)
{
    double res = 0;
    for (int i = 0; i < mat.size(); i++)
        for (int j = 0; j < mat[0].size(); j++)
            res += mat[i][j];
    return res;
}
//打印矩阵
void printMatrix(const Matrix &mat, ostream &os)
{
    os << "Matrix: " << mat.size() << '*' << mat[0].size() << endl;
    for (int i = 0; i < mat.size(); i++)
    {
        for (int j = 0; j < mat[0].size(); j++)
            os << mat[i][j] << ' ';
        os << endl;
    }
    return;
}
/*------------矩阵类-----------------------*/

/*-------------数据读入----------------------- */
struct Point
{
    Matrix image;
    Matrix label;
    Point(void) { ; }
    Point(char *image, uint8_t num)
    {
        reshapeMatrix(this->image, INPUT_HEIGHT, INPUT_WIDTH);
        for (int i = 0; i < MNIST_HEIGHT; i++)
            for (int j = 0; j < MNIST_WIDTH; j++)
                this->image[OFFSET + i][OFFSET + j] = (uint8_t)image[i * MNIST_HEIGHT + j];
        reshapeMatrix(this->label, OUT_SIZE, 1, 0);
        label[num][0] = 1;
    }
};
vector<Point> TrainData, TestData;
inline void readData(vector<Point> &train, vector<Point> &test)
{
    char rubbish[16];
    ifstream train_images("./train-images.idx3-ubyte", ios::binary | ios::in); //图像文件
    ifstream train_labels("./train-labels.idx1-ubyte", ios::binary | ios::in); //标签文件
    train_images.read(rubbish, 16);                                            //文件开头有16字节不需要的信息,读入16字节,扔到rubbish里面
    train_labels.read(rubbish, 8);                                             //文件开头有16字节不需要的信息,读入8字节,扔到rubbish里面
    for (int i = 0; i < NUM_TRAIN; i++)                                        //读入每一张图片,i是第i张图片
    {
        char image[MNIST_HEIGHT * MNIST_WIDTH]; //存放图片二进制信息的缓冲区
        uint8_t num;                            //存放图片标签
        /*下面都是按照二进制原样读入*/
        train_images.read(image, MNIST_HEIGHT * MNIST_WIDTH); //读入28*28个字节,每个字节为0x00-0xFF代表这个像素的灰度,按照先行后列转成一维的
        //unit_t与char一样大小,但是read只接受char*,所以要转换下类型,
        train_labels.read((char *)(&num), LABEL_BYTE); //每个标签占一个字节,范围是0x00-0x09, 代表这个图像指的是几
        /*
            char->double: double pix=(uint8_t)image[i]; //先把char强制转成整数类,二进制都是原码,然后编译器会自动把整数类变成浮点数类
            char->int:    int label=num; //这两个都是原码储存, 直接按位复制就可以
        */
        train.push_back(Point(image, num));
    }

    ifstream test_images("./t10k-images.idx3-ubyte", ios::binary | ios::in);
    ifstream test_labels("./t10k-labels.idx1-ubyte", ios::binary | ios::in);
    test_images.read(rubbish, 16); //4*32bit_integer
    test_labels.read(rubbish, 8);  //2*32bit_integer
    for (int i = 0; i < NUM_TEST; i++)
    {
        char image[MNIST_HEIGHT * MNIST_WIDTH];
        uint8_t num;
        test_images.read(image, MNIST_HEIGHT * MNIST_WIDTH);
        test_labels.read((char *)(&num), LABEL_BYTE);
        test.push_back(Point(image, num));
    }
}
//打印图片
inline void printImage(const Matrix &data)
{
    for (int i = 0; i < 32; i++)
    {
        for (int j = 0; j < 32; j++)
        {
            printf("%3.2lf ", data[i][j]);
        }
        cout << '\n';
    }
}
/*-------------数据读入----------------------- */

/*------------------归一化--------------------*/
inline void Normalize(vector<Point> &set)
{
    for (int i = 0; i < set.size(); i++)
        for (int j = 0; j < INPUT_HEIGHT; j++)
            for (int k = 0; k < INPUT_WIDTH; k++)
                //[-0.1, 1.175]
                set[i].image[j][k] = set[i].image[j][k] / 200.0 - 0.1;
}
/*------------------归一化--------------------*/

/* ------------------------随机化-------------------------*/
default_random_engine Rand(time(NULL));
uniform_real_distribution<double> uniform_dis(-1.0, 1.0);
inline void randAssign(Matrix &mat, double last_neu_num)
{
    for (int i = 0; i < mat.size(); i++)
        for (int j = 0; j < mat[0].size(); j++)
            mat[i][j] = uniform_dis(Rand) * sqrt(1 / last_neu_num);
}
inline void randAssign(vector<double> &vec, double last_neu_num)
{
    for (int i = 0; i < vec.size(); i++)
        vec[i] = uniform_dis(Rand) * sqrt(1 / last_neu_num);
}
/* ------------------------随机化-------------------------*/

/*-----------------------激发函数------------------------- */
inline double tanh(double x)
{
    double a = pow(E, x);
    double b = pow(E, -x);
    return (a - b) / (a + b);
}
inline void tanh(const Matrix &a, Matrix &res)
{
    reshapeMatrix(res, a.size(), a[0].size());
    for (int i = 0; i < a.size(); i++)
        for (int j = 0; j < a[0].size(); j++)
            res[i][j] = tanh(a[i][j]);
    return;
}
inline void softmax(const Matrix &mat, Matrix &res)
{
    reshapeMatrix(res, 10, 1);
    double max_z = mat[0][0];
    for (int i = 0; i < mat.size(); i++)
        max_z = max(max_z, mat[i][0]);
    double sum = 0;
    for (int i = 0; i < mat.size(); i++)
        sum += pow(E, mat[i][0] - max_z);
    for (int i = 0; i < mat.size(); i++)
        res[i][0] = pow(E, mat[i][0] - max_z) / sum;
    return;
}
/*-----------------------激发函数------------------------- */

/*---------------------LeNet-5网络结构-----------------------*/
Matrix INPUT; //输入层,1@32*32

vector<Matrix> C1_core;     //C1的卷积核, 6@5*5
vector<double> C1_bias;     //C1的偏移量, 6@标
vector<Matrix> C1;          //C1的输出,6@5*5
vector<Matrix> der_C1_core; //C1的卷积核的偏导数, 6@5*5
vector<double> der_C1_bias; //C1的偏移量得到偏导数, 6@标
vector<Matrix> sus_C1;      //C1的输出的敏感度,6@5*5

vector<double> S2_coefficient;     //系数, 6@1*1
vector<double> S2_bias;            //偏移量, 6@标
vector<Matrix> S2_rec;             //平均采样,6@14*14
vector<Matrix> S2;                 //激发,6@14*14
vector<double> der_S2_coefficient; //系数的偏导数, 6@1*1
vector<double> der_S2_bias;        //偏移量的偏导数, 6@标
vector<Matrix> sus_S2_rec;         //平均采样的敏感度,6@14*14

const int C3_num[16] = {3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 6};
const bool C3_connect[16][6] =
    {
        {true, true, true, false, false, false},
        {false, true, true, true, false, false},
        {false, false, true, true, true, false},
        {false, false, false, true, true, true},
        {true, false, false, false, true, true},
        {true, true, false, false, false, true},
        {true, true, true, true, false, false},
        {false, true, true, true, true, false},
        {false, false, true, true, true, true},
        {true, false, false, true, true, true},
        {true, true, false, false, true, true},
        {true, true, true, false, false, true},
        {true, true, false, true, true, false},
        {false, true, true, false, true, true},
        {true, false, true, true, false, true},
        {true, true, true, true, true, true}};
vector<vector<Matrix>> C3_core;     //卷积核, 16@[6个3, 9个4, 1个6]@5*5
vector<double> C3_bias;             //偏移量, 16@标量
vector<Matrix> C3;                  //C3的的输出, 16@10*10
vector<vector<Matrix>> der_C3_core; //卷积核的偏导数, 16@[6个3, 9个4, 1个6]@5*5
vector<double> der_C3_bias;         //偏移量的偏导数, 16@标量
vector<Matrix> sus_C3;              //C3的的输出的明感度, 16@10*10

vector<double> S4_coefficient;     //系数, 16@1*1
vector<double> S4_bias;            //偏移量, 16@1*1
vector<Matrix> S4_rec;             //平均采样, 16@5*5
vector<Matrix> S4;                 //激发, 16@5*5
vector<double> der_S4_coefficient; //系数的偏导数, 16@1*1
vector<double> der_S4_bias;        //偏移量的偏导数, 16@1*1
vector<Matrix> sus_S4_rec;         //平均采样的敏感度, 16@5*5

vector<vector<Matrix>> C5_core;     //C5的卷积核, 120@16个5*5
vector<double> C5_bias;             //C5的偏移量, 120@标
Matrix C5;                          //C5的的输出, 120@标量
vector<vector<Matrix>> der_C5_core; //C5的卷积核的偏导数, 120@16个5*5
vector<double> der_C5_bias;         //C5的偏移量的偏导数, 120@标
Matrix sus_C5;                      //C5的的输出的敏感度, 120@标量

Matrix F6_weight;     //F6权重, 84*120
Matrix F6_bias;       //F6偏移量,84*1
Matrix F6_rec;        //F6偏移量,84*1
Matrix F6;            //F6的输出, 84*1
Matrix der_F6_weight; //F6权重的偏导数, 84*120
Matrix der_F6_bias;   //F6偏移量的偏导数,84*1
Matrix sus_F6_rec;    //F6偏移量的敏感度,84*1

Matrix OUTPUT_weight;     //输出权重, 10*84
Matrix OUTPUT_bias;       //输出偏移量, 10*1
Matrix OUTPUT_rec;        //输出的接受值, 10*1
Matrix OUTPUT;            //输出, 10*1
Matrix der_OUTPUT_weight; //输出权重的偏导数, 10*84
Matrix der_OUTPUT_bias;   //输出偏移量的偏导数, 10*1
Matrix sus_OUTPUT_rec;    //输出的接受值的敏感度, 10*1
/*---------------------LeNet-5网络结构-----------------------*/

void init(void) //结构初始化
{
    C1.resize(6);
    C1_core.resize(6);
    C1_bias.resize(6, 0);
    sus_C1.resize(6);
    der_C1_core.resize(6);
    der_C1_bias.resize(6, 0);
    for (int i = 0; i < 6; i++)
    {
        reshapeMatrix(C1_core[i], 5, 5);
        reshapeMatrix(der_C1_core[i], 5, 5);
        randAssign(C1_core[i], 25);
    }

    S2_coefficient.resize(6);
    randAssign(S2_coefficient, 4);
    S2_bias.resize(6, 1);
    S2_rec.resize(6);
    S2.resize(6);
    der_S2_coefficient.resize(6);
    der_S2_bias.resize(6, 0);
    sus_S2_rec.resize(6);

    C3_core.resize(16);
    C3_bias.resize(16, 0);
    C3.resize(16);
    der_C3_core.resize(16);
    der_C3_bias.resize(16, 0);
    sus_C3.resize(16);
    for (int i = 0; i < 16; i++)
    {
        C3_core[i].resize(6);
        der_C3_core[i].resize(6);
        reshapeMatrix(C3[i], 5, 5, 0);
        for (int j = 0; j < 6; j++)
        {
            if (C3_connect[i][j])
            {
                reshapeMatrix(C3_core[i][j], 5, 5);
                reshapeMatrix(der_C3_core[i][j], 5, 5);
                randAssign(C3_core[i][j], 25 * C3_num[i]);
            }
        }
    }

    S4_coefficient.resize(16);
    randAssign(S4_coefficient, 4);
    S4_bias.resize(16, 0);
    S4_rec.resize(16);
    S4.resize(16);
    der_S4_coefficient.resize(16);
    der_S4_bias.resize(16, 0);
    sus_S4_rec.resize(16);

    C5_core.resize(120);
    C5_bias.resize(120, 0);
    reshapeMatrix(C5, 120, 1);
    der_C5_core.resize(120);
    der_C5_bias.resize(120, 0);
    reshapeMatrix(sus_C5, 120, 1);
    for (int i = 0; i < 120; i++)
    {
        C5_core[i].resize(16);
        der_C5_core[i].resize(16);
        for (int k = 0; k < 16; k++)
        {
            reshapeMatrix(C5_core[i][k], 5, 5);
            reshapeMatrix(der_C5_core[i][k], 5, 5);
            randAssign(C5_core[i][k], 400);
        }
    }

    reshapeMatrix(F6_weight, 84, 120);
    randAssign(F6_weight, 120);
    reshapeMatrix(F6_bias, 84, 1, 0);
    reshapeMatrix(der_F6_weight, 84, 120);
    reshapeMatrix(der_F6_bias, 84, 1, 0);

    reshapeMatrix(OUTPUT_weight, 10, 84);
    randAssign(OUTPUT_weight, 84);
    reshapeMatrix(OUTPUT_bias, 10, 1, 0);
    reshapeMatrix(der_OUTPUT_weight, 10, 84);
    reshapeMatrix(der_OUTPUT_bias, 10, 1, 0);
}

/*--------------前向传播-----------------------*/
void forwardPropagation(const Point &point)
{ //输入
    INPUT = point.image;
    //卷积:C1
    for (int i = 0; i < 6; i++)
    {
        convMatrix(INPUT, C1_core[i], C1[i]);
        plusMatrix(C1[i], C1_bias[i]);
    }
    //池化:S2
    for (int i = 0; i < 6; i++)
    {
        downSampleMatrix(C1[i], 2, 2, S2_rec[i]);
        multiplyMatrix(S2_rec[i], S2_coefficient[i]);
        plusMatrix(S2_rec[i], S2_bias[i]);
        tanh(S2_rec[i], S2[i]);
    }
    //卷积:C3
    Matrix temp;
    for (int i = 0; i < 16; i++)
    {
        reshapeMatrix(C3[i], 10, 10, C3_bias[i]);
        for (int j = 0; j < 6; j++)
        {
            if (C3_connect[i][j])
            {
                convMatrix(S2[j], C3_core[i][j], temp);
                plusMatrix(C3[i], temp);
            }
        }
    }
    //池化:S4
    for (int i = 0; i < 16; i++)
    {
        downSampleMatrix(C3[i], 2, 2, S4_rec[i]);
        multiplyMatrix(S4_rec[i], S4_coefficient[i]);
        plusMatrix(S4_rec[i], S4_bias[i]);
        tanh(S4_rec[i], S4[i]);
    }
    //卷积:C5
    Matrix t1, t2;
    for (int i = 0; i < 120; i++)
    {
        C5[i][0] = C5_bias[i];
        reshapeMatrix(t1, 1, 1, 0);
        for (int j = 0; j < 16; j++)
        {
            convMatrix(S4[j], C5_core[i][j], t2);
            plusMatrix(t1, t2);
        }
        C5[i][0] += t1[0][0];
    }
    //全连接:F6
    multiplyMatrix(F6_weight, C5, F6_rec);
    plusMatrix(F6_rec, F6_bias);
    tanh(F6_rec, F6);
    //全连接:OUTPUT
    multiplyMatrix(OUTPUT_weight, F6, OUTPUT_rec);
    plusMatrix(OUTPUT_rec, OUTPUT_bias);
    softmax(OUTPUT_rec, OUTPUT);
}
/*--------------前向传播-----------------------*/

/*--------------损失函数-----------------------*/
//获取所代表的值
inline size_t getNum(const Point &point)
{
    for (int i = 0; i < 10; i++)
        if (point.label[i][0] == 1)
            return i;
    throw "@@@";
    return -1;
}
//交叉熵损失
double Loss(const Point &point)
{
    return -1.0 * log(OUTPUT[getNum(point)][0]);
}
/*--------------损失函数-----------------------*/

/*---------------反向传播----------------------*/
void backPropagation(const Point &point)
{
    //OUTPUT敏感度
    minusMatrix(OUTPUT, point.label, sus_OUTPUT_rec);

    //F6敏感度
    Matrix t1, t2, t3;
    transposeMatrix(OUTPUT_weight, t1);
    multiplyMatrix(t1, sus_OUTPUT_rec, t2);
    reshapeMatrix(sus_F6_rec, 84, 1);
    //点乘
    for (int i = 0; i < 84; i++)
        sus_F6_rec[i][0] = t2[i][0] * (1.0 - F6[i][0]) * (1.0 + F6[i][0]);

    //C5敏感度
    transposeMatrix(F6_weight, t1);
    multiplyMatrix(t1, sus_F6_rec, sus_C5);

    //S4的敏感度
    for (int i = 0; i < 16; i++)
    {
        reshapeMatrix(t1, 9, 9, 0);
        reshapeMatrix(sus_S4_rec[i], 5, 5, 0);
        for (int k = 0; k < 120; k++)
        {
            t1[4][4] = sus_C5[k][0];
            rot180Matrix(C5_core[k][i], t2);
            convMatrix(t1, t2, t3);
            plusMatrix(sus_S4_rec[i], t3);
        }
        //点乘
        for (int p = 0; p < 5; p++)
            for (int q = 0; q < 5; q++)
                sus_S4_rec[i][p][q] *= (1.0 - S4[i][p][q]) * (1.0 + S4[i][p][q]);
    }

    //C3的敏感度
    for (int i = 0; i < 16; i++)
    {
        //upSampleMatrix(sus_S4_rec[i], 2, 2, sus_C3[i]);
        reshapeMatrix(sus_C3[i], 10, 10, 0);
        for (int p = 0; p < 5; p++)
            for (int q = 0; q < 5; q++)
            {
                int rb = p * 2;
                int re = (p + 1) * 2;

                while (rb < re)
                {
                    int cb = (q)*2;
                    int ce = (q + 1) * 2;
                    while (cb < ce)
                    {
                        sus_C3[i][rb][cb] = sus_S4_rec[i][p][q];
                        cb++;
                    }
                    rb++;
                }
            }
        multiplyMatrix(sus_C3[i], S4_coefficient[i]);
    }

    //S2的敏感度
    for (int i = 0; i < 6; i++)
    {
        reshapeMatrix(sus_S2_rec[i], 14, 14, 0);
        for (int k = 0; k < 16; k++)
            if (C3_connect[k][i])
            {
                //padding
                reshapeMatrix(t1, 18, 18, 0);
                for (int p = 0; p < 10; p++)
                    for (int q = 0; q < 10; q++)
                        t1[p + 4][q + 4] = sus_C3[k][p][q];
                rot180Matrix(C3_core[k][i], t2);
                convMatrix(t1, t2, t3);
                plusMatrix(sus_S2_rec[i], t3);
            }
        //点乘
        for (int p = 0; p < 14; p++)
            for (int q = 0; q < 14; q++)
                sus_S2_rec[i][p][q] *= (1.0 - S2[i][p][q]) * (1.0 + S2[i][p][q]);
    }
    //C1的明感度
    for (int i = 0; i < 6; i++)
    {
        //upSampleMatrix(sus_S2_rec[i], 2, 2, sus_C1[i]);
        reshapeMatrix(sus_C1[i], 28, 28);
        for (int j = 0; j < 14; j++)
            for (int k = 0; k < 14; k++)
            {
                int rb = j * 2;
                int re = (j + 1) * 2;
                while (rb < re)
                {
                    int cb = k * 2;
                    int ce = (k + 1) * 2;
                    while (cb < ce)
                    {
                        sus_C1[i][rb][cb] = sus_S2_rec[i][j][k];
                        cb++;
                    }
                    rb++;
                }
            }
        multiplyMatrix(sus_C1[i], S2_coefficient[i]);
    }
}
/*---------------反向传播----------------------*/

/*---------------导数清零----------------------*/
inline void clearDer(void)
{
    for (int i = 0; i < 6; i++)
    {
        for (int j = 0; j < 5; j++)
            for (int k = 0; k < 5; k++)
                der_C1_core[i][j][k] = 0;
        der_C1_bias[i] = 0;
    }
    for (int i = 0; i < 6; i++)
    {
        der_S2_coefficient[i] = 0;
        der_S2_bias[i] = 0;
    }
    for (int i = 0; i < 16; i++)
    {
        der_C3_bias[i] = 0;
        for (int j = 0; j < 6; j++)
            if (C3_connect[i][j])
            {
                for (int p = 0; p < 5; p++)
                    for (int q = 0; q < 5; q++)
                        der_C3_core[i][j][p][q] = 0;
            }
    }
    for (int i = 0; i < 16; i++)
    {
        der_S4_coefficient[i] = 0;
        der_S4_bias[i] = 0;
    }
    for (int i = 0; i < 120; i++)
    {
        der_C5_bias[i] = 0;
        for (int j = 0; j < 16; j++)
        {
            for (int p = 0; p < 5; p++)
                for (int q = 0; q < 5; q++)
                    der_C5_core[i][j][p][q] = 0;
        }
    }
    for (int i = 0; i < 84; i++)
    {
        der_F6_bias[i][0] = 0;
        for (int j = 0; j < 120; j++)
            der_F6_weight[i][j] = 0;
    }
    for (int i = 0; i < 10; i++)
    {
        der_OUTPUT_bias[i][0] = 0;
        for (int j = 0; j < 84; j++)
            der_OUTPUT_weight[i][j] = 0;
    }
}
/*---------------导数清零----------------------*/

/*---------------导数累计----------------------*/
void accumulateDer(void)
{
    Matrix t1, t2, t3;
    //OUTPUT的偏导数
    transposeMatrix(F6, t1);
    multiplyMatrix(sus_OUTPUT_rec, t1, t2);
    plusMatrix(der_OUTPUT_weight, t2);
    plusMatrix(der_OUTPUT_bias, sus_OUTPUT_rec);
    //F6的偏导数
    transposeMatrix(C5, t1);
    multiplyMatrix(sus_F6_rec, t1, t2);
    plusMatrix(der_F6_weight, t2);
    plusMatrix(der_F6_bias, sus_F6_rec);
    //C5的偏导数
    reshapeMatrix(t1, 1, 1, 0);
    for (int i = 0; i < 120; i++)
    {
        der_C5_bias[i] += sus_C5[i][0];
        for (int j = 0; j < 16; j++)
        {
            t1[0][0] = sus_C5[i][0];
            convMatrix(S4[j], t1, t2);
            plusMatrix(der_C5_core[i][j], t2);
        }
    }
    //S4的偏导数
    for (int i = 0; i < 16; i++)
    {
        downSampleMatrix(C3[i], 2, 2, t1);
        convMatrix(t1, sus_S4_rec[i], t2);
        der_S4_coefficient[i] += t2[0][0];
        der_S4_bias[i] += sumMatrix(sus_S4_rec[i]);
    }
    //C3的偏导数
    for (int i = 0; i < 16; i++)
    {
        der_C3_bias[i] += sumMatrix(sus_C3[i]);
        for (int j = 0; j < 6; j++)
        {
            if (C3_connect[i][j])
            {
                convMatrix(S2[j], sus_C3[i], t1);
                plusMatrix(der_C3_core[i][j], t1);
            }
        }
    }
    //S2的偏导数
    for (int i = 0; i < 6; i++)
    {
        downSampleMatrix(C1[i], 2, 2, t1);
        convMatrix(t1, sus_S2_rec[i], t2);
        der_S2_coefficient[i] += t2[0][0];
        der_S2_bias[i] += sumMatrix(sus_S2_rec[i]);
    }
    //C1的偏导数
    for (int i = 0; i < 6; i++)
    {
        der_C1_bias[i] += sumMatrix(sus_C1[i]);
        convMatrix(INPUT, sus_C1[i], t1);
        plusMatrix(der_C1_core[i], t1);
    }
}
/*---------------导数累计----------------------*/

/*---------------梯度检查----------------------*/
void checkGradient(void)
{
    clearDer();
    forwardPropagation(TrainData[0]);
    backPropagation(TrainData[0]);
    accumulateDer();
    double c = der_C1_bias[5];
    cout << c << endl;

    double a;
    C1_bias[5] += EPS;
    forwardPropagation(TrainData[0]);
    cout << (a = Loss(TrainData[0])) << endl;

    double b;
    C1_bias[5] -= 2 * EPS;
    forwardPropagation(TrainData[0]);
    cout << (b = Loss(TrainData[0])) << endl;
    double d = (a - b) / (2 * EPS);
    cout << d << endl;
    cout << (c / d) << endl;
}
/*---------------梯度检查----------------------*/

/*---------------梯度下降----------------------*/
const double step = 0.2;
const int max_iter = 20;
const int batch_size = 20;
void batchGradientDescent(void)
{
    Matrix t1, t2, t3;
    double C = -1.0 * step / batch_size;
    for (int i = 0; i < 6; i++)
    {
        multiplyMatrix(der_C1_core[i], C);
        der_C1_bias[i] *= C;
        plusMatrix(C1_core[i], der_C1_core[i]);
        C1_bias[i] += der_C1_bias[i];
    }
    for (int i = 0; i < 6; i++)
    {
        der_S2_coefficient[i] *= C;
        der_S2_bias[i] *= C;
        S2_coefficient[i] += der_S2_coefficient[i];
        S2_bias[i] += der_S2_bias[i];
    }
    for (int i = 0; i < 16; i++)
    {
        der_C3_bias[i] *= C;
        C3_bias[i] += der_C3_bias[i];
        for (int j = 0; j < 6; j++)
            if (C3_connect[i][j])
            {
                multiplyMatrix(der_C3_core[i][j], C);
                plusMatrix(C3_core[i][j], der_C3_core[i][j]);
            }
    }
    for (int i = 0; i < 16; i++)
    {
        der_S4_coefficient[i] *= C;
        S4_coefficient[i] += der_S4_coefficient[i];
        der_S4_bias[i] *= C;
        S4_bias[i] += der_S4_bias[i];
    }
    for (int i = 0; i < 120; i++)
    {
        der_C5_bias[i] *= C;
        C5_bias[i] += der_C5_bias[i];
        for (int j = 0; j < 16; j++)
        {
            multiplyMatrix(der_C5_core[i][j], C);
            plusMatrix(C5_core[i][j], der_C5_core[i][j]);
        }
    }
    multiplyMatrix(der_F6_weight, C);
    plusMatrix(F6_weight, der_F6_weight);
    multiplyMatrix(der_F6_bias, C);
    plusMatrix(F6_bias, der_F6_bias);
    multiplyMatrix(der_OUTPUT_weight, C);
    plusMatrix(OUTPUT_weight, der_OUTPUT_weight);
    multiplyMatrix(der_OUTPUT_bias, C);
    plusMatrix(OUTPUT_bias, der_OUTPUT_bias);
}
inline double evaluateStudy(void)
{
    int cnt = 0;
    for (int i = 0; i < NUM_TRAIN; i++)
    {
        try
        {
            forwardPropagation(TrainData[i]);
        }
        catch (char const *e)
        {
            cout << "eva" << endl
                 << e << endl;
        }
        int max_pos = 0;
        for (int i = 0; i < 10; i++)
            if (OUTPUT[i] > OUTPUT[max_pos])
                max_pos = i;
        if (getNum(TrainData[i]) == max_pos)
            cnt++;
    }
    return (double)cnt / NUM_TRAIN;
}
inline double evaluateExt(void)
{
    int cnt = 0;
    for (int i = 0; i < NUM_TEST; i++)
    {
        try
        {
            forwardPropagation(TestData[i]);
        }
        catch (char const *e)
        {
            cout << "test" << endl
                 << e << endl;
        }
        int max_pos = 0;
        for (int i = 0; i < 10; i++)
            if (OUTPUT[i] > OUTPUT[max_pos])
                max_pos = i;
        if (getNum(TestData[i]) == max_pos)
            cnt++;
    }
    return (double)cnt / NUM_TEST;
}
int main(void)
{
    clock_t start_time = clock();
    srand(time(0));
    readData(TrainData, TestData);
    Normalize(TrainData);
    Normalize(TestData);
    init();
    try
    {
        int k = 0;
        int max_k = NUM_TRAIN / batch_size;
        while (k < max_k)
        {
            cout << "--------------"
                 << "batch: " << k << "--------------" << endl;
            for (int i = 0; i < max_iter; i++)
            {
                clearDer();
                double loss = 0;
                int max_j = (k + 1) * batch_size;
                for (int j = k * batch_size; j < max_j; j++)
                {
                    forwardPropagation(TrainData[j]);
                    backPropagation(TrainData[j]);
                    accumulateDer();
                    loss += Loss(TrainData[j]);
                }
                cout << "Iter: " << i << " "
                     << "Loss: " << loss << endl;
                if (loss < 0.5)
                    break;
                batchGradientDescent();
            }
            k++;
        }
        clock_t end_time = clock();
        cout << "学习率: " << evaluateStudy() << endl;
        cout << "范化率:" << evaluateExt() << endl;
        cout << "耗时: " << (double)(end_time - start_time) / CLOCKS_PER_SEC << 's' << endl;
    }
    catch (char const *e)
    {
        cout << e << endl;
    }
    return 0;
}

你可能感兴趣的:(AI,C++,MNIST,CNN,LeNet-5,AI)