opencv用SVM进行多类分类

        最近遇到一个多分类的问题,在网上查了些有关SVM的资料,这篇日志是来自:http://wenku.baidu.com/view/81c3e210f18583d0496459f0.html  自己写代码实现了下,感觉很好使,作为一个学习的例子,放在自己的博客里,供以后查阅使用。

        

#include "stdafx.h"
#include "cv.h"
#include "highgui.h"
#include "ml.h"
#include <TIME.H>
#include <CTYPE.H>
#include <IOSTREAM>

using namespace std;

int main(int argc, char* argv[])
{
	int size = 400; // height and widht of image
	const int s = 1000; // number of data
	int i, j,sv_num;
	IplImage* img;

	CvSVM svm = CvSVM();
	CvSVMParams param;
	CvTermCriteria criteria; // 停止迭代标准
	CvRNG rng = cvRNG(time(NULL));
	CvPoint pts[s]; // 定义1000个点
	float data[s*2]; // 点的坐标
	int res[s]; // 点的类别

	CvMat data_mat, res_mat;
	CvScalar rcolor;

	const float* support;

	// 图像区域的初始化
	img = cvCreateImage(cvSize(size,size),IPL_DEPTH_8U,3);
	cvZero(img);

	// 学习数据的生成
	for (i=0; i<s;++i)
	{
		pts[i].x = cvRandInt(&rng)%size;
		pts[i].y = cvRandInt(&rng)%size;

		if (pts[i].y>50*cos(pts[i].x*CV_PI/100)+200)
		{
			cvLine(img,cvPoint(pts[i].x-2,pts[i].y-2),cvPoint(pts[i].x+2,pts[i].y+2),CV_RGB(255,0,0));
			cvLine(img,cvPoint(pts[i].x+2,pts[i].y-2),cvPoint(pts[i].x-2,pts[i].y+2),CV_RGB(255,0,0));
			res[i]=1;
		}
		else
		{
			if (pts[i].x>200)
			{
				cvLine(img,cvPoint(pts[i].x-2,pts[i].y-2),cvPoint(pts[i].x+2,pts[i].y+2),CV_RGB(0,255,0));
				cvLine(img,cvPoint(pts[i].x+2,pts[i].y-2),cvPoint(pts[i].x-2,pts[i].y+2),CV_RGB(0,255,0));
				res[i]=2;
			}
			else
			{
				cvLine(img,cvPoint(pts[i].x-2,pts[i].y-2),cvPoint(pts[i].x+2,pts[i].y+2),CV_RGB(0,0,255));
				cvLine(img,cvPoint(pts[i].x+2,pts[i].y-2),cvPoint(pts[i].x-2,pts[i].y+2),CV_RGB(0,0,255));
				res[i]=3;
			}
		}
	}

	// 学习数据的现实
	cvNamedWindow("SVM",CV_WINDOW_AUTOSIZE);
	cvShowImage("SVM",img);
	cvWaitKey(0);

	// 学习参数的生成
	for (i=0;i<s;++i)
	{
		data[i*2] = float(pts[i].x)/size;
		data[i*2+1] = float(pts[i].y)/size;
	}

	cvInitMatHeader(&data_mat,s,2,CV_32FC1,data);
	cvInitMatHeader(&res_mat,s,1,CV_32SC1,res);
	criteria = cvTermCriteria(CV_TERMCRIT_EPS,1000,FLT_EPSILON);
	param = CvSVMParams(CvSVM::C_SVC,CvSVM::RBF,10.0,8.0,1.0,10.0,0.5,0.1,NULL,criteria);

	// SVM type:CvSVM::C_SVC  Kernel:CvSVM::RBF degree:10.0  gamma:8.0 coef0:1.0

	svm.train(&data_mat,&res_mat,NULL,NULL,param);

	// 学习结果绘图
	for (i=0;i<size;i++)
	{
		for (j=0;j<size;j++)
		{
			CvMat m;
			float ret = 0.0;
			float a[] = {float(j)/size,float(i)/size};
			cvInitMatHeader(&m,1,2,CV_32FC1,a);
			ret = svm.predict(&m);

			switch((int)ret)
			{
				case 1:
					rcolor = CV_RGB(100,0,0);
					break;
				case 2:
					rcolor = CV_RGB(0,100,0);
					break;
				case 3:
					rcolor = CV_RGB(0,0,100);
					break;
			}
			cvSet2D(img,i,j,rcolor);
		}
	}


	// 为了显示学习结果,通过对输入图像区域的所有像素(特征向量)进行分类,然后对输入的像素用所属颜色等级的颜色绘图
	for(i=0;i<s;++i)
	{
		CvScalar rcolor;
		switch(res[i])
		{
			case 1:
				rcolor = CV_RGB(255,0,0);
				break;
			case 2:
				rcolor = CV_RGB(0,255,0);
				break;
			case 3:
				rcolor = CV_RGB(0,0,255);
				break;
		}
		cvLine(img,cvPoint(pts[i].x-2,pts[i].y-2),cvPoint(pts[i].x+2,pts[i].y+2),rcolor);
		cvLine(img,cvPoint(pts[i].x+2,pts[i].y-2),cvPoint(pts[i].x-2,pts[i].y+2),rcolor);			
	}

	// 支持向量的绘制
	sv_num = svm.get_support_vector_count();
	for (i=0; i<sv_num;++i)
	{
		support = svm.get_support_vector(i);
		cvCircle(img,cvPoint((int)(support[0]*size),(int)(support[i]*size)),5,CV_RGB(200,200,200));
	}

	cvNamedWindow("SVM",CV_WINDOW_AUTOSIZE);
	cvShowImage("SVM",img);
	cvWaitKey(0);
	cvDestroyWindow("SVM");
	cvReleaseImage(&img);

	return 0;
}


 

 

你可能感兴趣的:(SVM)