三、基于SVM算法实现手写数字识别
作为一个工科生,而非数学专业的学生,我们研究一个算法,是要将它用于实际领域的。下面给出基于OpenCV3.0的SVM算法手写数字识别程序源码(参考http://blog.csdn.net/firefight/article/details/6452188)程序略有改动。
本部分将基于OpenCV实现简单的数字识别,待识别图像如下图所示,通过以下几个步骤实现图像中的数字的自动识别。
1.使用OpenCV训练手写数字识别分类器;
2.图像预处理及图像分割;
3.应用分类器进行识别。
3.1使用OpenCV训练手写数字识别分类器
所谓学习分类器就是根据训练样本,选取模型训练产生数字分类器,这里采用上文提到的SVM算法。
训练集使用MNIST,这个MNIST数据库是一个手写数字的数据库,它提供了六万的训练集和一万的测试集。它的图片是被规范处理过的,是一张被放在中间部位的28px*28px的灰度图。总共包含4个文件,每一个文件头部几个字节都记录着这些图片的信息,然后才是储存的图片信息,关于文件信息的具体描述可以参考下面这个网站:https://www.jianshu.com/p/4195577585e6
下面是利用OpenCV 3.2.0的SVM相关API学习MNIST样本库产生样本函数的主要代码:(值得注意的是MNIST库中的图像是黑底白字的)
svm.h头文件
#pragma once
#include
#include
#include
#include
#include
#include
#include
using namespace std;
using namespace cv;
class NumTrainData
{
public:
NumTrainData()
{
memset(data, 0, sizeof(data));//Sets buffers to a specified character. Init the data
result = -1;
}
public:
float data[64];
int result;
};
extern vector buffer;
int ReadTrainData(int maxCount);
void newSvmStudy(vector& trainData);
char JpgPredict(Mat src);
svm.cpp文件
#include "svm.h"
#include "opencv2/opencv.hpp"
using namespace cv;
using namespace std;
using namespace cv::ml;
#define SHOW_PROCESS 0
#define ON_STUDY 0
int featureLen = 64;
void swapBuffer(char *buf)//0123->3210
{
char temp;
temp = *(buf);
*buf = *(buf + 3);
*(buf + 3) = temp;
temp = *(buf + 1);
*(buf + 1) = *(buf + 2);
*(buf + 2) = temp;
}
//获取ROI区域
void GetROI(Mat& src, Mat& dst)
{
int left, right, top, bottom;
left = src.cols;
right = 0;
top = src.rows;
bottom = 0;//右下角为原点
//Get valid area
for (int i = 0; i < src.rows; i++)
{
for (int j = 0; j < src.cols; j++)
{
if (src.at(i, j) > 0)//获取src中i,j点的像素值,为灰度图像,值为0-255
{
if (j < left) left = j;
if (j > right) right = j;
if (i < top) top = i;
if (i > bottom) bottom = i;
}
}
}//将原点置于含有像素点的方框的左上角
//Point center;
//center.x=(left+right)/2;
//center.y=(top+bottom)/2;
int width = right - left;
int height = bottom - top;
int len = (width < height) ? height : width;
//create a squre
dst = Mat::zeros(len, len, CV_8UC1);
//Copy valid data to squre center
Rect dstRect((len - width) / 2, (len - height) / 2, width, height);
Rect srcRect(left, top, width, height);
Mat dstROI = dst(dstRect);
Mat srcROI = src(srcRect);
srcROI.copyTo(dstROI);
}
int ReadTrainData(int maxCount)
{
//Open image and label file
const char fileName[] = "res//train-images.idx3-ubyte";//图像信息,以二进制方式存储 28*28
const char LabelFileName[] = "res//train-labels.idx1-ubyte";//标签信息,以二进制方式存储
//ofstream是从内存到硬盘,ifstream是从硬盘到内存,读取标准样本库
ifstream lab_ifs(LabelFileName, ios_base::binary);
ifstream ifs(fileName, ios_base::binary);
if (ifs.fail() == true)//读取文件失败
return -1;
if (lab_ifs.fail() == true)//读取文件失败
return -1;
//Read train data number and image rows/clos
char magicNum[4], ccount[4], crows[4], ccols[4];
ifs.read(magicNum, sizeof(magicNum));//Read block of data
ifs.read(ccount, sizeof(ccount));
ifs.read(crows, sizeof(crows));
ifs.read(ccols, sizeof(ccols));
int count, rows, cols;
swapBuffer(ccount);//Copies bytes between buffers.
swapBuffer(crows);
swapBuffer(ccols);
memcpy(&count, ccount, sizeof(count));//Copies bytes between buffers.
memcpy(&rows, crows, sizeof(rows));
memcpy(&cols, ccols, sizeof(cols));
//Just skip label header
lab_ifs.read(magicNum, sizeof(magicNum));
lab_ifs.read(ccount, sizeof(ccount));
//Create source and show image matrix
Mat src = Mat::zeros(rows, cols, CV_8UC1);//28*28 piex single channel image
Mat temp = Mat::zeros(8, 8, CV_8UC1);
Mat img, dst;
char label = 0;
Scalar templateColor(255, 0, 255);
NumTrainData rtd;
//int loop=1000;
int total = 0;
while (!ifs.eof())//Indicates if the end of a stream has been reached.
{
if (total >= count)//total train data number
break;
total++;
//cout << total << endl;
//Read label
lab_ifs.read(&label, 1);//读取标签,1个字节
label = label + '0';//转换为ASCII码中的罗马数字
//Read source data
ifs.read((char*)src.data, rows*cols);//读取训练图像数据;每个像素被转成了0-255,0代表着白色,255代表着黑色。
GetROI(src, dst);
#if(SHOW_PROCESS)
//Too small to watch
img = Mat::zeros(dst.rows * 10, dst.cols * 10, CV_8UC1);
resize(dst, img, img.size());
stringstream ss;
ss << "Number" << label;
string text = ss.str();
putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, template);
#endif
rtd.result = label;
resize(dst, temp, temp.size());//将dst缩放成一个8*8的temp矩阵
//tehreshold(temp,temp,10,1,CT_THRESH_BINARY);
for (int i = 0; i < 8; i++)
{
for (int j = 0; j < 8; j++)
{
rtd.data[i * 8 + j] = temp.at(i, j);
}
}
buffer.push_back(rtd);
//if(waitKey(0)==27)//ESC to quit
//break;
maxCount--;
if (maxCount == 0)
{
//cout << "maxcount=" << maxCount << endl;
system("pause");
break;
}
}
//buffer中存储了maxcount个8*8的矩阵和它所具有的标签
ifs.close();
lab_ifs.close();
return 0;
}
void newSvmStudy(vector& trainData)
{
int testCount = trainData.size();//60000
Mat m = Mat::zeros(1, featureLen, CV_32FC1);
Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
Mat res = Mat::zeros(testCount, 1, CV_32SC1);
for (int i = 0; i < testCount; i++)
{
NumTrainData td = trainData.at(i);
memcpy(m.data, td.data, featureLen * sizeof(float));
normalize(m, m);
memcpy(data.data + i*featureLen * sizeof(float), m.data, featureLen * sizeof(float));
res.at(i, 0) = td.result;
//res.at(i, 0) = td.result;//存储标签
}
////////////////////START RT TRAINNING///////////////
//设置SVM参数
Ptr svm = SVM::create();
svm->setType(SVM::C_SVC);//用于多类分类
svm->setKernel(SVM::RBF);//采用高斯核函数
svm->setTermCriteria(cv::TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));
svm->setDegree(10.0);//高斯核的参数设置
svm->setGamma(8.0);
svm->setCoef0(1.0);
svm->setC(10.0);
svm->setNu(0.5);
svm->setP(0.1);
//训练
Ptr tData = TrainData::create(data, ROW_SAMPLE, res);
svm->train(tData);
svm->save("res\\SVM_DATA.xml");
}
//预测数据
char JpgPredict(Mat src)
{
Ptr svm = Algorithm::load("res\\SVM_DATA.xml");
svm->load("res\\SVM_DATA.xml");
threshold(src, src, 230, 250, CV_THRESH_BINARY);
Mat temp = Mat::zeros(8, 8, CV_8UC1);
Mat m = Mat::zeros(1, featureLen, CV_32FC1);
Mat element = getStructuringElement(MORPH_RECT, Size(2, 2));
dilate(src, src, element);
imshow("1", src);
waitKey(30);
resize(src, temp, temp.size());
for (int i = 0; i < 8; i++)
{
for (int j = 0; j < 8; j++)
{
m.at(0, j + i * 8) = temp.at(i, j);
}
}
normalize(m, m);// 该函数归一化输入数组使它的范数或者数值范围在一定的范围内。
char ret = (char)svm->predict(m);//如果值为true而且是一个2类问题则返回判决函数值,否则返回类标签
return ret;
}
3.2 图像预处理及图像分割
前面通过学习产生了分类器,但我们输入图像中的数字并不能直接作为测试输入。图像中的数字笔画有时并不规整,还可能相互重叠。因为本文例子为了简化用的是屏幕截图,所以位置形变校正,色彩亮度校正等等都省去了,但仍需要一些简单处理。下面先对输入图像进行一把简单的预处理,主要目的是将图像转成二值图,这样便于我们下一步分割和识别。这样做还有个好处,就是把其余的噪声也顺带去掉了。
接下来,就可以对图像进行分割了。由于我们的分类器只能对数字一个一个地识别,所以首先要把每个数字分割出来。基本思想是先用findContours()函数把基本轮廓找出来,然后通过简单验证以确认是否为数字的轮廓。对于那些通过验证的轮廓,接下去会用boundingRect()找出它们的包围盒。
Process.h文件
#pragma once
#include "svm.h"
#include "opencv2/opencv.hpp"
class Coordinate //坐标类
{
public:
double x, y; //轮廓位置
int order; //轮廓向量contours中的第几个
bool operator<(Coordinate &m) //运算符重载,在sort()排序函数中使用
{
if (x < m.x)
return true;
else
return false;
}
};
void ImageProcess(Mat &srcImage);
void ImageFindRectangle(Mat &srcImage);
Process.cpp文件
#include "Process.h"
using namespace cv;
using namespace std;
Coordinate con[100] = { 0 }; //存放分割好的矩阵的中心坐标
vector> contours;//定义一个存放边缘矩阵的容器
vector hierarchy; //定义一个存放树节点的前后关系的容器
Rect rect[100]; //定义一个存放分割好图像的矩阵,注意数据溢出关系
int i = 0;//全局变量
void ImageFindRectangle(Mat &srcImage)
{
//使用contours迭代器遍历每一个轮廓,找到并画出包围这个轮廓的最小矩阵
vector>::iterator It;
for (It = contours.begin(); It < contours.end(); It++)
{
//画出可包围数字的最小矩形
Point2f vertex[4];
rect[i] = boundingRect(*It); //计算轮廓的垂直边界最小矩形,矩形是与图像上下边界平行的
//矩形左上角的点
vertex[0] = rect[i].tl();
//矩形左下角的点
vertex[1].x = (float)rect[i].tl().x, vertex[1].y = (float)rect[i].br().y;
//矩形右下角的点
vertex[2] = rect[i].br();
//矩形右上方的点
vertex[3].x = (float)rect[i].br().x, vertex[3].y = (float)rect[i].tl().y;
for (int j = 0; j < 4; j++)
line(srcImage, vertex[j], vertex[(j + 1) % 4], Scalar(0, 0, 255), 1);
con[i].x = (vertex[0].x + vertex[1].x + vertex[2].x + vertex[3].x) / 4.0;
//根据中心点判断图图像的位置
con[i].y = (vertex[0].y + vertex[1].y + vertex[2].y + vertex[3].y) / 4.0;
con[i].order = i;
i++;
}
sort(con, con + i); //将con按升序排列
}
void ImageProcess(Mat &srcImage)
{
Mat Image = Mat::zeros(srcImage.size(), CV_8U);
Mat grayImage = Mat::zeros(srcImage.size(), CV_8U);
//图像预处理
cvtColor(srcImage, srcImage, COLOR_BGR2GRAY); //转化为灰度图像
threshold(srcImage, srcImage, 230, 255, CV_THRESH_BINARY);//阈值化
//寻找图像边缘
findContours(srcImage, contours, hierarchy, CV_RETR_EXTERNAL, CV_CHAIN_APPROX_NONE);//寻找图像边缘;函数用法参数见笔记
Mat dstImage = Mat::zeros(Image.size(), CV_8U);
drawContours(dstImage, contours, -1, Scalar(255, 0, 255));//在dstImage图像中画出边缘
//进行分割
ImageFindRectangle(dstImage);
//存储分割矩阵
Mat num[11];
for (int j = 0; j < i; j++)
{
int k;
k = con[j].order;
srcImage(rect[k]).copyTo(num[j]);
}
cout << "i=" << i << endl;
vector res;
for (int j = 0; j < i; j++)
{
res.push_back(JpgPredict(num[j]));
//cout << JpgPredict(num[j]) << endl;
}
cout << "Predicted number is:";
for (const auto&number : res)
{
cout <
3.3 应用分类器进行识别
Main.cpp函数
#include "svm.h"
#include "Process.h"
#include
#include
#include
using namespace cv;
using namespace std;
vector buffer;
#define ON_STUDY 0
#define ON_PROCESS 1
int main(void)
{
#if ON_STUDY
int maxCount = 30000;
ReadTrainData(maxCount);
newSvmStudy(buffer);
#endif
#if ON_PROCESS
Mat img = imread("Sample3.jpg");
ImageProcess(img);
waitKey(0);
#endif
return 0;
}
识别结果如下:
结果检测,SVM算法可以较好的识别手写数字,但是在编写代码的过程中发现一个问题,那就是这个算法对“1”数字的识别精度非常差,可能10张图中只能正确识别一次,不知道有没有大神能够给出一些建议?
上一篇:基于OpenCV的 SVM算法实现数字识别(三)---SMO求解