目前需要提炼下ml部分的接口。目的是以后方便选择用哪种分类器。还是一头雾水啊。。。学到哪先记录到哪。
一。以CvSVM为例。下面是CvSVM类的定义:
class CV_EXPORTS_W CvSVM : public CvStatModel { public: // SVM type enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 }; // SVM kernel type enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 }; // SVM params type enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 }; CV_WRAP CvSVM(); virtual ~CvSVM(); CvSVM( const CvMat* trainData, const CvMat* responses, const CvMat* varIdx=0, const CvMat* sampleIdx=0, CvSVMParams params=CvSVMParams() ); virtual bool train( const CvMat* trainData, const CvMat* responses, const CvMat* varIdx=0, const CvMat* sampleIdx=0,//这两个参数好像不太用 CvSVMParams params=CvSVMParams() ); virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;// virtual float predict( const CvMat* samples, CV_OUT CvMat* results ) const; CV_WRAP virtual int get_support_vector_count() const; virtual const float* get_support_vector(int i) const; virtual CvSVMParams get_params() const { return params; }; CV_WRAP virtual void clear(); static CvParamGrid get_default_grid( int param_id ); virtual void write( CvFileStorage* storage, const char* name ) const; virtual void read( CvFileStorage* storage, CvFileNode* node ); CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; } protected: virtual bool set_params( const CvSVMParams& params ); virtual bool train1( int sample_count, int var_count, const float** samples, const void* responses, double Cp, double Cn, CvMemStorage* _storage, double* alpha, double& rho ); virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples, const CvMat* responses, CvMemStorage* _storage, double* alpha ); virtual void create_kernel(); virtual void create_solver(); virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const; virtual void write_params( CvFileStorage* fs ) const; virtual void read_params( CvFileStorage* fs, CvFileNode* node ); CvSVMParams params; CvMat* class_labels; int var_all; float** sv; int sv_total; CvMat* var_idx; CvMat* class_weights; CvSVMDecisionFunc* decision_func; CvMemStorage* storage; CvSVMSolver* solver; CvSVMKernel* kernel; };
#include <opencv2/core/core.hpp> #include <opencv2/highgui/highgui.hpp> #include <opencv2/ml/ml.hpp> using namespace cv; int main() { // Data for visual representation int width = 512, height = 512; Mat image = Mat::zeros(height, width, CV_8UC3); // Set up training data float labels[4] = {1.0, -1.0, -1.0, -1.0}; Mat labelsMat(4, 1, CV_32FC1, labels);//对应于接口的_response float trainingData[4][2] = { {501, 10}, {255, 10}, {501, 255}, {10, 501} }; Mat trainingDataMat(4, 2, CV_32FC1, trainingData);//对应于接口的_train_data // Set up SVM's parameters CvSVMParams params; params.svm_type = CvSVM::C_SVC; params.kernel_type = CvSVM::LINEAR; params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6); // Train the SVM CvSVM SVM; SVM.train(trainingDataMat, labelsMat, Mat(), Mat(), params); Vec3b green(0,255,0), blue (255,0,0); // Show the decision regions given by the SVM for (int i = 0; i < image.rows; ++i) for (int j = 0; j < image.cols; ++j) { Mat sampleMat = (Mat_<float>(1,2) << i,j); float response = SVM.predict(sampleMat); if (response == 1) image.at<Vec3b>(j, i) = green; else if (response == -1) image.at<Vec3b>(j, i) = blue; } // Show the training data int thickness = -1; int lineType = 8; circle( image, Point(501, 10), 5, Scalar( 0, 0, 0), thickness, lineType); circle( image, Point(255, 10), 5, Scalar(255, 255, 255), thickness, lineType); circle( image, Point(501, 255), 5, Scalar(255, 255, 255), thickness, lineType); circle( image, Point( 10, 501), 5, Scalar(255, 255, 255), thickness, lineType); // Show support vectors thickness = 2; lineType = 8; int c = SVM.get_support_vector_count(); for (int i = 0; i < c; ++i) { const float* v = SVM.get_support_vector(i); circle( image, Point( (int) v[0], (int) v[1]), 6, Scalar(128, 128, 128), thickness, lineType); } imwrite("result.png", image); // save the image imshow("SVM Simple Example", image); // show it to the user waitKey(0); }
bool CvCascadeBoost::train( const CvFeatureEvaluator* _featureEvaluator,//包含了sum,tilted,特征的位置等信息 int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize, const CvCascadeBoostParams& _params )
bool CvBoostTree::train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvBoost* _ensemble )
virtual bool train( const CvMat* train_data, [int tflag,] ..., const CvMat* responses, ..., [const CvMat* var_idx,] ..., [const CvMat* sample_idx,] ... [const CvMat* var_type,] ..., [const CvMat* missing_mask,] <misc_training_alg_params> ... )=0;
featureEvaluator->init( (CvFeatureParams*)featureParams, numPos + numNeg, cascadeParams.winSize );
bool CvCascadeClassifier::updateTrainingSet( double& acceptanceRatio)//featureEvaluator->setImage( img, isPositive ? 1 : 0, i );
data = new CvCascadeBoostTrainData( _featureEvaluator, _numSamples, _precalcValBufSize, _precalcIdxBufSize, _params );
CvCascadeBoostTree* tree = new CvCascadeBoostTree; if( !tree->train( data, subsample_mask, this ) )//应该是训练一个弱分类器tree { delete tree; break; } cvSeqPush( weak, &tree );//把弱分类器添加到强分类器里面