学习OpenCV——SVM

OpenCV开发SVM算法是基于LibSVM软件包开发的,LibSVM是台湾大学林智仁(Lin Chih-Jen)等开发设计的一个简单、易于使用和快速有效的SVM模式识别与回归的软件包。用OpenCV使用SVM算法的大概流程是

1)设置训练样本集

需要两组数据,一组是数据的类别,一组是数据的向量信息。

2)设置SVM参数

利用CvSVMParams类实现类内的成员变量svm_type表示SVM类型:

CvSVM::C_SVC  C-SVC

CvSVM::NU_SVC v-SVC

CvSVM::ONE_CLASS 一类SVM

CvSVM::EPS_SVR e-SVR

CvSVM::NU_SVR v-SVR

成员变量kernel_type表示核函数的类型:

CvSVM::LINEAR 线性:u‘v

CvSVM::POLY 多项式:(r*u'v + coef0)^degree

CvSVM::RBF RBF函数:exp(-r|u-v|^2)

CvSVM::SIGMOID sigmoid函数:tanh(r*u'v + coef0)

成员变量degree针对多项式核函数degree的设置,gamma针对多项式/rbf/sigmoid核函数的设置,coef0针对多项式/sigmoid核函数的设置,Cvalue为损失函数,在C-SVC、e-SVR、v-SVR中有效,nu设置v-SVC、一类SVM和v-SVR参数,p为设置e-SVR中损失函数的值,class_weightsC_SVC的权重,term_crit为SVM训练过程的终止条件。其中默认值degree = 0,gamma = 1,coef0 = 0,Cvalue = 1,nu = 0,p = 0,class_weights = 0

3)训练SVM

调用CvSVM::train函数建立SVM模型,第一个参数为训练数据,第二个参数为分类结果,最后一个参数即CvSVMParams

4)用这个SVM进行分类

调用函数CvSVM::predict实现分类

5)获得支持向量

除了分类,也可以得到SVM的支持向量,调用函数CvSVM::get_support_vector_count获得支持向量的个数,CvSVM::get_support_vector获得对应的索引编号的支持向量。

实现代码如下:运行步骤

[cpp]  view plain copy print ?
  1. // step 1:   
  2. float labels[4] = {1.0, -1.0, -1.0, -1.0};  
  3. Mat labelsMat(3, 1, CV_32FC1, labels);  
  4.   
  5. float trainingData[4][2] = { {501, 10}, {255, 10}, {501, 255}, {10, 501} };  
  6. Mat trainingDataMat(3, 2, CV_32FC1, trainingData);  
  7.   
  8. // step 2:   
  9. CvSVMParams params;  
  10. params.svm_type = CvSVM::C_SVC;  
  11. params.kernel_type = CvSVM::LINEAR;  
  12. params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6);  
  13.   
  14. // step 3:   
  15. CvSVM SVM;  
  16. SVM.train(trainingDataMat, labelsMat, Mat(), Mat(), params);  
  17.   
  18. // step 4:   
  19. Vec3b green(0, 255, 0), blue(255, 0, 0);  
  20. for (int i=0; i<image.rows; i++)  
  21. {  
  22.     for (int j=0; j<image.cols; j++)  
  23.     {  
  24.         Mat sampleMat = (Mat_<float>(1,2) << i,j);  
  25.         float response = SVM.predict(sampleMat);  
  26.   
  27.         if (fabs(response-1.0) < 0.0001)  
  28.         {  
  29.             image.at<Vec3b>(j, i) = green;  
  30.         }  
  31.         else if (fabs(response+1.0) < 0.001)  
  32.         {  
  33.             image.at<Vec3b>(j, i) = blue;  
  34.         }  
  35.     }  
  36. }  
  37.   
  38. // step 5:   
  39. int c = SVM.get_support_vector_count();  
  40.   
  41. for (int i=0; i<c; i++)  
  42. {  
  43.     const float* v = SVM.get_support_vector(i);  
  44. }  

 

实验代码1:颜色分类

[cpp]  view plain copy print ?
  1. //利用SVM解决2维空间向量的3级分类问题      
  2. #include "stdafx.h"      
  3. #include "cv.h"      
  4. #include "highgui.h"             
  5. #include <ML.H>      
  6. #include <TIME.H>      
  7.      
  8. #include <CTYPE.H>      
  9.     
  10. #include <IOSTREAM>      
  11. using namespace std;     
  12. int main(int argc, char **argv)     
  13. {     
  14.         int size = 400;         //图像的长度和宽度      
  15.         const int s = 1000;          //试验点个数(可更改!!)      
  16.         int i, j, sv_num;     
  17.         IplImage *img;     
  18.         CvSVM svm = CvSVM();    //★★★      
  19.         CvSVMParams param;     
  20.         CvTermCriteria criteria;//停止迭代的标准      
  21.         CvRNG rng = cvRNG(time(NULL));     
  22.         CvPoint pts[s];         //定义1000个点      
  23.         float data[s*2];        //点的坐标      
  24.         int res[s];             //点的所属类      
  25.         CvMat data_mat, res_mat;     
  26.         CvScalar rcolor;     
  27.         const float *support;     
  28.         // (1)图像区域的确保和初始化      
  29.         img= cvCreateImage(cvSize(size, size), IPL_DEPTH_8U, 3);     
  30.         cvZero(img);     
  31.         //确保画像区域,并清0(用黑色作初始化处理)。      
  32.          
  33.         // (2)学习数据的生成      
  34.         for (i= 0; i< s; i++) {     
  35.             pts[i].x= cvRandInt(&rng) % size;   //用随机整数赋值      
  36.             pts[i].y= cvRandInt(&rng) % size;     
  37.             if (pts[i].y> 50 * cos(pts[i].x* CV_PI/ 100) + 200) {     
  38.                 cvLine(img, cvPoint(pts[i].x- 2, pts[i].y- 2), cvPoint(pts[i].x+ 2, pts[i].y+ 2), CV_RGB(255, 0, 0));     
  39.                 cvLine(img, cvPoint(pts[i].x+ 2, pts[i].y- 2), cvPoint(pts[i].x- 2, pts[i].y+ 2), CV_RGB(255, 0, 0));     
  40.                 res[i] = 1;     
  41.             }     
  42.             else {     
  43.                 if (pts[i].x> 200) {     
  44.                     cvLine(img, cvPoint(pts[i].x- 2, pts[i].y- 2), cvPoint(pts[i].x+ 2, pts[i].y+ 2), CV_RGB(0, 255, 0));     
  45.                     cvLine(img, cvPoint(pts[i].x+ 2, pts[i].y- 2), cvPoint(pts[i].x- 2, pts[i].y+ 2), CV_RGB(0, 255, 0));     
  46.                     res[i] = 2;     
  47.                 }     
  48.                 else {     
  49.                     cvLine(img, cvPoint(pts[i].x- 2, pts[i].y- 2), cvPoint(pts[i].x+ 2, pts[i].y+ 2), CV_RGB(0, 0, 255));     
  50.                     cvLine(img, cvPoint(pts[i].x+ 2, pts[i].y- 2), cvPoint(pts[i].x- 2, pts[i].y+ 2), CV_RGB(0, 0, 255));     
  51.                     res[i] = 3;     
  52.                 }     
  53.             }     
  54.         }     
  55.         //生成2维随机训练数据,并将其值放在CvPoint数据类型的数组pts[ ]中。      
  56.          
  57.         // (3)学习数据的显示      
  58.         cvNamedWindow("SVM", CV_WINDOW_AUTOSIZE);     
  59.         cvShowImage("SVM", img);     
  60.         cvWaitKey(0);     
  61.          
  62.         // (4)学习参数的生成      
  63.         for (i= 0; i< s; i++) {     
  64.             data[i* 2] = float (pts[i].x) / size;     
  65.             data[i* 2 + 1] = float (pts[i].y) / size;     
  66.         }     
  67.         cvInitMatHeader(&data_mat, s, 2, CV_32FC1, data);     
  68.         cvInitMatHeader(&res_mat, s, 1, CV_32SC1, res);     
  69.         criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);     
  70.         param= CvSVMParams (CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);     
  71.         /*   
  72.             SVM种类:CvSVM::C_SVC   
  73.             Kernel的种类:CvSVM::RBF   
  74.             degree:10.0(此次不使用)   
  75.             gamma:8.0   
  76.             coef0:1.0(此次不使用)   
  77.             C:10.0   
  78.             nu:0.5(此次不使用)   
  79.             p:0.1(此次不使用)   
  80.             然后对训练数据正规化处理,并放在CvMat型的数组里。   
  81.                                                                 */     
  82.          
  83.          
  84.         //☆☆☆☆☆☆☆☆☆(5)SVM学习☆☆☆☆☆☆☆☆☆☆☆☆      
  85.         svm.train(&data_mat, &res_mat, NULL, NULL, param);//☆      
  86.         //☆☆利用训练数据和确定的学习参数,进行SVM学习☆☆☆☆          
  87.          
  88.         // (6)学习结果的绘图      
  89.         for (i= 0; i< size; i++) {     
  90.             for (j= 0; j< size; j++) {     
  91.                 CvMat m;     
  92.                 float ret = 0.0;     
  93.                 float a[] = { float (j) / size, float (i) / size };     
  94.                 cvInitMatHeader(&m, 1, 2, CV_32FC1, a);     
  95.                 ret= svm.predict(&m);     
  96.                 switch ((int) ret) {     
  97.                     case 1:     
  98.                         rcolor= CV_RGB(100, 0, 0);     
  99.                         break;     
  100.                     case 2:     
  101.                         rcolor= CV_RGB(0, 100, 0);     
  102.                         break;     
  103.                     case 3:     
  104.                         rcolor= CV_RGB(0, 0, 100);     
  105.                         break;     
  106.                 }     
  107.                 cvSet2D(img, i, j, rcolor);     
  108.             }     
  109.         }     
  110.         //为了显示学习结果,通过输入图像区域的所有像素(特征向量)并进行分类。然后对输入像素用所属等级的颜色绘图。      
  111.          
  112.         // (7)训练数据的再绘制      
  113.         for (i= 0; i< s; i++) {     
  114.             CvScalar rcolor;     
  115.             switch (res[i]) {     
  116.                 case 1:     
  117.                     rcolor= CV_RGB(255, 0, 0);     
  118.                     break;     
  119.                 case 2:     
  120.                     rcolor= CV_RGB(0, 255, 0);     
  121.                     break;     
  122.                 case 3:     
  123.                     rcolor= CV_RGB(0, 0, 255);     
  124.                     break;     
  125.             }     
  126.             cvLine(img, cvPoint(pts[i].x- 2, pts[i].y- 2), cvPoint(pts[i].x+ 2, pts[i].y+ 2), rcolor);     
  127.             cvLine(img, cvPoint(pts[i].x+ 2, pts[i].y- 2), cvPoint(pts[i].x- 2, pts[i].y+ 2), rcolor);     
  128.         }     
  129.         //将训练数据在结果图像上重复的绘制出来。      
  130.          
  131.         // (8)支持向量的绘制      
  132.         sv_num= svm.get_support_vector_count();     
  133.         for (i= 0; i< sv_num; i++) {     
  134.             support = svm.get_support_vector(i);     
  135.             cvCircle(img, cvPoint((int) (support[0] * size), (int) (support[1] * size)), 5, CV_RGB(200, 200, 200));     
  136.         }     
  137.         //用白色的圆圈对支持向量作标记。      
  138.          
  139.         // (9)图像的显示       
  140.         cvNamedWindow("SVM", CV_WINDOW_AUTOSIZE);     
  141.         cvShowImage("SVM", img);     
  142.         cvWaitKey(0);     
  143.         cvDestroyWindow("SVM");     
  144.         cvReleaseImage(&img);     
  145.         return 0;     
  146.         //显示实际处理结果的图像,直到某个键被按下为止。      
  147.     }    


实验代码2:用MIT人脸库检测,效果实在不好,检测结果全是人脸或者全都不是人脸。原因应该是图像检测没有做好应该用HoG等特征首先检测,在进行分类训练,不特征不明显,肯定分类效果并不好。

[cpp]  view plain copy print ?
  1. //////////////////////////////////////////////////////////////////////////  
  2. // File Name: pjSVM.cpp  
  3. // Author:   easyfov([email protected])  
  4. // Company: Lida Optical and Electronic Co.,Ltd.  
  5. //http://apps.hi.baidu.com/share/detail/32719017  
  6. //////////////////////////////////////////////////////////////////////////  
  7.   
  8. #include <cv.h>  
  9. #include <highgui.h>  
  10. #include <ml.h>  
  11.   
  12. #include <iostream>  
  13. #include <fstream>  
  14. #include <string>  
  15. #include <vector>  
  16. using namespace std;  
  17.   
  18. #define WIDTH 20  
  19. #define HEIGHT 20  
  20.   
  21. int main( /*int argc, char** argv*/ )  
  22. {  
  23.     vector<string> img_path;  
  24.     vector<int> img_catg;  
  25.     int nLine = 0;  
  26.     string buf;  
  27.     ifstream svm_data( "E:/SVM_DATA.txt" );  
  28.   
  29.     while( svm_data )  
  30.     {  
  31.         if( getline( svm_data, buf ) )  
  32.         {  
  33.             nLine ++;  
  34.             if( nLine % 2 == 0 )  
  35.             {  
  36.                  img_catg.push_back( atoi( buf.c_str() ) );//atoi将字符串转换成整型,标志(0,1)  
  37.             }  
  38.             else  
  39.             {  
  40.                 img_path.push_back( buf );//图像路径  
  41.             }  
  42.         }  
  43.     }  
  44.     svm_data.close();//关闭文件  
  45.   
  46.     CvMat *data_mat, *res_mat;  
  47.     int nImgNum = nLine / 2;            //读入样本数量  
  48.     ////样本矩阵,nImgNum:横坐标是样本数量, WIDTH * HEIGHT:样本特征向量,即图像大小  
  49.     data_mat = cvCreateMat( nImgNum, WIDTH * HEIGHT, CV_32FC1 );  
  50.     cvSetZero( data_mat );  
  51.     //类型矩阵,存储每个样本的类型标志  
  52.     res_mat = cvCreateMat( nImgNum, 1, CV_32FC1 );  
  53.     cvSetZero( res_mat );  
  54.   
  55.     IplImage *srcImg, *sampleImg;  
  56.     float b;  
  57.     DWORD n;  
  58.   
  59.     for( string::size_type i = 0; i != img_path.size(); i++ )  
  60.     {  
  61.        srcImg = cvLoadImage( img_path[i].c_str(), CV_LOAD_IMAGE_GRAYSCALE );  
  62.        if( srcImg == NULL )  
  63.        {  
  64.             cout<<" can not load the image: "<<img_path[i].c_str()<<endl;  
  65.             continue;  
  66.        }  
  67.   
  68.        cout<<" processing "<<img_path[i].c_str()<<endl;  
  69.   
  70.        sampleImg = cvCreateImage( cvSize( WIDTH, HEIGHT ), IPL_DEPTH_8U, 1 );//样本大小(WIDTH, HEIGHT)  
  71.        cvResize( srcImg, sampleImg );//改变图像大小  
  72.   
  73.        cvSmooth( sampleImg, sampleImg );    //降噪  
  74.        //生成训练数据  
  75.        n = 0;  
  76.         forint ii = 0; ii < sampleImg->height; ii++ )  
  77.         {  
  78.             forint jj = 0; jj < sampleImg->width; jj++, n++ )  
  79.             {  
  80.                  b = (float)((int)((uchar)( sampleImg->imageData + sampleImg->widthStep * ii + jj )) / 255.0 );  
  81.                  cvmSet( data_mat, (int)i, n, b );  
  82.             }  
  83.         }  
  84.         cvmSet( res_mat, i, 0, img_catg[i] );  
  85.         cout<<" end processing "<<img_path[i].c_str()<<" "<<img_catg[i]<<endl;  
  86.     }  
  87.   
  88.   
  89.     CvSVM svm = CvSVM();  
  90.     CvSVMParams param;  
  91.     CvTermCriteria criteria;  
  92.     criteria = cvTermCriteria( CV_TERMCRIT_EPS, 1000, FLT_EPSILON );  
  93.     param = CvSVMParams( CvSVM::C_SVC, CvSVM::RBF, 10.0, 0.09, 1.0, 10.0, 0.5, 1.0, NULL, criteria );  
  94.      /*   
  95.             SVM种类:CvSVM::C_SVC   
  96.             Kernel的种类:CvSVM::RBF   
  97.             degree:10.0(此次不使用)   
  98.             gamma:8.0   
  99.             coef0:1.0(此次不使用)   
  100.             C:10.0   
  101.             nu:0.5(此次不使用)   
  102.             p:0.1(此次不使用)   
  103.             然后对训练数据正规化处理,并放在CvMat型的数组里。   
  104.                                                                 */     
  105.     //☆☆☆☆☆☆☆☆☆(5)SVM学习☆☆☆☆☆☆☆☆☆☆☆☆      
  106.     svm.train( data_mat, res_mat, NULL, NULL, param );  
  107.     //☆☆利用训练数据和确定的学习参数,进行SVM学习☆☆☆☆  
  108.     svm.save( "SVM_DATA.xml" );  
  109.   
  110.   
  111.     //检测样本  
  112.     IplImage *tst, *tst_tmp;  
  113.     vector<string> img_tst_path;  
  114.     ifstream img_tst( "E:/SVM_TEST.txt" );  
  115.     while( img_tst )  
  116.     {  
  117.         if( getline( img_tst, buf ) )  
  118.         {  
  119.             img_tst_path.push_back( buf );  
  120.         }  
  121.     }  
  122.     img_tst.close();  
  123.   
  124.     CvMat *tst_mat = cvCreateMat( 1, WIDTH*HEIGHT, CV_32FC1 );  
  125.     char line[512];  
  126.     ofstream predict_txt( "SVM_PREDICT.txt" );  
  127.     for( string::size_type j = 0; j != img_tst_path.size(); j++ )  
  128.     {  
  129.         tst = cvLoadImage( img_tst_path[j].c_str(), CV_LOAD_IMAGE_GRAYSCALE );  
  130.         if( tst == NULL )  
  131.         {  
  132.              cout<<" can not load the image: "<<img_tst_path[j].c_str()<<endl;  
  133.                continue;  
  134.    }  
  135.    tst_tmp = cvCreateImage( cvSize( WIDTH, HEIGHT ), IPL_DEPTH_8U, 1 );  
  136.    cvResize( tst, tst_tmp );  
  137.    cvSmooth( tst_tmp, tst_tmp );  
  138.    n = 0;  
  139.    for(int ii = 0; ii < tst_tmp->height; ii++ )  
  140.    {  
  141.      for(int jj = 0; jj < tst_tmp->width; jj++, n++ )  
  142.      {  
  143.          b = (float)(((int)((uchar)tst_tmp->imageData+tst_tmp->widthStep*ii+jj))/255.0);  
  144.          cvmSet( tst_mat, 0, n, (double)b );  
  145.      }  
  146.    }  
  147.   
  148.    int ret = svm.predict( tst_mat );  
  149.    sprintf( line, "%s %d\r\n", img_tst_path[j].c_str(), ret );  
  150.    predict_txt<<line;  
  151. }  
  152. predict_txt.close();  
  153.   
  154. cvReleaseImage( &srcImg );  
  155. cvReleaseImage( &sampleImg );  
  156. cvReleaseImage( &tst );  
  157. cvReleaseImage( &tst_tmp );  
  158. cvReleaseMat( &data_mat );  
  159. cvReleaseMat( &res_mat );  
  160.   
  161. return 0;  
  162. }  


其中

G:/program/pjSVM/face/1.png
0
G:/program/pjSVM/face/2.png
0
G:/program/pjSVM/face/3.png
0
G:/program/pjSVM/face/4.png
0
G:/program/pjSVM/face/5.png
0
G:/program/pjSVM/face/6.png
0
G:/program/pjSVM/face/7.png
0
G:/program/pjSVM/face/8.png
0
G:/program/pjSVM/face/9.png
0
G:/program/pjSVM/face/10.png
0
G:/program/pjSVM/face/11.png
0
G:/program/pjSVM/face/12.png
0
G:/program/pjSVM/face/13.png
0
G:/program/pjSVM/face/14.png
0
G:/program/pjSVM/face/15.png
1
G:/program/pjSVM/face/16.png
1
G:/program/pjSVM/face/17.png
1
G:/program/pjSVM/face/18.png
1
G:/program/pjSVM/face/19.png
1
G:/program/pjSVM/face/20.png
1
G:/program/pjSVM/face/21.png
1
G:/program/pjSVM/face/22.png
1
G:/program/pjSVM/face/23.png
1
G:/program/pjSVM/face/24.png
1
G:/program/pjSVM/face/25.png
1
G:/program/pjSVM/face/26.png
1
G:/program/pjSVM/face/27.png
1
G:/program/pjSVM/face/28.png
1
G:/program/pjSVM/face/29.png
1
G:/program/pjSVM/face/30.png

1

SVM_TEST.txt中内容如下:

G:/program/pjSVM/try_face/5.png
G:/program/pjSVM/try_face/9.png
G:/program/pjSVM/try_face/11.png
G:/program/pjSVM/try_face/15.png
G:/program/pjSVM/try_face/2.png
G:/program/pjSVM/try_face/30.png
G:/program/pjSVM/try_face/17.png
G:/program/pjSVM/try_face/21.png
G:/program/pjSVM/try_face/24.png
G:/program/pjSVM/try_face/27.png

PS:txt操作简单方式:http://blog.csdn.net/lytwell/article/details/6029503

你可能感兴趣的:(学习OpenCV——SVM)