使用OPENCV训练手写数字识别分类器
1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
3,确定字符特征方式为最简单的8×8网格内的字符点数
4,创建SVM,训练并读取,结果如下
1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
10000个训练样本,测试数据正确率95.45%
60000个训练样本,测试数据正确率97.67%
5,编写手写输入的GUI程序,并进行验证,效果还可以接受。
以下为主要代码,以供参考
(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)
#include "stdafx.h" #include #include "opencv2/opencv.hpp" #include using namespace std; using namespace cv; #define SHOW_PROCESS 0 #define ON_STUDY 0 class NumTrainData { public: NumTrainData() { memset(data, 0, sizeof(data)); result = -1; } public: float data[64]; int result; }; vector buffer; int featureLen = 64; void swapBuffer(char* buf) { char temp; temp = *(buf); *buf = *(buf+3); *(buf+3) = temp; temp = *(buf+1); *(buf+1) = *(buf+2); *(buf+2) = temp; } void GetROI(Mat& src, Mat& dst) { int left, right, top, bottom; left = src.cols; right = 0; top = src.rows; bottom = 0; //Get valid area for(int i=0; i(i, j) > 0) { if(jright) right = j; if(ibottom) bottom = i; } } } //Point center; //center.x = (left + right) / 2; //center.y = (top + bottom) / 2; int width = right - left; int height = bottom - top; int len = (width < height) ? height : width; //Create a squre dst = Mat::zeros(len, len, CV_8UC1); //Copy valid data to squre center Rect dstRect((len - width)/2, (len - height)/2, width, height); Rect srcRect(left, top, width, height); Mat dstROI = dst(dstRect); Mat srcROI = src(srcRect); srcROI.copyTo(dstROI); } int ReadTrainData(int maxCount) { //Open image and label file const char fileName[] = "../res/train-images.idx3-ubyte"; const char labelFileName[] = "../res/train-labels.idx1-ubyte"; ifstream lab_ifs(labelFileName, ios_base::binary); ifstream ifs(fileName, ios_base::binary); if( ifs.fail() == true ) return -1; if( lab_ifs.fail() == true ) return -1; //Read train data number and image rows / cols char magicNum[4], ccount[4], crows[4], ccols[4]; ifs.read(magicNum, sizeof(magicNum)); ifs.read(ccount, sizeof(ccount)); ifs.read(crows, sizeof(crows)); ifs.read(ccols, sizeof(ccols)); int count, rows, cols; swapBuffer(ccount); swapBuffer(crows); swapBuffer(ccols); memcpy(&count, ccount, sizeof(count)); memcpy(&rows, crows, sizeof(rows)); memcpy(&cols, ccols, sizeof(cols)); //Just skip label header lab_ifs.read(magicNum, sizeof(magicNum)); lab_ifs.read(ccount, sizeof(ccount)); //Create source and show image matrix Mat src = Mat::zeros(rows, cols, CV_8UC1); Mat temp = Mat::zeros(8, 8, CV_8UC1); Mat img, dst; char label = 0; Scalar templateColor(255, 0, 255 ); NumTrainData rtd; //int loop = 1000; int total = 0; while(!ifs.eof()) { if(total >= count) break; total++; cout << total << endl; //Read label lab_ifs.read(&label, 1); label = label + '0'; //Read source data ifs.read((char*)src.data, rows * cols); GetROI(src, dst); #if(SHOW_PROCESS) //Too small to watch img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1); resize(dst, img, img.size()); stringstream ss; ss << "Number " << label; string text = ss.str(); putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); //imshow("img", img); #endif rtd.result = label; resize(dst, temp, temp.size()); //threshold(temp, temp, 10, 1, CV_THRESH_BINARY); for(int i = 0; i<8; i++) { for(int j = 0; j<8; j++) { rtd.data[ i*8 + j] = temp.at(i, j); } } buffer.push_back(rtd); //if(waitKey(0)==27) //ESC to quit // break; maxCount--; if(maxCount == 0) break; } ifs.close(); lab_ifs.close(); return 0; } void newRtStudy(vector& trainData) { int testCount = trainData.size(); Mat data = Mat::zeros(testCount, featureLen, CV_32FC1); Mat res = Mat::zeros(testCount, 1, CV_32SC1); for (int i= 0; i< testCount; i++) { NumTrainData td = trainData.at(i); memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float)); res.at(i, 0) = td.result; } /////////////START RT TRAINNING////////////////// CvRTrees forest; CvMat* var_importance = 0; forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(), CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER)); forest.save( "new_rtrees.xml" ); } int newRtPredict() { CvRTrees forest; forest.load( "new_rtrees.xml" ); const char fileName[] = "../res/t10k-images.idx3-ubyte"; const char labelFileName[] = "../res/t10k-labels.idx1-ubyte"; ifstream lab_ifs(labelFileName, ios_base::binary); ifstream ifs(fileName, ios_base::binary); if( ifs.fail() == true ) return -1; if( lab_ifs.fail() == true ) return -1; char magicNum[4], ccount[4], crows[4], ccols[4]; ifs.read(magicNum, sizeof(magicNum)); ifs.read(ccount, sizeof(ccount)); ifs.read(crows, sizeof(crows)); ifs.read(ccols, sizeof(ccols)); int count, rows, cols; swapBuffer(ccount); swapBuffer(crows); swapBuffer(ccols); memcpy(&count, ccount, sizeof(count)); memcpy(&rows, crows, sizeof(rows)); memcpy(&cols, ccols, sizeof(cols)); Mat src = Mat::zeros(rows, cols, CV_8UC1); Mat temp = Mat::zeros(8, 8, CV_8UC1); Mat m = Mat::zeros(1, featureLen, CV_32FC1); Mat img, dst; //Just skip label header lab_ifs.read(magicNum, sizeof(magicNum)); lab_ifs.read(ccount, sizeof(ccount)); char label = 0; Scalar templateColor(255, 0, 0); NumTrainData rtd; int right = 0, error = 0, total = 0; int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0; while(ifs.good()) { //Read label lab_ifs.read(&label, 1); label = label + '0'; //Read data ifs.read((char*)src.data, rows * cols); GetROI(src, dst); //Too small to watch img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3); resize(dst, img, img.size()); rtd.result = label; resize(dst, temp, temp.size()); //threshold(temp, temp, 10, 1, CV_THRESH_BINARY); for(int i = 0; i<8; i++) { for(int j = 0; j<8; j++) { m.at(0,j + i*8) = temp.at(i, j); } } if(total >= count) break; char ret = (char)forest.predict(m); if(ret == label) { right++; if(total <= 5000) right_1++; else right_2++; } else { error++; if(total <= 5000) error_1++; else error_2++; } total++; #if(SHOW_PROCESS) stringstream ss; ss << "Number " << label << ", predict " << ret; string text = ss.str(); putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); imshow("img", img); if(waitKey(0)==27) //ESC to quit break; #endif } ifs.close(); lab_ifs.close(); stringstream ss; ss << "Total " << total << ", right " << right <<", error " << error; string text = ss.str(); putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); imshow("img", img); waitKey(0); return 0; } void newSvmStudy(vector& trainData) { int testCount = trainData.size(); Mat m = Mat::zeros(1, featureLen, CV_32FC1); Mat data = Mat::zeros(testCount, featureLen, CV_32FC1); Mat res = Mat::zeros(testCount, 1, CV_32SC1); for (int i= 0; i< testCount; i++) { NumTrainData td = trainData.at(i); memcpy(m.data, td.data, featureLen*sizeof(float)); normalize(m, m); memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float)); res.at(i, 0) = td.result; } /////////////START SVM TRAINNING////////////////// CvSVM svm = CvSVM(); CvSVMParams param; CvTermCriteria criteria; criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON); param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria); svm.train(data, res, Mat(), Mat(), param); svm.save( "SVM_DATA.xml" ); } int newSvmPredict() { CvSVM svm = CvSVM(); svm.load( "SVM_DATA.xml" ); const char fileName[] = "../res/t10k-images.idx3-ubyte"; const char labelFileName[] = "../res/t10k-labels.idx1-ubyte"; ifstream lab_ifs(labelFileName, ios_base::binary); ifstream ifs(fileName, ios_base::binary); if( ifs.fail() == true ) return -1; if( lab_ifs.fail() == true ) return -1; char magicNum[4], ccount[4], crows[4], ccols[4]; ifs.read(magicNum, sizeof(magicNum)); ifs.read(ccount, sizeof(ccount)); ifs.read(crows, sizeof(crows)); ifs.read(ccols, sizeof(ccols)); int count, rows, cols; swapBuffer(ccount); swapBuffer(crows); swapBuffer(ccols); memcpy(&count, ccount, sizeof(count)); memcpy(&rows, crows, sizeof(rows)); memcpy(&cols, ccols, sizeof(cols)); Mat src = Mat::zeros(rows, cols, CV_8UC1); Mat temp = Mat::zeros(8, 8, CV_8UC1); Mat m = Mat::zeros(1, featureLen, CV_32FC1); Mat img, dst; //Just skip label header lab_ifs.read(magicNum, sizeof(magicNum)); lab_ifs.read(ccount, sizeof(ccount)); char label = 0; Scalar templateColor(255, 0, 0); NumTrainData rtd; int right = 0, error = 0, total = 0; int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0; while(ifs.good()) { //Read label lab_ifs.read(&label, 1); label = label + '0'; //Read data ifs.read((char*)src.data, rows * cols); GetROI(src, dst); //Too small to watch img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3); resize(dst, img, img.size()); rtd.result = label; resize(dst, temp, temp.size()); //threshold(temp, temp, 10, 1, CV_THRESH_BINARY); for(int i = 0; i<8; i++) { for(int j = 0; j<8; j++) { m.at(0,j + i*8) = temp.at(i, j); } } if(total >= count) break; normalize(m, m); char ret = (char)svm.predict(m); if(ret == label) { right++; if(total <= 5000) right_1++; else right_2++; } else { error++; if(total <= 5000) error_1++; else error_2++; } total++; #if(SHOW_PROCESS) stringstream ss; ss << "Number " << label << ", predict " << ret; string text = ss.str(); putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); imshow("img", img); if(waitKey(0)==27) //ESC to quit break; #endif } ifs.close(); lab_ifs.close(); stringstream ss; ss << "Total " << total << ", right " << right <<", error " << error; string text = ss.str(); putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); imshow("img", img); waitKey(0); return 0; } int main( int argc, char *argv[] ) { #if(ON_STUDY) int maxCount = 60000; ReadTrainData(maxCount); //newRtStudy(buffer); newSvmStudy(buffer); #else //newRtPredict(); newSvmPredict(); #endif return 0; }