使用OPENCV训练手写数字识别分类器 ,
另一个车牌识别开源代码:http://www.dexmac.com/index.php/software/114-plategatewayqt
1,下载训练数据和测试数据文件(http://yann.lecun.com/exdb/mnist/),这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个.
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端.MNIST数据格式如下:
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
3,确定字符特征方式为最简单的8×8网格内的字符点数
4,创建SVM,训练并读取,结果如下
1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
10000个训练样本,测试数据正确率95.45%
60000个训练样本,测试数据正确率97.67%
5,编写手写输入的GUI程序,并进行验证,效果还可以接受。
以下为主要代码,以供参考
(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)
#include "stdafx.h" #include <fstream> #include "opencv2/opencv.hpp" #include <vector> using namespace std; using namespace cv; #define SHOW_PROCESS 0 #define ON_STUDY 0 class NumTrainData { public: NumTrainData() { memset(data, 0, sizeof(data)); result = -1; } public: float data[64]; int result; }; vector<NumTrainData> buffer; int featureLen = 64; void swapBuffer(char* buf) { char temp; temp = *(buf); *buf = *(buf+3); *(buf+3) = temp; temp = *(buf+1); *(buf+1) = *(buf+2); *(buf+2) = temp; } void GetROI(Mat& src, Mat& dst) { int left, right, top, bottom; left = src.cols; right = 0; top = src.rows; bottom = 0; //Get valid area for(int i=0; i<src.rows; i++) { for(int j=0; j<src.cols; j++) { if(src.at<uchar>(i, j) > 0) { if(j<left) left = j; if(j>right) right = j; if(i<top) top = i; if(i>bottom) bottom = i; } } } //Point center; //center.x = (left + right) / 2; //center.y = (top + bottom) / 2; int width = right - left; int height = bottom - top; int len = (width < height) ? height : width; //Create a squre dst = Mat::zeros(len, len, CV_8UC1); //Copy valid data to squre center Rect dstRect((len - width)/2, (len - height)/2, width, height); Rect srcRect(left, top, width, height); Mat dstROI = dst(dstRect); Mat srcROI = src(srcRect); srcROI.copyTo(dstROI); } int ReadTrainData(int maxCount) { //Open image and label file const char fileName[] = "../res/train-images.idx3-ubyte"; const char labelFileName[] = "../res/train-labels.idx1-ubyte"; ifstream lab_ifs(labelFileName, ios_base::binary); ifstream ifs(fileName, ios_base::binary); if( ifs.fail() == true ) return -1; if( lab_ifs.fail() == true ) return -1; //Read train data number and image rows / cols char magicNum[4], ccount[4], crows[4], ccols[4]; ifs.read(magicNum, sizeof(magicNum)); ifs.read(ccount, sizeof(ccount)); ifs.read(crows, sizeof(crows)); ifs.read(ccols, sizeof(ccols)); int count, rows, cols; swapBuffer(ccount); swapBuffer(crows); swapBuffer(ccols); memcpy(&count, ccount, sizeof(count)); memcpy(&rows, crows, sizeof(rows)); memcpy(&cols, ccols, sizeof(cols)); //Just skip label header lab_ifs.read(magicNum, sizeof(magicNum)); lab_ifs.read(ccount, sizeof(ccount)); //Create source and show image matrix Mat src = Mat::zeros(rows, cols, CV_8UC1); Mat temp = Mat::zeros(8, 8, CV_8UC1); Mat img, dst; char label = 0; Scalar templateColor(255, 0, 255 ); NumTrainData rtd; //int loop = 1000; int total = 0; while(!ifs.eof()) { if(total >= count) break; total++; cout << total << endl; //Read label lab_ifs.read(&label, 1); label = label + '0'; //Read source data ifs.read((char*)src.data, rows * cols); GetROI(src, dst); #if(SHOW_PROCESS) //Too small to watch img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1); resize(dst, img, img.size()); stringstream ss; ss << "Number " << label; string text = ss.str(); putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); //imshow("img", img); #endif rtd.result = label; resize(dst, temp, temp.size()); //threshold(temp, temp, 10, 1, CV_THRESH_BINARY); for(int i = 0; i<8; i++) { for(int j = 0; j<8; j++) { rtd.data[ i*8 + j] = temp.at<uchar>(i, j); } } buffer.push_back(rtd); //if(waitKey(0)==27) //ESC to quit // break; maxCount--; if(maxCount == 0) break; } ifs.close(); lab_ifs.close(); return 0; } void newRtStudy(vector<NumTrainData>& trainData) { int testCount = trainData.size(); Mat data = Mat::zeros(testCount, featureLen, CV_32FC1); Mat res = Mat::zeros(testCount, 1, CV_32SC1); for (int i= 0; i< testCount; i++) { NumTrainData td = trainData.at(i); memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float)); res.at<unsigned int>(i, 0) = td.result; } /////////////START RT TRAINNING////////////////// CvRTrees forest; CvMat* var_importance = 0; forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(), CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER)); forest.save( "new_rtrees.xml" ); } int newRtPredict() { CvRTrees forest; forest.load( "new_rtrees.xml" ); const char fileName[] = "../res/t10k-images.idx3-ubyte"; const char labelFileName[] = "../res/t10k-labels.idx1-ubyte"; ifstream lab_ifs(labelFileName, ios_base::binary); ifstream ifs(fileName, ios_base::binary); if( ifs.fail() == true ) return -1; if( lab_ifs.fail() == true ) return -1; char magicNum[4], ccount[4], crows[4], ccols[4]; ifs.read(magicNum, sizeof(magicNum)); ifs.read(ccount, sizeof(ccount)); ifs.read(crows, sizeof(crows)); ifs.read(ccols, sizeof(ccols)); int count, rows, cols; swapBuffer(ccount); swapBuffer(crows); swapBuffer(ccols); memcpy(&count, ccount, sizeof(count)); memcpy(&rows, crows, sizeof(rows)); memcpy(&cols, ccols, sizeof(cols)); Mat src = Mat::zeros(rows, cols, CV_8UC1); Mat temp = Mat::zeros(8, 8, CV_8UC1); Mat m = Mat::zeros(1, featureLen, CV_32FC1); Mat img, dst; //Just skip label header lab_ifs.read(magicNum, sizeof(magicNum)); lab_ifs.read(ccount, sizeof(ccount)); char label = 0; Scalar templateColor(255, 0, 0); NumTrainData rtd; int right = 0, error = 0, total = 0; int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0; while(ifs.good()) { //Read label lab_ifs.read(&label, 1); label = label + '0'; //Read data ifs.read((char*)src.data, rows * cols); GetROI(src, dst); //Too small to watch img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3); resize(dst, img, img.size()); rtd.result = label; resize(dst, temp, temp.size()); //threshold(temp, temp, 10, 1, CV_THRESH_BINARY); for(int i = 0; i<8; i++) { for(int j = 0; j<8; j++) { m.at<float>(0,j + i*8) = temp.at<uchar>(i, j); } } if(total >= count) break; char ret = (char)forest.predict(m); if(ret == label) { right++; if(total <= 5000) right_1++; else right_2++; } else { error++; if(total <= 5000) error_1++; else error_2++; } total++; #if(SHOW_PROCESS) stringstream ss; ss << "Number " << label << ", predict " << ret; string text = ss.str(); putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); imshow("img", img); if(waitKey(0)==27) //ESC to quit break; #endif } ifs.close(); lab_ifs.close(); stringstream ss; ss << "Total " << total << ", right " << right <<", error " << error; string text = ss.str(); putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); imshow("img", img); waitKey(0); return 0; } void newSvmStudy(vector<NumTrainData>& trainData) { int testCount = trainData.size(); Mat m = Mat::zeros(1, featureLen, CV_32FC1); Mat data = Mat::zeros(testCount, featureLen, CV_32FC1); Mat res = Mat::zeros(testCount, 1, CV_32SC1); for (int i= 0; i< testCount; i++) { NumTrainData td = trainData.at(i); memcpy(m.data, td.data, featureLen*sizeof(float)); normalize(m, m); memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float)); res.at<unsigned int>(i, 0) = td.result; } /////////////START SVM TRAINNING////////////////// CvSVM svm = CvSVM(); CvSVMParams param; CvTermCriteria criteria; criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON); param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria); svm.train(data, res, Mat(), Mat(), param); svm.save( "SVM_DATA.xml" ); } int newSvmPredict() { CvSVM svm = CvSVM(); svm.load( "SVM_DATA.xml" ); const char fileName[] = "../res/t10k-images.idx3-ubyte"; const char labelFileName[] = "../res/t10k-labels.idx1-ubyte"; ifstream lab_ifs(labelFileName, ios_base::binary); ifstream ifs(fileName, ios_base::binary); if( ifs.fail() == true ) return -1; if( lab_ifs.fail() == true ) return -1; char magicNum[4], ccount[4], crows[4], ccols[4]; ifs.read(magicNum, sizeof(magicNum)); ifs.read(ccount, sizeof(ccount)); ifs.read(crows, sizeof(crows)); ifs.read(ccols, sizeof(ccols)); int count, rows, cols; swapBuffer(ccount); swapBuffer(crows); swapBuffer(ccols); memcpy(&count, ccount, sizeof(count)); memcpy(&rows, crows, sizeof(rows)); memcpy(&cols, ccols, sizeof(cols)); Mat src = Mat::zeros(rows, cols, CV_8UC1); Mat temp = Mat::zeros(8, 8, CV_8UC1); Mat m = Mat::zeros(1, featureLen, CV_32FC1); Mat img, dst; //Just skip label header lab_ifs.read(magicNum, sizeof(magicNum)); lab_ifs.read(ccount, sizeof(ccount)); char label = 0; Scalar templateColor(255, 0, 0); NumTrainData rtd; int right = 0, error = 0, total = 0; int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0; while(ifs.good()) { //Read label lab_ifs.read(&label, 1); label = label + '0'; //Read data ifs.read((char*)src.data, rows * cols); GetROI(src, dst); //Too small to watch img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3); resize(dst, img, img.size()); rtd.result = label; resize(dst, temp, temp.size()); //threshold(temp, temp, 10, 1, CV_THRESH_BINARY); for(int i = 0; i<8; i++) { for(int j = 0; j<8; j++) { m.at<float>(0,j + i*8) = temp.at<uchar>(i, j); } } if(total >= count) break; normalize(m, m); char ret = (char)svm.predict(m); if(ret == label) { right++; if(total <= 5000) right_1++; else right_2++; } else { error++; if(total <= 5000) error_1++; else error_2++; } total++; #if(SHOW_PROCESS) stringstream ss; ss << "Number " << label << ", predict " << ret; string text = ss.str(); putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); imshow("img", img); if(waitKey(0)==27) //ESC to quit break; #endif } ifs.close(); lab_ifs.close(); stringstream ss; ss << "Total " << total << ", right " << right <<", error " << error; string text = ss.str(); putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); imshow("img", img); waitKey(0); return 0; } int main( int argc, char *argv[] ) { #if(ON_STUDY) int maxCount = 60000; ReadTrainData(maxCount); //newRtStudy(buffer); newSvmStudy(buffer); #else //newRtPredict(); newSvmPredict(); #endif return 0; }