下面这是opencv官方文档中的代码,我加了一部分注释:
1 #include "stdafx.h" 2 #include "opencv2/core/core.hpp" 3 #include "highgui.h" 4 #include "ml.h" 5 6 using namespace cv; 7 8 int _tmain(int argc, _TCHAR* argv[]) 9 { 10 // 11 int width = 512, height = 512; 12 Mat image = Mat::zeros(height, width, CV_8UC3); 13 14 // set up training data 15 float labels[4] = {1.0, 1.0, -1.0, -1.0}; 16 Mat labelsMat(4, 1, CV_32FC1, labels); 17 18 float trainingData[4][2] = { {501, 10}, {255, 10}, {501, 255}, {10, 501} }; 19 Mat trainingDataMat(4, 2, CV_32FC1, trainingData); 20 21 // set up SVM's parameters,具体参数设置请看下文 22 CvSVMParams params; 23 params.svm_type = CvSVM::C_SVC; 24 params.kernel_type = CvSVM::LINEAR; 25 params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6); 26 27 // train the svm 28 CvSVM SVM; 29 SVM.train(trainingDataMat, labelsMat, Mat(), Mat(), params); 30 31 Vec3b green(0,255,0), blue(255,0,0); 32 33 // show the decision region given by the SVM 34 for (int i = 0; i < image.rows; ++ i) 35 { 36 for (int j = 0; j < image.cols; ++ j) 37 { 38 Mat sampleMat = (Mat_<float>(1,2) << i,j); 39 40 // predict 函数使用训练好的SVM模型对一个输入的样本进行分类 41 float response = SVM.predict(sampleMat); 42 43 if (response == 1) 44 { 45 // 注意这里是(j,i),不是(i,j) 46 image.at(j,i) = green; 47 } 48 else 49 { 50 // 同上 51 image.at (j,i) = blue; 52 } 53 } 54 } 55 56 int thickness = -1; 57 int lineType = 8; 58 59 circle(image, Point(501, 10), 5, Scalar( 0, 0, 0), thickness, lineType); 60 circle(image, Point(255, 10), 5, Scalar( 0, 0, 0), thickness, lineType); 61 circle(image, Point(501, 255), 5, Scalar(255,255,255), thickness, lineType); 62 circle(image, Point( 10, 501), 5, Scalar(255,255,255), thickness, lineType); 63 64 // show support vectors 65 thickness = 2; 66 lineType = 8; 67 68 // 获得当前的支持向量的个数 69 int c = SVM.get_support_vector_count(); 70 71 for (int i = 0; i < c; ++ i) 72 { 73 const float* v = SVM.get_support_vector(i); 74 circle( image, Point( (int) v[0], (int) v[1]), 6, Scalar(128, 128, 128), thickness, lineType); 75 } 76 77 imwrite("result.png", image); // save the image 78 79 imshow("SVM Simple Example", image); // show it to the user 80 waitKey(0); 81 return 0; 82 }
这里说一下CvSVMParams中的参数设置
1 CV_SVM 中的参数设置 2 3 svm_type: 4 CvSVM::C_SVC C-SVC 5 CvSVM::NU_SVC v-SVC 6 SvSVM::ONE_CLASS 一类SVM 7 CvSVM::EPS_SVR e-SVR 8 CvSVM::NU_SVR v-SVR 9 10 kernel_type: 11 CvSVM::LINEAR 线性:u*v 12 CvSVM::POLY 多项式(r*u'v + coef0)^degree 13 CvSVM::RBF RBF函数: exp(-r|u-v|^2) 14 CvSVM::SIGMOID sigmoid函数: tanh(r*u'v + coef0) 15 16 成员变量 17 degree: 针对多项式核函数degree的设置 18 gamma: 针对多项式/rbf/sigmoid核函数的设置 19 coef0: 针对多项式/sigmoid核函数的设置 20 Cvalue: 为损失函数,在C-SVC、e-SVR、v-SVR中有效 21 nu: 设置v-SVC、一类SVM和v-SVR参数 22 p: 为设置e-SVR中损失函数的值 23 class_weights: C_SVC的权重 24 term_crit: 为SVM训练过程的终止条件。 25 其中默认值 degree = 0, 26 gamma = 1, 27 coef0 = 0, 28 Cvalue = 1, 29 nu = 0, 30 p = 0, 31 class_weights = 0