通过研究LeCun相关论文和Fengbichun的博客提供的源码,才真正的了解了CNN的架构,与我最初的理解是不同的。所以我之前提供的伪代码,有几步都是错误的,这里贴出源码,以继续研究相关结构和图像识别算法。
在下载频道提供了精度在0.985的训练参数:http://download.csdn.net/detail/aitazhixin/9862192点击打开链接
.h
#ifndef __SLQ_LE_NET_5_H__
#define __SLQ_LE_NET_5_H__
#include
#include
#include
/**
* LeNet5 Neural Network Define
**/
namespace slqDL {
/** Train Parameters **/
#define EpochLoop (100) // Train Loops
#define AccuracyRate (0.98) // Train Accuracy
#define Alpha (0.01) // Train Step Length
#define LoopError (0.001) //
#define EspCNN (1e-8) //
/** Raw Data Parameters **/
#define TrainImgNum (60000) // Train Image Number
#define TestImgNum (10000) // Test Image Number
#define RawImgRow (28) // Raw Image Row
#define RawImgCol (28) // Raw Image Col
#define imageRow (32) // Image Row
#define imageCol (32) // Image Col
/** Convolution or Pooling Parameters **/
#define convhsize (5) // Convolution factor high size
#define convwsize (5) // Convolution factor width size
#define convsize (25) // Convolution factor size
#define poolhsize (2) // Pooling factor high size
#define poolwsize (2) // Pooling factor width size
/** Feature Map Parameters **/
/** Input map Parameters **/
#define inMapHigh (imageRow) // input map high: equal to imageRow 32
#define inMapWidth (imageCol) // input map Width: equal to imageCol 32
#define inMapSize (1024) // inMapHigh * inMapWidth: 32*32
/** C1 Layer Parameters **/
#define c1MapHigh (28) //
#define c1MapWidth (28) //
#define f1MapSize (784) //
#define c1MapNum (6) //
#define c1optNum (150) // 6 * 5 * 5
#define c1MapSize (4704) //
/** S2 Layer Parameters **/
#define s2MapHigh (14) //
#define s2MapWidth (14) //
#define f2MapSize (196) //
#define s2MapNum (6) //
#define s2MapSize (1176) //
/** C3 Layer Parameters **/
#define c3MapHigh (10) //
#define c3MapWidth (10) //
#define f3MapSize (100) //
#define c3MapNum (16) //
#define c3optNum (2400) // 16 * 6 * 5 * 5
#define c3MapSize (1600) //
/** S4 Layer Parameters **/
#define s4MapHigh (5) //
#define s4MapWidth (5) //
#define f4MapSize (25) //
#define s4MapNum (16) //
#define s4MapSize (400) //
/** C5 Layer Parameters **/
#define c5MapHigh (1) //
#define c5MapWidth (1) //
#define c5MapNum (120) //
#define c5optNum (48000) // 120 * 16 * 5 * 5
#define c5MapSize (120) //
/** Output Map Parameters **/
#define outMapSize (10) //
#define outoptNum (1200) //
#define TabNum (10) //
#define ACTIVATION(x) ((std::exp((x)) - std::exp(-1*(x))) / (std::exp((x)) + std::exp(-1*(x))))
#define ACTDEVICE(x) ((1 - ((x))*((x))))
#define O (true)
#define X (false)
static const bool Twomap2ThreeTable[s2MapNum][c3MapNum] =
{
O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,
O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,
O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,
X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,
X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,
X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O
};
//static const bool Fourmap2FiveTabel[s4MapNum][c5MapNum];
class slqLeNet5
{
public:
slqLeNet5() = default;
slqLeNet5(const slqLeNet5 &lenet) = default;
slqLeNet5 & operator = (const slqLeNet5 &lenet) = default;
~slqLeNet5();
void init();
void train();
private:
void initMap();
void ForwardC1();
void ForwardS2();
void ForwardC3();
void ForwardS4();
void ForwardC5();
void ForwardOut();
void BackwardOut();
void BackwardC5();
void BackwardS4();
void BackwardC3();
void BackwardS2();
void BackwardC1();
void UpgradeNetwork();
void UpdateParameters(double *delta, double *Edelta, double *para, int len);
double test();
void SaveParameters();
void RandomBias(double *randVector, int vLen);
void ProduceLabel();
void RegularMap(char *cmap, double *mapdata);
void uniform_rand(double* src, int len, double min, double max);
double uniform_rand(double min, double max);
private:
double c1map[c1MapSize];
double c1bias[c1MapNum];
double c1conv[c1optNum];
double c1bias_dt[c1MapNum];
double c1bias_Edt[c1MapNum];
double c1conv_dt[c1optNum];
double c1conv_Edt[c1optNum];
double c1delta[c1MapSize];
double s2map[s2MapSize];
double s2bias[s2MapNum];
double s2pool[s2MapNum];
double s2bias_dt[s2MapNum];
double s2bias_Edt[s2MapNum];
double s2pool_dt[s2MapNum];
double s2pool_Edt[s2MapNum];
double s2delta[s2MapSize];
double c3map[c3MapSize];
double c3bias[c3MapNum];
double c3conv[c3optNum];
double c3bias_dt[c3MapNum];
double c3bias_Edt[c3MapNum];
double c3conv_dt[c3optNum];
double c3conv_Edt[c3optNum];
double c3delta[c3MapSize];
double s4map[s4MapSize];
double s4bias[s4MapNum];
double s4pool[s4MapNum];
double s4bias_dt[s4MapNum];
double s4bias_Edt[s4MapNum];
double s4pool_dt[s4MapNum];
double s4pool_Edt[s4MapNum];
double s4delta[s4MapSize];
double c5map[c5MapSize];
double c5bias[c5MapNum];
double c5conv[c5optNum];
double c5bias_dt[c5MapNum];
double c5bias_Edt[c5MapNum];
double c5conv_dt[c5optNum];
double c5conv_Edt[c5optNum];
double c5delta[c5MapSize];
double outmap[outMapSize];
double outbias[outMapSize];
double outfullconn[outoptNum];
double outbias_dt[outMapSize];
double outbias_Edt[outMapSize];
double outfull_dt[outoptNum];
double outfull_Edt[outoptNum];
double outdelta[outMapSize];
double label[outMapSize];
char *trainImg;
char *trainLabel;
char *testImg;
char *testLabel;
double *trainData;
double *testData;
double *inmap;
char *tstLabelPtr;
char *curLabelPtr;
bool inited = false;
bool trainInit = false;
bool testInit = false;
int mIdx;
int hIdx;
int vIdx;
int chIdx;
int cvIdx;
int layermulti;
int hmulti;
LARGE_INTEGER StartCount;
LARGE_INTEGER EndCount;
LARGE_INTEGER CountFreq;
};
} // end namespace slqDL
#endif
.cpp
#include
#include
#include
#include
#include
#include "slqLeNet5.h"
using namespace std;
namespace slqDL {
slqLeNet5::~slqLeNet5()
{
if (trainImg)
{
delete trainImg;
trainImg = nullptr;
}
if (trainLabel)
{
delete trainLabel;
trainLabel = nullptr;
}
if (testImg)
{
delete testImg;
testImg = nullptr;
}
if (testLabel)
{
delete testLabel;
testLabel = nullptr;
}
if (trainData)
{
delete trainData;
trainData = nullptr;
}
if (testData)
{
delete testData;
testData = nullptr;
}
}
void slqLeNet5::init()
{
ifstream trainimgstream;
ifstream testimgstream;
ifstream trainLabelStream;
ifstream testLabelStream;
initMap();
trainimgstream.open("train-images.idx3-ubytepad", ifstream::binary | ifstream::in);
trainimgstream.read(trainImg, TrainImgNum*imageRow*imageCol);
trainimgstream.close();
testimgstream.open("t10k-images.idx3-ubytepad", ifstream::binary | ifstream::in);
testimgstream.read(testImg, TestImgNum*imageRow*imageCol);
testimgstream.close();
trainLabelStream.open("train-labels.idx1-ubytepad", ifstream::binary | ifstream::in);
trainLabelStream.read(trainLabel, TrainImgNum);
trainLabelStream.close();
testLabelStream.open("t10k-labels.idx1-ubytepad", ifstream::binary | ifstream::in);
testLabelStream.read(testLabel, TestImgNum);
testLabelStream.close();
inited = true;
}
void slqLeNet5::train()
{
if (false == inited)
return;
int eidx = 0;
int lidx = 0;
double precised = 0;
double errored = 0;
double cur_accur = 0;
double currentErr;
int Tab = 0;
int TabCount = 0;
for (eidx = 0; eidx < EpochLoop; eidx++)
{
cout << "epoch " << eidx << endl;
QueryPerformanceFrequency(&CountFreq);
QueryPerformanceCounter(&StartCount);
for (lidx = 0; lidx < TrainImgNum; lidx++)
{
curLabelPtr = trainLabel + lidx;
ProduceLabel();
if (false == trainInit)
{
RegularMap(trainImg + lidx * inMapSize, trainData + lidx * inMapSize);
}
inmap = trainData + lidx * inMapSize;
ForwardC1();
ForwardS2();
ForwardC3();
ForwardS4();
ForwardC5();
ForwardOut();
BackwardOut();
BackwardC5();
BackwardS4();
BackwardC3();
BackwardS2();
BackwardC1();
UpgradeNetwork();
}
QueryPerformanceCounter(&EndCount);
cout << " time " << (double)(EndCount.QuadPart - StartCount.QuadPart) / CountFreq.QuadPart << endl;
if (false == trainInit)
{
delete trainImg;
trainImg = nullptr;
trainInit = true;
}
cur_accur = test();
cout << "current accuracy " << cur_accur << endl;
if (cur_accur >= AccuracyRate)
{
cout << "current accuracy " << cur_accur << endl;
SaveParameters();
return;
}
}
if (eidx == EpochLoop)
{
cout << "Non-precise current accuracy " << cur_accur << endl;
SaveParameters();
}
}
void slqLeNet5::initMap()
{
srand(time(0) + rand());
const double scale = 6.0;
double min_ = -std::sqrt(scale / (25.0 + 150.0));
double max_ = std::sqrt(scale / (25.0 + 150.0));
uniform_rand(c1conv, c1optNum, min_, max_);
min_ = -std::sqrt(scale / (4.0 + 1.0));
max_ = std::sqrt(scale / (4.0 + 1.0));
uniform_rand(s2pool, s2MapNum, min_, max_);
min_ = -std::sqrt(scale / (150.0 + 400.0));
max_ = std::sqrt(scale / (150.0 + 400.0));
uniform_rand(c3conv, c3optNum, min_, max_);
min_ = -std::sqrt(scale / (4.0 + 1.0));
max_ = std::sqrt(scale / (4.0 + 1.0));
uniform_rand(s4pool, s4MapNum, min_, max_);
min_ = -std::sqrt(scale / (400.0 + 3000.0));
max_ = std::sqrt(scale / (400.0 + 3000.0));
uniform_rand(c5conv, c5optNum, min_, max_);
min_ = -std::sqrt(scale / (120.0 + 10.0));
max_ = std::sqrt(scale / (120.0 + 10.0));
uniform_rand(outfullconn, outoptNum, min_, max_);
std::fill(c1bias, c1bias + c1MapNum, 0.0);
std::fill(s2bias, s2bias + s2MapNum, 0.0);
std::fill(c3bias, c3bias + c3MapNum, 0.0);
std::fill(s4bias, s4bias + s4MapNum, 0.0);
std::fill(c5bias, c5bias + c5MapNum, 0.0);
std::fill(outbias, outbias + outMapSize, 0.0);
std::fill(c1bias_Edt, c1bias_Edt + c1MapNum, 0.0);
std::fill(c1conv_Edt, c1conv_Edt + c1optNum, 0.0);
std::fill(s2bias_Edt, s2bias_Edt + s2MapNum, 0.0);
std::fill(s2pool_Edt, s2pool_Edt + s2MapNum, 0.0);
std::fill(c3bias_Edt, c3bias_Edt + c3MapNum, 0.0);
std::fill(c3conv_Edt, c3conv_Edt + c3optNum, 0.0);
std::fill(s4bias_Edt, s4bias_Edt + s4MapNum, 0.0);
std::fill(s4pool_Edt, s4pool_Edt + s4MapNum, 0.0);
std::fill(c5bias_Edt, c5bias_Edt + c5MapNum, 0.0);
std::fill(c5conv_Edt, c5conv_Edt + c5optNum, 0.0);
std::fill(outbias_Edt, outbias_Edt + outMapSize, 0.0);
std::fill(outfull_Edt, outfull_Edt + outoptNum, 0.0);
trainImg = new char[TrainImgNum*inMapSize];
trainLabel = new char[TrainImgNum];
testImg = new char[TestImgNum*inMapSize];
testLabel = new char[TestImgNum];
trainData = new double[TrainImgNum*inMapSize];
testData = new double[TestImgNum*inMapSize];
}
void slqLeNet5::ForwardC1()
{
for (mIdx = 0; mIdx < c1MapNum; mIdx++)
{
layermulti = mIdx*f1MapSize;
for (hIdx = 0; hIdx < c1MapHigh; hIdx++)
{
hmulti = hIdx*c1MapWidth;
for (vIdx = 0; vIdx < c1MapWidth; vIdx++)
{
double *curmap = c1map + layermulti + hmulti + vIdx;
*curmap = 0.f;
for (chIdx = 0; chIdx < convhsize; chIdx++)
{
for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
{
*curmap += *(inmap + (hIdx + chIdx)*imageCol + vIdx + cvIdx) * (*(c1conv + mIdx*convsize + chIdx*convwsize + cvIdx));
}
}
*curmap += c1bias[mIdx];
*curmap = ACTIVATION(*curmap);
}
}
}
}
void slqLeNet5::ForwardS2()
{
for (mIdx = 0; mIdx < s2MapNum; mIdx++)
{
layermulti = mIdx * f2MapSize;
for (hIdx = 0; hIdx < s2MapHigh; hIdx++)
{
hmulti = hIdx*s2MapWidth;
for (vIdx = 0; vIdx < s2MapWidth; vIdx++)
{
double *curmap = s2map + layermulti + hmulti + vIdx;
double *c1cur = c1map + layermulti*4 + (hIdx * 2)*c1MapWidth + (vIdx * 2);
*curmap = s2bias[mIdx] + s2pool[mIdx] * (*c1cur + *(c1cur + 1) + *(c1cur + c1MapWidth) + *(c1cur + c1MapWidth + 1)) / 4.0;
*curmap = ACTIVATION(*curmap);
}
}
}
}
void slqLeNet5::ForwardC3()
{
for (mIdx = 0; mIdx < c3MapNum; mIdx++)
{
layermulti = mIdx*f3MapSize;
for (hIdx = 0; hIdx < c3MapHigh; hIdx++)
{
hmulti = hIdx * c3MapWidth;
for (vIdx = 0; vIdx < c3MapWidth; vIdx++)
{
double *curmap = c3map + layermulti + hmulti + vIdx;
*curmap = 0;
for (int pIdx = 0; pIdx < s2MapNum; pIdx++)
{
//if (!Twomap2ThreeTable[pIdx][mIdx])
// continue;
double *curconv = c3conv + mIdx * s2MapNum * convsize + pIdx * convsize;
for (chIdx = 0; chIdx < convhsize; chIdx++)
{
for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
{
*curmap += *(s2map + pIdx * f2MapSize + (hIdx + chIdx)*s2MapWidth + (vIdx + cvIdx)) * (*(curconv + chIdx*convwsize + cvIdx));
}
}
}
*curmap += c3bias[mIdx];
*curmap = ACTIVATION(*curmap);
}
}
}
}
void slqLeNet5::ForwardS4()
{
for (mIdx = 0; mIdx < s4MapNum; mIdx++)
{
for (hIdx = 0; hIdx < s4MapHigh; hIdx++)
{
for (vIdx = 0; vIdx < s4MapWidth; vIdx++)
{
double *curmap = s4map + mIdx*f4MapSize + hIdx*s4MapHigh + vIdx;
double *curc3 = c3map + mIdx*f3MapSize + (hIdx * 2)*c3MapWidth + (vIdx * 2);
*curmap = s4bias[mIdx] + s4pool[mIdx] * (*curc3 + *(curc3+1) + *(curc3+c3MapWidth) + *(curc3+c3MapWidth+1)) / 4.0;
*curmap = ACTIVATION(*curmap);
}
}
}
}
void slqLeNet5::ForwardC5()
{
for (mIdx = 0; mIdx < c5MapNum; mIdx++)
{
for (hIdx = 0; hIdx < c5MapHigh; hIdx++)
{
for (vIdx = 0; vIdx < c5MapWidth; vIdx++)
{
double *curmap = c5map + mIdx*c5MapHigh*c5MapWidth + hIdx*c5MapHigh + vIdx;
*curmap = 0;
for (int pIdx = 0; pIdx < s4MapNum; pIdx++)
{
double *curconv = c5conv + mIdx * s4MapNum * convsize + pIdx * convsize;
for (chIdx = 0; chIdx < convhsize; chIdx++)
{
for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
{
*curmap += *(s4map + pIdx * f4MapSize + (hIdx + chIdx)*s4MapWidth + (vIdx + cvIdx)) * (*(curconv + chIdx*convhsize + cvIdx));
}
}
}
*curmap += c5bias[mIdx];
*curmap = ACTIVATION(*curmap);
}
}
}
}
void slqLeNet5::ForwardOut()
{
for (mIdx = 0; mIdx < outMapSize; mIdx++)
{
double *curmap = outmap + mIdx;
*curmap = 0.f;
for (hIdx = 0; hIdx < c5MapNum; hIdx++)
{
*curmap += c5map[hIdx] * outfullconn[hIdx*outMapSize+mIdx];
}
*curmap += outbias[mIdx];
*curmap = ACTIVATION(*curmap);
}
}
void slqLeNet5::BackwardOut()
{
for (mIdx = 0; mIdx < outMapSize; mIdx++)
{
outdelta[mIdx] = (outmap[mIdx] - label[mIdx]) * (ACTDEVICE(outmap[mIdx]));
outbias_dt[mIdx] = outdelta[mIdx];
}
for (hIdx = 0; hIdx < c5MapNum; hIdx++)
{
for (vIdx = 0; vIdx < outMapSize; vIdx++)
{
outfull_dt[hIdx*outMapSize + vIdx] = c5map[hIdx] * outdelta[vIdx];
}
}
}
void slqLeNet5::BackwardC5()
{
for (mIdx = 0; mIdx < c5MapNum; mIdx++)
{
double curerr = 0;
for (vIdx = 0; vIdx < outMapSize; vIdx++)
{
curerr += outdelta[vIdx] * outfullconn[mIdx*outMapSize + vIdx];
}
c5delta[mIdx] = ACTDEVICE(c5map[mIdx]) * curerr;
c5bias_dt[mIdx] = c5delta[mIdx];
for (int pIdx = 0; pIdx < s4MapNum; pIdx++)
{
double *curconv = c5conv_dt + mIdx * s4MapNum * convsize + pIdx*convsize;
for (hIdx = 0; hIdx < s4MapHigh; hIdx++)
{
for (vIdx = 0; vIdx < s4MapWidth; vIdx++)
{
double *convdlt = curconv + hIdx*convwsize + vIdx;
*convdlt = *(s4map + pIdx * f4MapSize + hIdx * convhsize + vIdx) * c5delta[mIdx];
}
}
}
}
}
void slqLeNet5::BackwardS4()
{
for (mIdx = 0; mIdx < s4MapNum; mIdx++)
{
s4bias_dt[mIdx] = 0;
s4pool_dt[mIdx] = 0;
layermulti = mIdx*f4MapSize;
for (hIdx = 0; hIdx < s4MapHigh; hIdx++)
{
for (vIdx = 0; vIdx < s4MapWidth; vIdx++)
{
double *curdt = s4delta + layermulti + hIdx*s4MapWidth + vIdx;
double *curmap = s4map + layermulti + hIdx*s4MapWidth + vIdx;
double *c3cur = c3map + mIdx*f3MapSize + (hIdx * 2)*c3MapWidth + (vIdx * 2);
*curdt = 0;
for (int pIdx = 0; pIdx < c5MapNum; pIdx++)
{
double *curconv = c5conv + pIdx * s4MapNum * convsize + mIdx * convsize + hIdx * convhsize + vIdx;
*curdt += *curconv * c5delta[pIdx];
}
*curdt = ACTDEVICE(*curmap) * (*curdt);
s4bias_dt[mIdx] += *curdt;
s4pool_dt[mIdx] += (*c3cur + *(c3cur + 1) + *(c3cur + c3MapWidth) + *(c3cur + c3MapWidth + 1)) / 4.0 * (*curdt);
}
}
}
}
void slqLeNet5::BackwardC3()
{
for (mIdx = 0; mIdx < c3MapNum; mIdx++)
{
c3bias_dt[mIdx] = 0;
layermulti = mIdx*f3MapSize;
for (hIdx = 0; hIdx < c3MapHigh; hIdx++)
{
for (vIdx = 0; vIdx < c3MapWidth; vIdx++)
{
double *curdt = c3delta + layermulti + hIdx*c3MapWidth + vIdx;
double *curmap = c3map + layermulti + hIdx*c3MapWidth + vIdx;
double *s4dt = s4delta + mIdx*f4MapSize + (hIdx / 2)*s4MapWidth + (vIdx / 2);
*curdt = ACTDEVICE(*curmap)*s4pool[mIdx] * (*s4dt) / 4.0;
c3bias_dt[mIdx] += *curdt;
}
}
for (int pIdx = 0; pIdx < s2MapNum; pIdx++)
{
//if (!Twomap2ThreeTable[pIdx][mIdx])
// continue;
double *curcdt = c3conv_dt + mIdx * s2MapNum * convsize + pIdx * convsize;
for (chIdx = 0; chIdx < convhsize; chIdx++)
{
for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
{
double *curconv = curcdt + chIdx * convhsize + cvIdx;
*curconv = 0;
for (hIdx = 0; hIdx < c3MapHigh; hIdx++)
{
for (vIdx = 0; vIdx < c3MapWidth; vIdx++)
{
*curconv += *(s2map + pIdx * f2MapSize + (chIdx + hIdx)*s2MapWidth + (cvIdx + vIdx)) * c3delta[layermulti + hIdx*c3MapWidth + vIdx];
}
}
}
}
}
}
}
void slqLeNet5::BackwardS2()
{
memset((char*)s2delta, 0x0, s2MapSize*sizeof(double));
for (mIdx = 0; mIdx < s2MapNum; mIdx++)
{
for (hIdx = 0; hIdx < c3MapHigh; hIdx++)
{
for (vIdx = 0; vIdx < c3MapWidth; vIdx++)
{
for (int pIdx = 0; pIdx < c3MapNum; pIdx++)
{
//if (!Twomap2ThreeTable[mIdx][pIdx])
// continue;
double *curconv = c3conv + pIdx * s2MapNum * convsize + mIdx * convsize;
for (chIdx = 0; chIdx < convhsize; chIdx++)
{
for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
{
double *curdt = s2delta + mIdx * f2MapSize + (hIdx + chIdx)*s2MapWidth + (vIdx + cvIdx);
*curdt += *(curconv + chIdx * convwsize + cvIdx) * (*(c3delta + pIdx * f3MapSize + hIdx*c3MapWidth + vIdx));
}
}
}
}
}
s2bias_dt[mIdx] = 0;
s2pool_dt[mIdx] = 0;
for (hIdx = 0; hIdx < s2MapHigh; hIdx++)
{
for (vIdx = 0; vIdx < s2MapWidth; vIdx++)
{
double *curdt = s2delta + mIdx * f2MapSize + hIdx * s2MapWidth + vIdx;
double *curmap = s2map + mIdx * f2MapSize + hIdx * s2MapWidth + vIdx;
double *curc1 = c1map+mIdx*f1MapSize+(hIdx*2)*c1MapWidth+(vIdx*2);
*curdt *= ACTDEVICE(*curmap);
s2bias_dt[mIdx] += *curdt;
s2pool_dt[mIdx] += *curdt * (*curc1 + *(curc1 + 1) + *(curc1 + c1MapWidth) + *(curc1 + c1MapWidth + 1)) / 4;
}
}
}
}
void slqLeNet5::BackwardC1()
{
for (mIdx = 0; mIdx < c1MapNum; mIdx++)
{
c1bias_dt[mIdx] = 0;
layermulti = mIdx*f1MapSize;
for (hIdx = 0; hIdx < c1MapHigh; hIdx++)
{
for (vIdx = 0; vIdx < c1MapWidth; vIdx++)
{
double *curdt = c1delta + layermulti + hIdx*c1MapWidth + vIdx;
double *curmap = c1map + layermulti + hIdx*c1MapWidth + vIdx;
*curdt = ACTDEVICE(*curmap) * s2pool[mIdx] * s2delta[layermulti/4 +(hIdx/2)*s2MapWidth+(vIdx/2)] / 4.0;
c1bias_dt[mIdx] += *curdt;
}
}
for (chIdx = 0; chIdx < convhsize; chIdx++)
{
for (cvIdx = 0; cvIdx < convwsize; cvIdx++)
{
double *cdt = c1conv_dt + mIdx*convsize + chIdx*convwsize + cvIdx;
*cdt = 0;
for (hIdx = 0; hIdx < c1MapHigh; hIdx++)
{
for (vIdx = 0; vIdx < c1MapWidth; vIdx++)
{
*cdt += inmap[(hIdx + chIdx) * inMapWidth + (vIdx + cvIdx)] * c1delta[mIdx*c1MapHigh*c1MapWidth + hIdx*c1MapWidth + vIdx];
}
}
}
}
}
}
void slqLeNet5::UpgradeNetwork()
{
UpdateParameters(c1bias_dt, c1bias_Edt, c1bias, c1MapNum);
UpdateParameters(c1conv_dt, c1conv_Edt, c1conv, c1optNum);
UpdateParameters(s2bias_dt, s2bias_Edt, s2bias, s2MapNum);
UpdateParameters(s2pool_dt, s2pool_Edt, s2pool, s2MapNum);
UpdateParameters(c3bias_dt, c3bias_Edt, c3bias, c3MapNum);
UpdateParameters(c3conv_dt, c3conv_Edt, c3conv, c3optNum);
UpdateParameters(s4bias_dt, s4bias_Edt, s2bias, s4MapNum);
UpdateParameters(s4pool_dt, s4pool_Edt, s2pool, s4MapNum);
UpdateParameters(c5bias_dt, c5bias_Edt, c5bias, c5MapNum);
UpdateParameters(c5conv_dt, c5conv_Edt, c5conv, c5optNum);
UpdateParameters(outbias_dt, outbias_Edt, outbias, outMapSize);
UpdateParameters(outfull_dt, outfull_Edt, outfullconn, outoptNum);
}
void slqLeNet5::UpdateParameters(double *delta, double *Edelta, double *para, int len)
{
for (int lIdx = 0; lIdx < len; lIdx++)
{
Edelta[lIdx] += delta[lIdx] * delta[lIdx];
para[lIdx] -= Alpha * delta[lIdx] / (std::sqrt(Edelta[lIdx]) + EspCNN);
}
}
double slqLeNet5::test()
{
double preci = 0.f;
double maxpred = 0.f;
double maxEnd = -10.f;
int tIdx;
int oIdx;
int dIdx;
for (tIdx = 0; tIdx < TestImgNum; tIdx++)
{
maxEnd = -10.f;
tstLabelPtr = testLabel + tIdx;
if (false == testInit)
RegularMap(testImg + tIdx*inMapSize, testData + tIdx*inMapSize);
inmap = testData + tIdx*inMapSize;
ForwardC1();
ForwardS2();
ForwardC3();
ForwardS4();
ForwardC5();
ForwardOut();
maxpred = -10;
for (oIdx = 0; oIdx < outMapSize; oIdx++)
{
maxpred = maxpred > outmap[oIdx] ? maxpred : outmap[oIdx];
dIdx = maxpred > outmap[oIdx] ? dIdx : oIdx;
}
if (dIdx == (int)(*tstLabelPtr))
{
preci += 1.0;
}
}
if (false == testInit)
{
delete testImg;
testImg = nullptr;
testInit = true;
}
return (preci / TestImgNum);
}
void slqLeNet5::RandomBias(double *randVector, int vLen)
{
int tIdx = 0;
for (; tIdx < vLen; tIdx++)
{
randVector[tIdx] = 0;
}
}
void slqLeNet5::SaveParameters()
{
fstream fileStream;
fileStream.open("parameters_block", ifstream::out | ifstream::binary);
fileStream << c1conv << c1bias << s2pool << s2bias << \
c3conv << c3bias << s4pool << s4bias << c5conv << c5bias << outfullconn << outbias;
fileStream.close();
}
void slqLeNet5::ProduceLabel()
{
int tIdx;
for (tIdx = 0; tIdx < outMapSize; tIdx++)
{
if (tIdx == *curLabelPtr)
label[tIdx] = 0.8f;
else
label[tIdx] = -0.8f;
}
}
void slqLeNet5::RegularMap(char *cmap, double *rmap)
{
for (hIdx = 0; hIdx < imageRow; hIdx++)
{
for (vIdx = 0; vIdx < imageCol; vIdx++)
{
int tmp = (int)(*(cmap + hIdx*imageCol + vIdx));
tmp = tmp < 0 ? (tmp + 256) : tmp;
*(rmap + hIdx*imageCol + vIdx) = tmp / 255.0 * 2 - 1;
}
}
}
void slqLeNet5::uniform_rand(double* src, int len, double min, double max)
{
for (int i = 0; i < len; i++) {
src[i] = uniform_rand(min, max);
}
}
double slqLeNet5::uniform_rand(double min, double max)
{
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution dst(min, max);
return dst(gen);
}
}
对于经过大数据训练的CNN结构可以实现识别0~9这十个数字,表示很惊奇。
上面的保存参数到文件的函数整错了,修改为下面:
void slqLeNet5::SaveParameters()
{
ofstream fileStream;
fileStream.open("parameters_block", ofstream::out | ofstream::binary);
fileStream.write((char*)c1conv, sizeof(double)*c1optNum);
fileStream.write((char*)c1bias, sizeof(double)*c1MapNum);
fileStream.write((char*)s2pool, sizeof(double)*s2MapNum);
fileStream.write((char*)s2bias, sizeof(double)*s2MapNum);
fileStream.write((char*)c3conv, sizeof(double)*c3optNum);
fileStream.write((char*)c3bias, sizeof(double)*c3MapNum);
fileStream.write((char*)s4pool, sizeof(double)*s4MapNum);
fileStream.write((char*)s4bias, sizeof(double)*s4MapNum);
fileStream.write((char*)c5conv, sizeof(double)*c5optNum);
fileStream.write((char*)c5bias, sizeof(double)*c5MapNum);
fileStream.write((char*)outfullconn, sizeof(double)*outoptNum);
fileStream.write((char*)outbias, sizeof(double)*outMapSize);
fileStream.close();
}
void slqLeNet5::ReadParameters()
{
ifstream fileStream;
fileStream.open("parameters_block", ifstream::in | ifstream::binary);
fileStream.read((char*)c1conv, sizeof(double)*c1optNum);
fileStream.read((char*)c1bias, sizeof(double)*c1MapNum);
fileStream.read((char*)s2pool, sizeof(double)*s2MapNum);
fileStream.read((char*)s2bias, sizeof(double)*s2MapNum);
fileStream.read((char*)c3conv, sizeof(double)*c3optNum);
fileStream.read((char*)c3bias, sizeof(double)*c3MapNum);
fileStream.read((char*)s4pool, sizeof(double)*s4MapNum);
fileStream.read((char*)s4bias, sizeof(double)*s4MapNum);
fileStream.read((char*)c5conv, sizeof(double)*c5optNum);
fileStream.read((char*)c5bias, sizeof(double)*c5MapNum);
fileStream.read((char*)outfullconn, sizeof(double)*outoptNum);
fileStream.read((char*)outbias, sizeof(double)*outMapSize);
fileStream.close();
}
预测
void slqLeNet5::predict()
{
HANDLE hfile;
LPCTSTR lpFileName = L".\\*.png";
WIN32_FIND_DATA pNextInfo;
double imgArray[inMapSize];
double *imgAPtr;
char *imgBPtr;
char imgName[20];
char ch;
double maxV = -10;
int maxIdx = -1;
ReadParameters();
while (true)
{
hfile = FindFirstFile(lpFileName, &pNextInfo);
if (INVALID_HANDLE_VALUE == hfile)
{
cout << "Error Handle to get First File" << endl;
break;
}
while (FindNextFile(hfile, &pNextInfo))
{
cv::Mat image;
for (hIdx = 0; hIdx < (wcslen(pNextInfo.cFileName) + 2); hIdx++)
{
imgName[hIdx] = pNextInfo.cFileName[hIdx];
}
imgName[hIdx] = '\0';
cout << "image " << imgName << endl;
image = cv::imread(imgName, CV_8UC1);
if (image.empty())
{
cout << "Error image" << endl;
return;
}
std::fill(imgArray, imgArray + inMapSize, -1);
imgAPtr = imgArray;
imgBPtr = (char*)image.data;
double topV = -256;
double botV = 256;
for (hIdx = 0; hIdx < RawImgRow; hIdx++)
{
for (vIdx = 0; vIdx < RawImgCol; vIdx++)
{
int tmp = (int)(*(imgBPtr + hIdx*RawImgCol + vIdx));
tmp = tmp < 0 ? (tmp + 256) : tmp;
topV = topV > tmp ? topV : tmp;
botV = botV < tmp ? botV : tmp;
}
}
for (hIdx = 0; hIdx < RawImgRow; hIdx++)
{
for (vIdx = 0; vIdx < RawImgCol; vIdx++)
{
int tmp = (int)(*(imgBPtr + hIdx*RawImgCol + vIdx));
tmp = tmp < 0 ? (tmp + 256) : tmp;
*(imgAPtr + (hIdx + 2)*imageCol + (vIdx + 2)) = (tmp - botV) / (topV - botV) * 2.0 - 1.0;
}
}
inmap = imgArray;
ForwardC1();
ForwardS2();
ForwardC3();
ForwardS4();
ForwardC5();
ForwardOut();
maxV = -10;
maxIdx = -1;
for (hIdx = 0; hIdx < outMapSize; hIdx++)
{
maxV = maxV > outmap[hIdx] ? maxV : outmap[hIdx];
maxIdx = maxV > outmap[hIdx] ? maxIdx : hIdx;
}
cout << "predict " << maxIdx << endl;
ch = getchar();
if ('!' == ch)
return;
}
}
}