使用OPENCV训练手写数字识别分类器

使用OPENCV训练手写数字识别分类器 ,

另一个车牌识别开源代码:http://www.dexmac.com/index.php/software/114-plategatewayqt

1,下载训练数据和测试数据文件(http://yann.lecun.com/exdb/mnist/),这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个.
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端.MNIST数据格式如下:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000801(2049) magic number (MSB first) 
0004     32 bit integer  60000            number of items 
0008     unsigned byte   ??               label 
0009     unsigned byte   ??               label 
........ 
xxxx     unsigned byte   ??               label

The labels values are 0 to 9.

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel

Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).


3,确定字符特征方式为最简单的8×8网格内的字符点数

使用OPENCV训练手写数字识别分类器_第1张图片

4,创建SVM,训练并读取,结果如下
 1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
  10000个训练样本,测试数据正确率95.45%
  60000个训练样本,测试数据正确率97.67%

5,编写手写输入的GUI程序,并进行验证,效果还可以接受。

使用OPENCV训练手写数字识别分类器_第2张图片

 

以下为主要代码,以供参考

(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)

#include "stdafx.h"

#include <fstream>
#include "opencv2/opencv.hpp"
#include <vector>

using namespace std;
using namespace cv;

#define SHOW_PROCESS 0
#define ON_STUDY 0

class NumTrainData
{
public:
	NumTrainData()
	{
		memset(data, 0, sizeof(data));
		result = -1;
	}
public:
	float data[64];
	int result;
};

vector<NumTrainData> buffer;
int featureLen = 64;

void swapBuffer(char* buf)
{
	char temp;
	temp = *(buf);
	*buf = *(buf+3);
	*(buf+3) = temp;

	temp = *(buf+1);
	*(buf+1) = *(buf+2);
	*(buf+2) = temp;
}

void GetROI(Mat& src, Mat& dst)
{
	int left, right, top, bottom;
	left = src.cols;
	right = 0;
	top = src.rows;
	bottom = 0;

	//Get valid area
	for(int i=0; i<src.rows; i++)
	{
		for(int j=0; j<src.cols; j++)
		{
			if(src.at<uchar>(i, j) > 0)
			{
				if(j<left) left = j;
				if(j>right) right = j;
				if(i<top) top = i;
				if(i>bottom) bottom = i;
			}
		}
	}

	//Point center;
	//center.x = (left + right) / 2;
	//center.y = (top + bottom) / 2;

	int width = right - left;
	int height = bottom - top;
	int len = (width < height) ? height : width;

	//Create a squre
	dst = Mat::zeros(len, len, CV_8UC1);

	//Copy valid data to squre center
	Rect dstRect((len - width)/2, (len - height)/2, width, height);
	Rect srcRect(left, top, width, height);
	Mat dstROI = dst(dstRect);
	Mat srcROI = src(srcRect);
	srcROI.copyTo(dstROI);
}

int ReadTrainData(int maxCount)
{
	//Open image and label file
	const char fileName[] = "../res/train-images.idx3-ubyte";
	const char labelFileName[] = "../res/train-labels.idx1-ubyte";

	ifstream lab_ifs(labelFileName, ios_base::binary);
	ifstream ifs(fileName, ios_base::binary);

	if( ifs.fail() == true )
		return -1;

	if( lab_ifs.fail() == true )
		return -1;

	//Read train data number and image rows / cols
	char magicNum[4], ccount[4], crows[4], ccols[4];
	ifs.read(magicNum, sizeof(magicNum));
	ifs.read(ccount, sizeof(ccount));
	ifs.read(crows, sizeof(crows));
	ifs.read(ccols, sizeof(ccols));

	int count, rows, cols;
	swapBuffer(ccount);
	swapBuffer(crows);
	swapBuffer(ccols);

	memcpy(&count, ccount, sizeof(count));
	memcpy(&rows, crows, sizeof(rows));
	memcpy(&cols, ccols, sizeof(cols));

	//Just skip label header
	lab_ifs.read(magicNum, sizeof(magicNum));
	lab_ifs.read(ccount, sizeof(ccount));

	//Create source and show image matrix
	Mat src = Mat::zeros(rows, cols, CV_8UC1);
	Mat temp = Mat::zeros(8, 8, CV_8UC1);
	Mat img, dst;

	char label = 0;
	Scalar templateColor(255, 0, 255 );

	NumTrainData rtd;

	//int loop = 1000;
	int total = 0;

	while(!ifs.eof())
	{
		if(total >= count)
			break;
		
		total++;
		cout << total << endl;
		
		//Read label
		lab_ifs.read(&label, 1);
		label = label + '0';

		//Read source data
		ifs.read((char*)src.data, rows * cols);
		GetROI(src, dst);

#if(SHOW_PROCESS)
		//Too small to watch
		img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);
		resize(dst, img, img.size());

		stringstream ss;
		ss << "Number " << label;
		string text = ss.str();
		putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);

		//imshow("img", img);
#endif

		rtd.result = label;
		resize(dst, temp, temp.size());
		//threshold(temp, temp, 10, 1, CV_THRESH_BINARY);

		for(int i = 0; i<8; i++)
		{
			for(int j = 0; j<8; j++)
			{
					rtd.data[ i*8 + j] = temp.at<uchar>(i, j);
			}
		}

		buffer.push_back(rtd);

		//if(waitKey(0)==27) //ESC to quit
		//	break;

		maxCount--;
		
		if(maxCount == 0)
			break;
	}

	ifs.close();
	lab_ifs.close();

	return 0;
}

void newRtStudy(vector<NumTrainData>& trainData)
{
	int testCount = trainData.size();

	Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
	Mat res = Mat::zeros(testCount, 1, CV_32SC1);

	for (int i= 0; i< testCount; i++) 
	{ 

		NumTrainData td = trainData.at(i);
		memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));

		res.at<unsigned int>(i, 0) = td.result;
	}

	/////////////START RT TRAINNING//////////////////
    CvRTrees forest;
    CvMat* var_importance = 0;

    forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),
            CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
	forest.save( "new_rtrees.xml" );
}


int newRtPredict()
{
    CvRTrees forest;
	forest.load( "new_rtrees.xml" );

	const char fileName[] = "../res/t10k-images.idx3-ubyte";
	const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";

	ifstream lab_ifs(labelFileName, ios_base::binary);
	ifstream ifs(fileName, ios_base::binary);

	if( ifs.fail() == true )
		return -1;

	if( lab_ifs.fail() == true )
		return -1;

	char magicNum[4], ccount[4], crows[4], ccols[4];
	ifs.read(magicNum, sizeof(magicNum));
	ifs.read(ccount, sizeof(ccount));
	ifs.read(crows, sizeof(crows));
	ifs.read(ccols, sizeof(ccols));

	int count, rows, cols;
	swapBuffer(ccount);
	swapBuffer(crows);
	swapBuffer(ccols);

	memcpy(&count, ccount, sizeof(count));
	memcpy(&rows, crows, sizeof(rows));
	memcpy(&cols, ccols, sizeof(cols));

	Mat src = Mat::zeros(rows, cols, CV_8UC1);
	Mat temp = Mat::zeros(8, 8, CV_8UC1);
	Mat m = Mat::zeros(1, featureLen, CV_32FC1);
	Mat img, dst;

	//Just skip label header
	lab_ifs.read(magicNum, sizeof(magicNum));
	lab_ifs.read(ccount, sizeof(ccount));

	char label = 0;
	Scalar templateColor(255, 0, 0);

	NumTrainData rtd;

	int right = 0, error = 0, total = 0;
	int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
	while(ifs.good())
	{
		//Read label
		lab_ifs.read(&label, 1);
		label = label + '0';

		//Read data
		ifs.read((char*)src.data, rows * cols);
		GetROI(src, dst);

		//Too small to watch
		img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
		resize(dst, img, img.size());

		rtd.result = label;
		resize(dst, temp, temp.size());
		//threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
		for(int i = 0; i<8; i++)
		{
			for(int j = 0; j<8; j++)
			{
					m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
			}
		}

		if(total >= count)
			break;

		char ret = (char)forest.predict(m); 

		if(ret == label)
		{
			right++;
			if(total <= 5000)
				right_1++;
			else
				right_2++;
		}
		else
		{
			error++;
			if(total <= 5000)
				error_1++;
			else
				error_2++;
		}

		total++;

#if(SHOW_PROCESS)
		stringstream ss;
		ss << "Number " << label << ", predict " << ret;
		string text = ss.str();
		putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);

		imshow("img", img);
		if(waitKey(0)==27) //ESC to quit
			break;
#endif

	}

	ifs.close();
	lab_ifs.close();

	stringstream ss;
	ss << "Total " << total << ", right " << right <<", error " << error;
	string text = ss.str();
	putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
	imshow("img", img);
	waitKey(0);

	return 0;
}

void newSvmStudy(vector<NumTrainData>& trainData)
{
	int testCount = trainData.size();

	Mat m = Mat::zeros(1, featureLen, CV_32FC1);
	Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
	Mat res = Mat::zeros(testCount, 1, CV_32SC1);

	for (int i= 0; i< testCount; i++) 
	{ 

		NumTrainData td = trainData.at(i);
		memcpy(m.data, td.data, featureLen*sizeof(float));
		normalize(m, m);
		memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));

		res.at<unsigned int>(i, 0) = td.result;
	}

	/////////////START SVM TRAINNING//////////////////
	CvSVM svm = CvSVM(); 
	CvSVMParams param; 
	CvTermCriteria criteria;

	criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON); 
	param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria); 

	svm.train(data, res, Mat(), Mat(), param);
	svm.save( "SVM_DATA.xml" );
}


int newSvmPredict()
{
	CvSVM svm = CvSVM(); 
	svm.load( "SVM_DATA.xml" );

	const char fileName[] = "../res/t10k-images.idx3-ubyte";
	const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";

	ifstream lab_ifs(labelFileName, ios_base::binary);
	ifstream ifs(fileName, ios_base::binary);

	if( ifs.fail() == true )
		return -1;

	if( lab_ifs.fail() == true )
		return -1;

	char magicNum[4], ccount[4], crows[4], ccols[4];
	ifs.read(magicNum, sizeof(magicNum));
	ifs.read(ccount, sizeof(ccount));
	ifs.read(crows, sizeof(crows));
	ifs.read(ccols, sizeof(ccols));

	int count, rows, cols;
	swapBuffer(ccount);
	swapBuffer(crows);
	swapBuffer(ccols);

	memcpy(&count, ccount, sizeof(count));
	memcpy(&rows, crows, sizeof(rows));
	memcpy(&cols, ccols, sizeof(cols));

	Mat src = Mat::zeros(rows, cols, CV_8UC1);
	Mat temp = Mat::zeros(8, 8, CV_8UC1);
	Mat m = Mat::zeros(1, featureLen, CV_32FC1);
	Mat img, dst;

	//Just skip label header
	lab_ifs.read(magicNum, sizeof(magicNum));
	lab_ifs.read(ccount, sizeof(ccount));

	char label = 0;
	Scalar templateColor(255, 0, 0);

	NumTrainData rtd;

	int right = 0, error = 0, total = 0;
	int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
	while(ifs.good())
	{
		//Read label
		lab_ifs.read(&label, 1);
		label = label + '0';

		//Read data
		ifs.read((char*)src.data, rows * cols);
		GetROI(src, dst);

		//Too small to watch
		img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
		resize(dst, img, img.size());

		rtd.result = label;
		resize(dst, temp, temp.size());
		//threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
		for(int i = 0; i<8; i++)
		{
			for(int j = 0; j<8; j++)
			{
					m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
			}
		}

		if(total >= count)
			break;

		normalize(m, m);
		char ret = (char)svm.predict(m); 

		if(ret == label)
		{
			right++;
			if(total <= 5000)
				right_1++;
			else
				right_2++;
		}
		else
		{
			error++;
			if(total <= 5000)
				error_1++;
			else
				error_2++;
		}

		total++;

#if(SHOW_PROCESS)
		stringstream ss;
		ss << "Number " << label << ", predict " << ret;
		string text = ss.str();
		putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);

		imshow("img", img);
		if(waitKey(0)==27) //ESC to quit
			break;
#endif

	}

	ifs.close();
	lab_ifs.close();

	stringstream ss;
	ss << "Total " << total << ", right " << right <<", error " << error;
	string text = ss.str();
	putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
	imshow("img", img);
	waitKey(0);

	return 0;
}

int main( int argc, char *argv[] )
{
#if(ON_STUDY)
	int maxCount = 60000;
	ReadTrainData(maxCount);

	//newRtStudy(buffer);
	newSvmStudy(buffer);
#else
	//newRtPredict();
	newSvmPredict();
#endif
	return 0;
}


 

你可能感兴趣的:(image,测试,Integer,float,byte,DST)