SVM数字识别
#include "stdafx.h" #include#include "opencv2/opencv.hpp" #include using namespace std; using namespace cv; #define SHOW_PROCESS 1 #define ON_STUDY 1 class NumTrainData { public: NumTrainData() { memset(data, 0, sizeof(data)); result = -1; } public: float data[64]; int result; }; vector buffer; int featureLen = 64; void swapBuffer(char* buf) { char temp; temp = *(buf); *buf = *(buf+3); *(buf+3) = temp; temp = *(buf+1); *(buf+1) = *(buf+2); *(buf+2) = temp; } void GetROI(Mat& src, Mat& dst) { int left, right, top, bottom; left = src.cols; right = 0; top = src.rows; bottom = 0; //Get valid area for(int i=0; i (i, j) > 0) { if(j right) right = j; if(i bottom) bottom = i; } } } //Point center; //center.x = (left + right) / 2; //center.y = (top + bottom) / 2; int width = right - left; int height = bottom - top; int len = (width < height) ? height : width; //Create a squre dst = Mat::zeros(len, len, CV_8UC1); //Copy valid data to squre center Rect dstRect((len - width)/2, (len - height)/2, width, height); Rect srcRect(left, top, width, height); Mat dstROI = dst(dstRect); Mat srcROI = srcRect); srcROI.copyTo(dstROI); } int ReadTrainData() { Mat src; Mat temp = Mat::zeros(8, 8, CV_8UC1); Mat m = Mat::zeros(1, 64, CV_8UC1); Mat dst; NumTrainData rtd; const int p_num = 377; //图片对应的数字,存储在nclass.xml中 Mat label=cvCreateMat(1, p_num, CV_32SC1); FileStorage file("nclass.xml", FileStorage::READ); file["data"]>>label; file.release(); cout< (i, j); rtd.data[ i*8 + j] = ff; } } re=label.at (0,k); rtd.result=re; buffer.push_back(rtd); k++; } return 0; } void newRtStudy(vector & trainData) { int testCount = trainData.size(); Mat data = Mat::zeros(testCount, featureLen, CV_32FC1); Mat res = Mat::zeros(testCount, 1, CV_32SC1); for (int i= 0; i< testCount; i++) { NumTrainData td = trainData.at(i); memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float)); res.at (i, 0) = td.result; } /////////////START RT TRAINNING////////////////// CvRTrees forest; CvMat* var_importance = 0; forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(), CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER)); forest.save( "new_rtrees.xml" ); } int newRtPredict() { CvRTrees forest; forest.load( "new_rtrees.xml" ); const char fileName[] = "../res/t10k-p_w_picpaths.idx3-ubyte"; const char labelFileName[] = "../res/t10k-labels.idx1-ubyte"; ifstream lab_ifs(labelFileName, ios_base::binary); ifstream ifs(fileName, ios_base::binary); if( ifs.fail() == true ) return -1; if( lab_ifs.fail() == true ) return -1; char magicNum[4], ccount[4], crows[4], ccols[4]; ifs.read(magicNum, sizeof(magicNum)); ifs.read(ccount, sizeof(ccount)); ifs.read(crows, sizeof(crows)); ifs.read(ccols, sizeof(ccols)); int count, rows, cols; swapBuffer(ccount); swapBuffer(crows); swapBuffer(ccols); memcpy(&count, ccount, sizeof(count)); memcpy(&rows, crows, sizeof(rows)); memcpy(&cols, ccols, sizeof(cols)); Mat src = Mat::zeros(rows, cols, CV_8UC1); Mat temp = Mat::zeros(8, 8, CV_8UC1); Mat m = Mat::zeros(1, featureLen, CV_32FC1); Mat img, dst; //Just skip label header lab_ifs.read(magicNum, sizeof(magicNum)); lab_ifs.read(ccount, sizeof(ccount)); char label = 0; Scalar templateColor(255, 0, 0); NumTrainData rtd; int right = 0, error = 0, total = 0; int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0; while(ifs.good()) { //Read label lab_ifs.read(&label, 1); label = label + '0'; //Read data ifs.read((char*)src.data, rows * cols); GetROI(src, dst); //Too small to watch img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3); resize(dst, img, img.size()); rtd.result = label; resize(dst, temp, temp.size()); //threshold(temp, temp, 10, 1, CV_THRESH_BINARY); for(int i = 0; i<8; i++) { for(int j = 0; j<8; j++) { m.at (0,j + i*8) = temp.at (i, j); } } if(total >= count) break; char ret = (char)forest.predict(m); if(ret == label) { right++; if(total <= 5000) right_1++; else right_2++; } else { error++; if(total <= 5000) error_1++; else error_2++; } total++; #if(SHOW_PROCESS) stringstream ss; ss << "Number " << label << ", predict " << ret; string text = ss.str(); putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); imshow("img", img); if(waitKey(0)==27) //ESC to quit break; #endif } ifs.close(); lab_ifs.close(); stringstream ss; ss << "Total " << total << ", right " << right <<", error " << error; string text = ss.str(); putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor); imshow("img", img); //waitKey(0); return 0; } void newSvmStudy(vector & trainData) { int testCount = trainData.size(); Mat m = Mat::zeros(1, featureLen, CV_32FC1); Mat data = Mat::zeros(testCount, featureLen, CV_32FC1); Mat res = Mat::zeros(testCount, 1, CV_32SC1); for (int i= 0; i< testCount; i++) { NumTrainData td = trainData.at(i); memcpy(m.data, td.data, featureLen*sizeof(float)); normalize(m, m); memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float)); res.at (i, 0) = td.result; } /////////////START SVM TRAINNING////////////////// CvSVM svm = CvSVM(); CvSVMParams param; CvTermCriteria criteria; criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON); param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria); svm.train(data, res, Mat(), Mat(), param); svm.save( "SVM_DATA.xml" ); } int newSvmPredict() { CvSVM svm = CvSVM(); svm.load( "SVM_DATA.xml" ); Mat temp = Mat::zeros(8, 8, CV_8UC1); Mat m = Mat::zeros(1, featureLen, CV_32FC1); Mat dst; int k=0; const int p_num=17; char ch[10]; while (k <= p_num) { string str; sprintf(ch,"%d",k+313); str=ch; str=str+".jpg"; //Read data Mat imread(str.c_str(), 0); GetROI(src, dst); resize(dst, temp, temp.size()); //threshold(temp, temp, 10, 1, CV_THRESH_BINARY); for(int i = 0; i<8; i++) { for(int j = 0; j<8; j++) { m.at (0,j + i*8) = temp.at (i, j); } } normalize(m, m); int ret = (int)svm.predict(m); cout<
输入的是标号为0到p_num的图片和对应的数字,nclass.xml是opencv对mat类型自动存储的格式,然后手动修改内容,想找个更好的,目前还没有找到