LeNet5源码实现和训练参数

通过研究LeCun相关论文和Fengbichun的博客提供的源码,才真正的了解了CNN的架构,与我最初的理解是不同的。所以我之前提供的伪代码,有几步都是错误的,这里贴出源码,以继续研究相关结构和图像识别算法。

在下载频道提供了精度在0.985的训练参数:http://download.csdn.net/detail/aitazhixin/9862192点击打开链接

.h

#ifndef __SLQ_LE_NET_5_H__
#define __SLQ_LE_NET_5_H__

#include 
#include 
#include 
/**
* LeNet5 Neural Network Define
**/

namespace slqDL {
/** Train Parameters **/
#define EpochLoop           (100)              // Train Loops
#define AccuracyRate        (0.98)             // Train Accuracy
#define Alpha               (0.01)             // Train Step Length
#define LoopError           (0.001)            //
#define EspCNN              (1e-8)             //

	/** Raw Data Parameters **/
#define TrainImgNum          (60000)            // Train Image Number
#define TestImgNum           (10000)            // Test Image Number
#define RawImgRow            (28)               // Raw Image Row
#define RawImgCol            (28)               // Raw Image Col
#define imageRow             (32)               // Image Row
#define imageCol             (32)               // Image Col

    /** Convolution or Pooling Parameters **/
#define convhsize            (5)                // Convolution factor high size
#define convwsize            (5)                // Convolution factor width size
#define convsize             (25)               // Convolution factor size
#define poolhsize            (2)                // Pooling factor high size
#define poolwsize            (2)                // Pooling factor width size

    /** Feature Map Parameters **/
    /** Input map Parameters **/
#define inMapHigh            (imageRow)         // input map high: equal to imageRow 32
#define inMapWidth           (imageCol)         // input map Width: equal to imageCol 32
#define inMapSize            (1024)             // inMapHigh * inMapWidth: 32*32

    /** C1 Layer Parameters **/
#define c1MapHigh            (28)               //
#define c1MapWidth           (28)               //
#define f1MapSize            (784)              //
#define c1MapNum             (6)                //
#define c1optNum             (150)              // 6 * 5 * 5
#define c1MapSize            (4704)             //

    /** S2 Layer Parameters **/
#define s2MapHigh            (14)               //
#define s2MapWidth           (14)               //
#define f2MapSize            (196)              //
#define s2MapNum             (6)                //
#define s2MapSize            (1176)             //

    /** C3 Layer Parameters **/
#define c3MapHigh            (10)               //
#define c3MapWidth           (10)               //
#define f3MapSize            (100)              //
#define c3MapNum             (16)               //
#define c3optNum             (2400)             // 16 * 6 * 5 * 5
#define c3MapSize            (1600)             //

    /** S4 Layer Parameters **/
#define s4MapHigh            (5)                //
#define s4MapWidth           (5)                //
#define f4MapSize            (25)               //
#define s4MapNum             (16)               //
#define s4MapSize            (400)              //

    /** C5 Layer Parameters **/
#define c5MapHigh            (1)                //
#define c5MapWidth           (1)                //
#define c5MapNum             (120)              //
#define c5optNum             (48000)            // 120 * 16 * 5 * 5
#define c5MapSize            (120)              //

    /** Output Map Parameters **/
#define outMapSize           (10)               //
#define outoptNum            (1200)             //

#define TabNum               (10)               //

#define ACTIVATION(x)       ((std::exp((x)) - std::exp(-1*(x))) / (std::exp((x)) + std::exp(-1*(x))))
#define ACTDEVICE(x)        ((1 - ((x))*((x))))

#define O                   (true)
#define X                   (false)

	static const bool Twomap2ThreeTable[s2MapNum][c3MapNum] =
	{
		O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,
		O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,
		O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,
		X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,
		X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,
		X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O
	};

	//static const bool Fourmap2FiveTabel[s4MapNum][c5MapNum];

class slqLeNet5
{
public:
    slqLeNet5() = default;
    slqLeNet5(const slqLeNet5 &lenet) = default;
    slqLeNet5 & operator = (const slqLeNet5 &lenet) = default;
    ~slqLeNet5();

    void init();
    void train();

private:
    void initMap();
    void ForwardC1();
    void ForwardS2();
    void ForwardC3();
    void ForwardS4();
    void ForwardC5();
    void ForwardOut();

    void BackwardOut();
    void BackwardC5();
    void BackwardS4();
    void BackwardC3();
    void BackwardS2();
    void BackwardC1();

    void UpgradeNetwork();
	void UpdateParameters(double *delta, double *Edelta, double *para, int len);

    double test();

    void SaveParameters();
    void RandomBias(double *randVector, int vLen);
	void ProduceLabel();
	void RegularMap(char *cmap, double *mapdata);

	void uniform_rand(double* src, int len, double min, double max);
	double uniform_rand(double min, double max);


private:


	double c1map[c1MapSize];
	double c1bias[c1MapNum];
	double c1conv[c1optNum];
	double c1bias_dt[c1MapNum];
	double c1bias_Edt[c1MapNum];
	double c1conv_dt[c1optNum];
	double c1conv_Edt[c1optNum];
	double c1delta[c1MapSize];

	double s2map[s2MapSize];
	double s2bias[s2MapNum];
	double s2pool[s2MapNum];
	double s2bias_dt[s2MapNum];
	double s2bias_Edt[s2MapNum];
	double s2pool_dt[s2MapNum];
	double s2pool_Edt[s2MapNum];
	double s2delta[s2MapSize];

	double c3map[c3MapSize];
	double c3bias[c3MapNum];
	double c3conv[c3optNum];
	double c3bias_dt[c3MapNum];
	double c3bias_Edt[c3MapNum];
	double c3conv_dt[c3optNum];
	double c3conv_Edt[c3optNum];
	double c3delta[c3MapSize];

	double s4map[s4MapSize];
	double s4bias[s4MapNum];
	double s4pool[s4MapNum];
	double s4bias_dt[s4MapNum];
	double s4bias_Edt[s4MapNum];
	double s4pool_dt[s4MapNum];
	double s4pool_Edt[s4MapNum];
	double s4delta[s4MapSize];

	double c5map[c5MapSize];
	double c5bias[c5MapNum];
	double c5conv[c5optNum];
	double c5bias_dt[c5MapNum];
	double c5bias_Edt[c5MapNum];
	double c5conv_dt[c5optNum];
	double c5conv_Edt[c5optNum];
	double c5delta[c5MapSize];

	double outmap[outMapSize];
	double outbias[outMapSize];
	double outfullconn[outoptNum];
	double outbias_dt[outMapSize];
	double outbias_Edt[outMapSize];
	double outfull_dt[outoptNum];
	double outfull_Edt[outoptNum];
	double outdelta[outMapSize];

	double label[outMapSize];

    char *trainImg;
    char *trainLabel;
    char *testImg;
    char *testLabel;

	double *trainData;
	double *testData;

	double *inmap;
    char *tstLabelPtr;
    char *curLabelPtr;
    bool inited = false;
	bool trainInit = false;
	bool testInit = false;

	int mIdx;
	int hIdx;
	int vIdx;
	int chIdx;
	int cvIdx;

	int layermulti;
	int hmulti;

	LARGE_INTEGER StartCount;
	LARGE_INTEGER EndCount;
	LARGE_INTEGER CountFreq;
};
} // end namespace slqDL

#endif




.cpp

#include 
#include 
#include 
#include 
#include 
#include "slqLeNet5.h"

using namespace std;

namespace slqDL {
    slqLeNet5::~slqLeNet5()
    {
        if (trainImg)
        {
            delete trainImg;
            trainImg = nullptr;
        }
        if (trainLabel)
        {
            delete trainLabel;
            trainLabel = nullptr;
        }
        if (testImg)
        {
            delete testImg;
            testImg = nullptr;
        }
        if (testLabel)
        {
            delete testLabel;
            testLabel = nullptr;
        }

		if (trainData)
		{
			delete trainData;
			trainData = nullptr;
		}
		if (testData)
		{
			delete testData;
			testData = nullptr;
		}



    }

    void slqLeNet5::init()
    {
        ifstream trainimgstream;
        ifstream testimgstream;

        ifstream trainLabelStream;
        ifstream testLabelStream;

        initMap();

        trainimgstream.open("train-images.idx3-ubytepad", ifstream::binary | ifstream::in);
        trainimgstream.read(trainImg, TrainImgNum*imageRow*imageCol);
        trainimgstream.close();

        testimgstream.open("t10k-images.idx3-ubytepad", ifstream::binary | ifstream::in);
        testimgstream.read(testImg, TestImgNum*imageRow*imageCol);
        testimgstream.close();

        trainLabelStream.open("train-labels.idx1-ubytepad", ifstream::binary | ifstream::in);
        trainLabelStream.read(trainLabel, TrainImgNum);
        trainLabelStream.close();

        testLabelStream.open("t10k-labels.idx1-ubytepad", ifstream::binary | ifstream::in);
        testLabelStream.read(testLabel, TestImgNum);
        testLabelStream.close();

		inited = true;

    }

    void slqLeNet5::train()
    {
        if (false == inited)
            return;

        int eidx = 0;
        int lidx = 0;
        double precised = 0;
        double errored = 0;
        double cur_accur = 0;
		double currentErr;

		int Tab = 0;
		int TabCount = 0;
        for (eidx = 0; eidx < EpochLoop; eidx++)
        {

            cout << "epoch " << eidx << endl;
			QueryPerformanceFrequency(&CountFreq);
			QueryPerformanceCounter(&StartCount);
            for (lidx = 0; lidx < TrainImgNum; lidx++)
			{
				curLabelPtr = trainLabel + lidx;

				ProduceLabel();

				if (false == trainInit)
				{
					RegularMap(trainImg + lidx * inMapSize, trainData + lidx * inMapSize);
				}
                inmap = trainData + lidx * inMapSize;

				ForwardC1();
				ForwardS2();
				ForwardC3();
				ForwardS4();
				ForwardC5();
				ForwardOut();

				BackwardOut();
				BackwardC5();
				BackwardS4();
				BackwardC3();
				BackwardS2();
				BackwardC1();

				UpgradeNetwork();
                
            }

			QueryPerformanceCounter(&EndCount);
			cout << " time " << (double)(EndCount.QuadPart - StartCount.QuadPart) / CountFreq.QuadPart << endl;
			if (false == trainInit)
			{
				delete trainImg;
				trainImg = nullptr;
				trainInit = true;
			}



            cur_accur = test();

			cout << "current accuracy " << cur_accur << endl;

            if (cur_accur >= AccuracyRate)
            {
                cout << "current accuracy " << cur_accur << endl;
                SaveParameters();
                return;
            }
            
        }

        if (eidx == EpochLoop)
        {
            cout << "Non-precise current accuracy " << cur_accur << endl;
            SaveParameters();
        }
    }

    void slqLeNet5::initMap()
	{
		srand(time(0) + rand());
		const double scale = 6.0;

		double min_ = -std::sqrt(scale / (25.0 + 150.0));
		double max_ = std::sqrt(scale / (25.0 + 150.0));
		uniform_rand(c1conv, c1optNum, min_, max_);

		min_ = -std::sqrt(scale / (4.0 + 1.0));
		max_ = std::sqrt(scale / (4.0 + 1.0));
		uniform_rand(s2pool, s2MapNum, min_, max_);

		min_ = -std::sqrt(scale / (150.0 + 400.0));
		max_ = std::sqrt(scale / (150.0 + 400.0));
		uniform_rand(c3conv, c3optNum, min_, max_);

		min_ = -std::sqrt(scale / (4.0 + 1.0));
		max_ = std::sqrt(scale / (4.0 + 1.0));
		uniform_rand(s4pool, s4MapNum, min_, max_);

		min_ = -std::sqrt(scale / (400.0 + 3000.0));
		max_ = std::sqrt(scale / (400.0 + 3000.0));
		uniform_rand(c5conv, c5optNum, min_, max_);

		min_ = -std::sqrt(scale / (120.0 + 10.0));
		max_ = std::sqrt(scale / (120.0 + 10.0));
		uniform_rand(outfullconn, outoptNum, min_, max_);


		std::fill(c1bias, c1bias + c1MapNum, 0.0);
		std::fill(s2bias, s2bias + s2MapNum, 0.0);
		std::fill(c3bias, c3bias + c3MapNum, 0.0);
		std::fill(s4bias, s4bias + s4MapNum, 0.0);
		std::fill(c5bias, c5bias + c5MapNum, 0.0);
		std::fill(outbias, outbias + outMapSize, 0.0);

		std::fill(c1bias_Edt, c1bias_Edt + c1MapNum, 0.0);
		std::fill(c1conv_Edt, c1conv_Edt + c1optNum, 0.0);
		std::fill(s2bias_Edt, s2bias_Edt + s2MapNum, 0.0);
		std::fill(s2pool_Edt, s2pool_Edt + s2MapNum, 0.0);
		std::fill(c3bias_Edt, c3bias_Edt + c3MapNum, 0.0);
		std::fill(c3conv_Edt, c3conv_Edt + c3optNum, 0.0);
		std::fill(s4bias_Edt, s4bias_Edt + s4MapNum, 0.0);
		std::fill(s4pool_Edt, s4pool_Edt + s4MapNum, 0.0);
		std::fill(c5bias_Edt, c5bias_Edt + c5MapNum, 0.0);
		std::fill(c5conv_Edt, c5conv_Edt + c5optNum, 0.0);
		std::fill(outbias_Edt, outbias_Edt + outMapSize, 0.0);
		std::fill(outfull_Edt, outfull_Edt + outoptNum, 0.0);

        trainImg = new char[TrainImgNum*inMapSize];
        trainLabel = new char[TrainImgNum];
        testImg = new char[TestImgNum*inMapSize];
        testLabel = new char[TestImgNum];

		trainData = new double[TrainImgNum*inMapSize];
		testData = new double[TestImgNum*inMapSize];
    }

    void slqLeNet5::ForwardC1()
    {

        for (mIdx = 0; mIdx < c1MapNum; mIdx++)
        {
			layermulti = mIdx*f1MapSize;
            for (hIdx = 0; hIdx < c1MapHigh; hIdx++)
            {
				hmulti = hIdx*c1MapWidth;
                for (vIdx = 0; vIdx < c1MapWidth; vIdx++)
                {
                    double *curmap = c1map + layermulti + hmulti + vIdx;
                    *curmap = 0.f;

                    for (chIdx = 0; chIdx < convhsize; chIdx++)
                    {
                        for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
                        {
                            *curmap += *(inmap + (hIdx + chIdx)*imageCol + vIdx + cvIdx) * (*(c1conv + mIdx*convsize + chIdx*convwsize + cvIdx));
                        }
                    }

                    *curmap += c1bias[mIdx];
                    *curmap = ACTIVATION(*curmap);
                }
            }
        }
    }

    void slqLeNet5::ForwardS2()
    {

        for (mIdx = 0; mIdx < s2MapNum; mIdx++)
        {
			layermulti = mIdx * f2MapSize;
            for (hIdx = 0; hIdx < s2MapHigh; hIdx++)
            {
				hmulti = hIdx*s2MapWidth;
                for (vIdx = 0; vIdx < s2MapWidth; vIdx++)
                {
                    double *curmap = s2map + layermulti + hmulti + vIdx;
                    double *c1cur = c1map + layermulti*4 + (hIdx * 2)*c1MapWidth + (vIdx * 2);
                    *curmap = s2bias[mIdx] + s2pool[mIdx] * (*c1cur + *(c1cur + 1) + *(c1cur + c1MapWidth) + *(c1cur + c1MapWidth + 1)) / 4.0;
                    *curmap = ACTIVATION(*curmap);
                }
            }
        }
    }

    void slqLeNet5::ForwardC3()
    {

        for (mIdx = 0; mIdx < c3MapNum; mIdx++)
        {
			layermulti = mIdx*f3MapSize;
            for (hIdx = 0; hIdx < c3MapHigh; hIdx++)
            {
				hmulti = hIdx * c3MapWidth;
                for (vIdx = 0; vIdx < c3MapWidth; vIdx++)
                {
                    double *curmap = c3map + layermulti + hmulti + vIdx;
                    *curmap = 0;

					for (int pIdx = 0; pIdx < s2MapNum; pIdx++)
					{
						//if (!Twomap2ThreeTable[pIdx][mIdx])
						//	continue;

						double *curconv = c3conv + mIdx * s2MapNum * convsize + pIdx * convsize;
						for (chIdx = 0; chIdx < convhsize; chIdx++)
						{
							for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
							{
								*curmap += *(s2map + pIdx * f2MapSize + (hIdx + chIdx)*s2MapWidth + (vIdx + cvIdx)) * (*(curconv + chIdx*convwsize + cvIdx));
							}
						}
					}


                    *curmap += c3bias[mIdx];
                    *curmap = ACTIVATION(*curmap);
                }
            }
        }
    }

    void slqLeNet5::ForwardS4()
    {

        for (mIdx = 0; mIdx < s4MapNum; mIdx++)
        {
            for (hIdx = 0; hIdx < s4MapHigh; hIdx++)
            {
                for (vIdx = 0; vIdx < s4MapWidth; vIdx++)
                {
                    double *curmap = s4map + mIdx*f4MapSize + hIdx*s4MapHigh + vIdx;
                    double *curc3 = c3map + mIdx*f3MapSize + (hIdx * 2)*c3MapWidth + (vIdx * 2);

                    *curmap = s4bias[mIdx] + s4pool[mIdx] * (*curc3 + *(curc3+1) + *(curc3+c3MapWidth) + *(curc3+c3MapWidth+1)) / 4.0;

                    *curmap = ACTIVATION(*curmap);
                }
            }
        }
    }

    void slqLeNet5::ForwardC5()
    {

        for (mIdx = 0; mIdx < c5MapNum; mIdx++)
        {
            for (hIdx = 0; hIdx < c5MapHigh; hIdx++)
            {
                for (vIdx = 0; vIdx < c5MapWidth; vIdx++)
                {
                    double *curmap = c5map + mIdx*c5MapHigh*c5MapWidth + hIdx*c5MapHigh + vIdx;
                    *curmap = 0;

					for (int pIdx = 0; pIdx < s4MapNum; pIdx++)
					{
						double *curconv = c5conv + mIdx * s4MapNum * convsize + pIdx * convsize;
						for (chIdx = 0; chIdx < convhsize; chIdx++)
						{
							for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
							{
								*curmap += *(s4map + pIdx * f4MapSize + (hIdx + chIdx)*s4MapWidth + (vIdx + cvIdx)) * (*(curconv + chIdx*convhsize + cvIdx));
							}
						}
					}


                    *curmap += c5bias[mIdx];

                    *curmap = ACTIVATION(*curmap);
                }
            }
        }
    }

    void slqLeNet5::ForwardOut()
    {

        for (mIdx = 0; mIdx < outMapSize; mIdx++)
        {
            double *curmap = outmap + mIdx;
            *curmap = 0.f;

            for (hIdx = 0; hIdx < c5MapNum; hIdx++)
            {
                *curmap += c5map[hIdx] * outfullconn[hIdx*outMapSize+mIdx];
            }

            *curmap += outbias[mIdx];
            *curmap = ACTIVATION(*curmap);
        }
    }

    void slqLeNet5::BackwardOut()
    {

        for (mIdx = 0; mIdx < outMapSize; mIdx++)
        {
            outdelta[mIdx] = (outmap[mIdx] - label[mIdx]) * (ACTDEVICE(outmap[mIdx]));
            outbias_dt[mIdx] = outdelta[mIdx];
        }

        for (hIdx = 0; hIdx < c5MapNum; hIdx++)
        {
            for (vIdx = 0; vIdx < outMapSize; vIdx++)
            {
                outfull_dt[hIdx*outMapSize + vIdx] = c5map[hIdx] * outdelta[vIdx];
            }
        }
    }

    void slqLeNet5::BackwardC5()
    {

        for (mIdx = 0; mIdx < c5MapNum; mIdx++)
        {
            double curerr = 0;
            for (vIdx = 0; vIdx < outMapSize; vIdx++)
            {
                curerr += outdelta[vIdx] * outfullconn[mIdx*outMapSize + vIdx];
            }

            c5delta[mIdx] = ACTDEVICE(c5map[mIdx]) * curerr;

            c5bias_dt[mIdx] = c5delta[mIdx];

			for (int pIdx = 0; pIdx < s4MapNum; pIdx++)
			{
				double *curconv = c5conv_dt + mIdx * s4MapNum * convsize + pIdx*convsize;
				for (hIdx = 0; hIdx < s4MapHigh; hIdx++)
				{
					for (vIdx = 0; vIdx < s4MapWidth; vIdx++)
					{
						double *convdlt = curconv + hIdx*convwsize + vIdx;
						*convdlt = *(s4map + pIdx * f4MapSize + hIdx * convhsize + vIdx) * c5delta[mIdx];
					}
				}
			}
        }
    }

    void slqLeNet5::BackwardS4()
    {   

        for (mIdx = 0; mIdx < s4MapNum; mIdx++)
        {
			s4bias_dt[mIdx] = 0;
			s4pool_dt[mIdx] = 0;

			layermulti = mIdx*f4MapSize;
            for (hIdx = 0; hIdx < s4MapHigh; hIdx++)
            {
                for (vIdx = 0; vIdx < s4MapWidth; vIdx++)
                {
                    double *curdt = s4delta + layermulti + hIdx*s4MapWidth + vIdx;
                    double *curmap = s4map + layermulti + hIdx*s4MapWidth + vIdx;
                    double *c3cur = c3map + mIdx*f3MapSize + (hIdx * 2)*c3MapWidth + (vIdx * 2);

					*curdt = 0;

					for (int pIdx = 0; pIdx < c5MapNum; pIdx++)
					{
						double *curconv = c5conv + pIdx * s4MapNum * convsize + mIdx * convsize + hIdx * convhsize + vIdx;
						*curdt += *curconv * c5delta[pIdx];
					}
                    *curdt = ACTDEVICE(*curmap) * (*curdt);
					s4bias_dt[mIdx] += *curdt;

					s4pool_dt[mIdx] += (*c3cur + *(c3cur + 1) + *(c3cur + c3MapWidth) + *(c3cur + c3MapWidth + 1)) / 4.0 * (*curdt);
                }
            }
        }

    }

    void slqLeNet5::BackwardC3()
    {

        for (mIdx = 0; mIdx < c3MapNum; mIdx++)
        {
			c3bias_dt[mIdx] = 0;
			layermulti = mIdx*f3MapSize;
            for (hIdx = 0; hIdx < c3MapHigh; hIdx++)
            {
                for (vIdx = 0; vIdx < c3MapWidth; vIdx++)
                {
                    double *curdt = c3delta + layermulti + hIdx*c3MapWidth + vIdx;
                    double *curmap = c3map + layermulti + hIdx*c3MapWidth + vIdx;
                    double *s4dt = s4delta + mIdx*f4MapSize + (hIdx / 2)*s4MapWidth + (vIdx / 2);
                    *curdt = ACTDEVICE(*curmap)*s4pool[mIdx] * (*s4dt) / 4.0;

					c3bias_dt[mIdx] += *curdt;
                }
            }


			for (int pIdx = 0; pIdx < s2MapNum; pIdx++)
			{
				//if (!Twomap2ThreeTable[pIdx][mIdx])
				//	continue;

				double *curcdt = c3conv_dt + mIdx * s2MapNum * convsize + pIdx * convsize;
				for (chIdx = 0; chIdx < convhsize; chIdx++)
				{
					for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
					{
						double *curconv = curcdt + chIdx * convhsize + cvIdx;
						*curconv = 0;

						for (hIdx = 0; hIdx < c3MapHigh; hIdx++)
						{
							for (vIdx = 0; vIdx < c3MapWidth; vIdx++)
							{
								*curconv += *(s2map + pIdx * f2MapSize + (chIdx + hIdx)*s2MapWidth + (cvIdx + vIdx)) * c3delta[layermulti + hIdx*c3MapWidth + vIdx];
							}
						}
					}
				}
			}

        }
    }

    void slqLeNet5::BackwardS2()
    {

		memset((char*)s2delta, 0x0, s2MapSize*sizeof(double));
        for (mIdx = 0; mIdx < s2MapNum; mIdx++)
        {
            for (hIdx = 0; hIdx < c3MapHigh; hIdx++)
            {
                for (vIdx = 0; vIdx < c3MapWidth; vIdx++)
                {

					for (int pIdx = 0; pIdx < c3MapNum; pIdx++)
					{
						//if (!Twomap2ThreeTable[mIdx][pIdx])
						//	continue;

						double *curconv = c3conv + pIdx * s2MapNum * convsize + mIdx * convsize;
						for (chIdx = 0; chIdx < convhsize; chIdx++)
						{
							for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
							{
								double *curdt = s2delta + mIdx * f2MapSize + (hIdx + chIdx)*s2MapWidth + (vIdx + cvIdx);
								*curdt += *(curconv + chIdx * convwsize + cvIdx) * (*(c3delta + pIdx * f3MapSize + hIdx*c3MapWidth + vIdx));
							}
						}
					}
                }
            }
            s2bias_dt[mIdx] = 0;
            s2pool_dt[mIdx] = 0;
            for (hIdx = 0; hIdx < s2MapHigh; hIdx++)
            {
                for (vIdx = 0; vIdx < s2MapWidth; vIdx++)
                {
					double *curdt = s2delta + mIdx * f2MapSize + hIdx * s2MapWidth + vIdx;
					double *curmap = s2map + mIdx * f2MapSize + hIdx * s2MapWidth + vIdx;
                    double *curc1 = c1map+mIdx*f1MapSize+(hIdx*2)*c1MapWidth+(vIdx*2);

					*curdt *= ACTDEVICE(*curmap);
					s2bias_dt[mIdx] += *curdt;
                    s2pool_dt[mIdx] += *curdt * (*curc1 + *(curc1 + 1) + *(curc1 + c1MapWidth) + *(curc1 + c1MapWidth + 1)) / 4;
                }
            }
        }
    }

    void slqLeNet5::BackwardC1()
    {

        for (mIdx = 0; mIdx < c1MapNum; mIdx++)
        {
			c1bias_dt[mIdx] = 0;

			layermulti = mIdx*f1MapSize;
            for (hIdx = 0; hIdx < c1MapHigh; hIdx++)
            {
                for (vIdx = 0; vIdx < c1MapWidth; vIdx++)
                {
                    double *curdt = c1delta + layermulti + hIdx*c1MapWidth + vIdx;
                    double *curmap = c1map + layermulti + hIdx*c1MapWidth + vIdx;

                    *curdt = ACTDEVICE(*curmap) * s2pool[mIdx] * s2delta[layermulti/4 +(hIdx/2)*s2MapWidth+(vIdx/2)] / 4.0;

					c1bias_dt[mIdx] += *curdt;
                }
            }


            for (chIdx = 0; chIdx < convhsize; chIdx++)
            {
                for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
                {
                    double *cdt = c1conv_dt + mIdx*convsize + chIdx*convwsize + cvIdx;
                    *cdt = 0;

                    for (hIdx = 0; hIdx < c1MapHigh; hIdx++)
                    {
                        for (vIdx = 0; vIdx < c1MapWidth; vIdx++)
                        {
                            *cdt += inmap[(hIdx + chIdx) * inMapWidth + (vIdx + cvIdx)] * c1delta[mIdx*c1MapHigh*c1MapWidth + hIdx*c1MapWidth + vIdx];
                        }
                    }
                }
            }
        }
    }

    void slqLeNet5::UpgradeNetwork()
    {
		UpdateParameters(c1bias_dt, c1bias_Edt, c1bias, c1MapNum);
		UpdateParameters(c1conv_dt, c1conv_Edt, c1conv, c1optNum);

		UpdateParameters(s2bias_dt, s2bias_Edt, s2bias, s2MapNum);
		UpdateParameters(s2pool_dt, s2pool_Edt, s2pool, s2MapNum);

		UpdateParameters(c3bias_dt, c3bias_Edt, c3bias, c3MapNum);
		UpdateParameters(c3conv_dt, c3conv_Edt, c3conv, c3optNum);

		UpdateParameters(s4bias_dt, s4bias_Edt, s2bias, s4MapNum);
		UpdateParameters(s4pool_dt, s4pool_Edt, s2pool, s4MapNum);

		UpdateParameters(c5bias_dt, c5bias_Edt, c5bias, c5MapNum);
		UpdateParameters(c5conv_dt, c5conv_Edt, c5conv, c5optNum);

		UpdateParameters(outbias_dt, outbias_Edt, outbias, outMapSize);
		UpdateParameters(outfull_dt, outfull_Edt, outfullconn, outoptNum);

    }

	void slqLeNet5::UpdateParameters(double *delta, double *Edelta, double *para, int len)
	{
		for (int lIdx = 0; lIdx < len; lIdx++)
		{
			Edelta[lIdx] += delta[lIdx] * delta[lIdx];
			para[lIdx] -= Alpha * delta[lIdx] / (std::sqrt(Edelta[lIdx]) + EspCNN);
		}
	}

    double slqLeNet5::test()
    {
        double preci = 0.f;
        double maxpred = 0.f;
		double maxEnd = -10.f;

		int tIdx;
        int oIdx;
        int dIdx;

        for (tIdx = 0; tIdx < TestImgNum; tIdx++)
        {
			maxEnd = -10.f;
            tstLabelPtr = testLabel + tIdx;

			if (false == testInit)
				RegularMap(testImg + tIdx*inMapSize, testData + tIdx*inMapSize);

            inmap = testData + tIdx*inMapSize;

			ForwardC1();
			ForwardS2();
			ForwardC3();
			ForwardS4();
			ForwardC5();
			ForwardOut(); 
			
			maxpred = -10;
			for (oIdx = 0; oIdx < outMapSize; oIdx++)
			{
				maxpred = maxpred > outmap[oIdx] ? maxpred : outmap[oIdx];
				dIdx = maxpred > outmap[oIdx] ? dIdx : oIdx;
			}

			if (dIdx == (int)(*tstLabelPtr))
			{
				preci += 1.0; 
			}
        }

		if (false == testInit)
		{
			delete testImg;
			testImg = nullptr;
			testInit = true;
		}

        return (preci / TestImgNum);
    }

    void slqLeNet5::RandomBias(double *randVector, int vLen)
    {
        int tIdx = 0;
		for (; tIdx < vLen; tIdx++)
		{
			randVector[tIdx] = 0;
		}
    }


    void slqLeNet5::SaveParameters()
    {
        fstream fileStream;

        fileStream.open("parameters_block", ifstream::out | ifstream::binary);

        fileStream << c1conv << c1bias << s2pool << s2bias << \
            c3conv << c3bias << s4pool << s4bias << c5conv << c5bias << outfullconn << outbias;

        fileStream.close();
    }

	void slqLeNet5::ProduceLabel()
	{
		int tIdx;

		for (tIdx = 0; tIdx < outMapSize; tIdx++)
		{
			if (tIdx == *curLabelPtr)
				label[tIdx] = 0.8f;
			else
				label[tIdx] = -0.8f;
		}
	}

	void slqLeNet5::RegularMap(char *cmap, double *rmap)
	{
		for (hIdx = 0; hIdx < imageRow; hIdx++)
		{
			for (vIdx = 0; vIdx < imageCol; vIdx++)
			{
				int tmp = (int)(*(cmap + hIdx*imageCol + vIdx));

				tmp = tmp < 0 ? (tmp + 256) : tmp;
				*(rmap + hIdx*imageCol + vIdx) = tmp / 255.0 * 2 - 1;
			}
		}
	}

	void slqLeNet5::uniform_rand(double* src, int len, double min, double max)
	{
		for (int i = 0; i < len; i++) {
			src[i] = uniform_rand(min, max);
		}

	}

	double slqLeNet5::uniform_rand(double min, double max)
	{
		std::random_device rd;
		std::mt19937 gen(rd());
		std::uniform_real_distribution dst(min, max);
		return dst(gen);
	}

}

这里参数随机初始化函数采用了fengbichun的方法,因为我使用的rand初始化参数不理想,算法不收敛。

对于经过大数据训练的CNN结构可以实现识别0~9这十个数字,表示很惊奇。


上面的保存参数到文件的函数整错了,修改为下面:

    void slqLeNet5::SaveParameters()
    {
        ofstream fileStream;

        fileStream.open("parameters_block", ofstream::out | ofstream::binary);

		fileStream.write((char*)c1conv, sizeof(double)*c1optNum);
		fileStream.write((char*)c1bias, sizeof(double)*c1MapNum);
		fileStream.write((char*)s2pool, sizeof(double)*s2MapNum);
		fileStream.write((char*)s2bias, sizeof(double)*s2MapNum);
		fileStream.write((char*)c3conv, sizeof(double)*c3optNum);
		fileStream.write((char*)c3bias, sizeof(double)*c3MapNum);
		fileStream.write((char*)s4pool, sizeof(double)*s4MapNum);
		fileStream.write((char*)s4bias, sizeof(double)*s4MapNum);
		fileStream.write((char*)c5conv, sizeof(double)*c5optNum);
		fileStream.write((char*)c5bias, sizeof(double)*c5MapNum);
		fileStream.write((char*)outfullconn, sizeof(double)*outoptNum);
		fileStream.write((char*)outbias, sizeof(double)*outMapSize);

        fileStream.close();
    }

	void slqLeNet5::ReadParameters()
	{
		ifstream fileStream;

		fileStream.open("parameters_block", ifstream::in | ifstream::binary);

		fileStream.read((char*)c1conv, sizeof(double)*c1optNum);
		fileStream.read((char*)c1bias, sizeof(double)*c1MapNum);
		fileStream.read((char*)s2pool, sizeof(double)*s2MapNum);
		fileStream.read((char*)s2bias, sizeof(double)*s2MapNum);
		fileStream.read((char*)c3conv, sizeof(double)*c3optNum);
		fileStream.read((char*)c3bias, sizeof(double)*c3MapNum);
		fileStream.read((char*)s4pool, sizeof(double)*s4MapNum);
		fileStream.read((char*)s4bias, sizeof(double)*s4MapNum);
		fileStream.read((char*)c5conv, sizeof(double)*c5optNum);
		fileStream.read((char*)c5bias, sizeof(double)*c5MapNum);
		fileStream.read((char*)outfullconn, sizeof(double)*outoptNum);
		fileStream.read((char*)outbias, sizeof(double)*outMapSize);

		fileStream.close();
	}



预测

	void slqLeNet5::predict()
	{
		HANDLE hfile;
		LPCTSTR lpFileName = L".\\*.png";
		WIN32_FIND_DATA pNextInfo;

		double imgArray[inMapSize];
		double *imgAPtr;
		char *imgBPtr;
		char imgName[20];
		char ch;

		double maxV = -10;
		int maxIdx = -1;

		
		ReadParameters();

		while (true)
		{

			hfile = FindFirstFile(lpFileName, &pNextInfo);
			if (INVALID_HANDLE_VALUE == hfile)
			{
				cout << "Error Handle to get First File" << endl;
				break;
			}

			while (FindNextFile(hfile, &pNextInfo))
			{
				cv::Mat image;
				for (hIdx = 0; hIdx < (wcslen(pNextInfo.cFileName) + 2); hIdx++)
				{
					imgName[hIdx] = pNextInfo.cFileName[hIdx];
				}
				imgName[hIdx] = '\0';
				cout << "image " << imgName << endl;
				image = cv::imread(imgName, CV_8UC1);
				if (image.empty())
				{
					cout << "Error image" << endl;
					return;
				}

				std::fill(imgArray, imgArray + inMapSize, -1);

				imgAPtr = imgArray;
				imgBPtr = (char*)image.data;

				double topV = -256;
				double botV = 256;

				for (hIdx = 0; hIdx < RawImgRow; hIdx++)
				{
					for (vIdx = 0; vIdx < RawImgCol; vIdx++)
					{
						int tmp = (int)(*(imgBPtr + hIdx*RawImgCol + vIdx));

						tmp = tmp < 0 ? (tmp + 256) : tmp;
						topV = topV > tmp ? topV : tmp;
						botV = botV < tmp ? botV : tmp;
					}
				}


				for (hIdx = 0; hIdx < RawImgRow; hIdx++)
				{
					for (vIdx = 0; vIdx < RawImgCol; vIdx++)
					{
						int tmp = (int)(*(imgBPtr + hIdx*RawImgCol + vIdx));

						tmp = tmp < 0 ? (tmp + 256) : tmp;
						*(imgAPtr + (hIdx + 2)*imageCol + (vIdx + 2)) = (tmp - botV) / (topV - botV) * 2.0 - 1.0;
					}
				}

				inmap = imgArray;
				ForwardC1();
				ForwardS2();
				ForwardC3();
				ForwardS4();
				ForwardC5();
				ForwardOut();

				maxV = -10;
				maxIdx = -1;
				for (hIdx = 0; hIdx < outMapSize; hIdx++)
				{
					maxV = maxV > outmap[hIdx] ? maxV : outmap[hIdx];
					maxIdx = maxV > outmap[hIdx] ? maxIdx : hIdx;
				}

				cout << "predict " << maxIdx << endl;

				ch = getchar();

				if ('!' == ch)
					return;

			}

		}

	}



你可能感兴趣的:(算法,AI)