OpenCV的支持向量机SVM的程序

转自http://blog.csdn.net/firefight/article/details/6400060

为了学习OPENCV SVM分类器, 参考网上的"利用SVM解决2维空间向量的分类问题"实现并改为C++代码,仅供参考

环境:OPENCV2.4.9 + VS2013

步骤:
1,生成随机的点,并按一定的空间分布将其归类
2,创建SVM并利用随机点样本进行训练
3,将整个空间按SVM分类结果进行划分,并显示支持向量

  copy

  1. #include "iostream"
    #include <opencv2/opencv.hpp>
    #include <opencv2/ml/ml.hpp>
    using namespace cv;
    using namespace std;
  2.    
  3. void drawCross(Mat &img, Point center, Scalar color)  
  4. {  
  5.     int col = center.x > 2 ? center.x : 2;  
  6.     int row = center.y> 2 ? center.y : 2;  
  7.   
  8.     line(img, Point(col -2, row - 2), Point(col + 2, row + 2), color);    
  9.     line(img, Point(col + 2, row - 2), Point(col - 2, row + 2), color);    
  10. }  
  11.   
  12. int newSvmTest(int rows, int cols, int testCount)  
  13. {  
  14.     if(testCount > rows * cols)  
  15.         return 0;  
  16.   
  17.     Mat img = Mat::zeros(rows, cols, CV_8UC3);  
  18.     Mat testPoint = Mat::zeros(rows, cols, CV_8UC1);  
  19.     Mat data = Mat::zeros(testCount, 2, CV_32FC1);  
  20.     Mat res = Mat::zeros(testCount, 1, CV_32SC1);  
  21.   
  22.     //Create random test points  
  23.     for (int i= 0; i< testCount; i++)   
  24.     {   
  25.         int row = rand() % rows;  
  26.         int col = rand() % cols;  
  27.   
  28.         if(testPoint.at<unsigned char>(row, col) == 0)  
  29.         {  
  30.             testPoint.at<unsigned char>(row, col) = 1;  
  31.             data.at<float>(i, 0) = float (col) / cols;   
  32.             data.at<float>(i, 1) = float (row) / rows;   
  33.         }  
  34.         else  
  35.         {  
  36.             i--;  
  37.             continue;  
  38.         }  
  39.   
  40.         if (row > ( 50 * cos(col * CV_PI/ 100) + 200) )  
  41.         {   
  42.             drawCross(img, Point(col, row), CV_RGB(255, 0, 0));  
  43.             res.at<unsigned int>(i, 0) = 1;   
  44.         }   
  45.         else   
  46.         {   
  47.             if (col > 200)   
  48.             {   
  49.                 drawCross(img, Point(col, row), CV_RGB(0, 255, 0));  
  50.                 res.at<unsigned int>(i, 0) = 2;   
  51.             }   
  52.             else   
  53.             {   
  54.                 drawCross(img, Point(col, row), CV_RGB(0, 0, 255));  
  55.                 res.at<unsigned int>(i, 0) = 3;   
  56.             }   
  57.         }   
  58.   
  59.     }  
  60.   
  61.     //Show test points  
  62.     imshow("随机数", img);  
  63.     
  64.   
  65.     /////////////START SVM TRAINNING//////////////////  
  66.     CvSVM svm = CvSVM();   
  67.     CvSVMParams param;   
  68.     CvTermCriteria criteria;  
  69.   
  70.     criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);   
  71.     param= CvSVMParams (CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);   
  72.   
  73.     svm.train(data, res, Mat(), Mat(), param);  
  74.   
  75.     for (int i= 0; i< rows; i++)   
  76.     {   
  77.         for (int j= 0; j< cols; j++)   
  78.         {   
  79.             Mat m = Mat::zeros(1, 2, CV_32FC1);  
  80.             m.at<float>(0,0) = float (j) / cols;  
  81.             m.at<float>(0,1) = float (i) / rows;  
  82.   
  83.             float ret = 0.0;   
  84.             ret = svm.predict(m);   
  85.             Scalar rcolor;   
  86.   
  87.             switch ((int) ret)   
  88.             {   
  89.                 case 1: rcolor= CV_RGB(100, 0, 0); break;   
  90.                 case 2: rcolor= CV_RGB(0, 100, 0); break;   
  91.                 case 3: rcolor= CV_RGB(0, 0, 100); break;   
  92.             }   
  93.   
  94.             line(img, Point(j,i), Point(j,i), rcolor);  
  95.         }   
  96.     }  
  97.   
  98.     imshow("分区效果", img);  
  99.     //waitKey(0);  
  100.   
  101.     //Show support vectors  
  102.     int sv_num= svm.get_support_vector_count();   
  103.     for (int i= 0; i< sv_num; i++)   
  104.     {   
  105.         const float* support = svm.get_support_vector(i);   
  106.         circle(img, Point((int) (support[0] * cols), (int) (support[1] * rows)), 5, CV_RGB(200, 200, 200));   
  107.     }  
  108.   
  109.     imshow("dst", img);  
  110.     waitKey(0);  
  111.   
  112.     return 0;  
  113. }  
  114.   
  115. int main(int argc, char** argv)  
  116. {  
  117.     return newSvmTest(400, 600, 100);  
  118. }  

OpenCV的支持向量机SVM的程序_第1张图片

OpenCV的支持向量机SVM的程序_第2张图片



你可能感兴趣的:(opencv,支持向量机)