OpenCv学习笔记---OpenCv中支持向量机模块SVM------源代码分析

/****************************************************************************************
                                 Support Vector Machines                              
****************************************************************************************/

// SVM training parameters
//【1】SVM的训练参数
//【2】在上一篇教程中,我们以--线性可分--的例子,简单讲解了SVM的基本原理。然而,SVM的实际应用情形可能
       //复杂的多(比如1--线性可分问题;2--非线性可分问题;3--SVM核函数的选择问题等等).总而言之,我
	   //们在训练之前,需要对SVM做一些参数设定,这类参数就保存在---CvSVMParams这个类中
struct CV_EXPORTS_W_MAP CvSVMParams
{
	//【1】CvSVMParams的默认构造函数
    CvSVMParams();
	//【2】CvSVMParams的带参构造函数
    CvSVMParams( int svm_type, int kernel_type,
                 double degree, double gamma, double coef0,
                 double Cvalue, double nu, double p,
                 CvMat* class_weights, CvTermCriteria term_crit );
    //【3】svm_type,SVM的类型
	        //【1】C_SVM----分类器---允许用异常值惩罚因子C进行不完全分类
			//【2】NU_SVC---类似然不完全分类的--分类器.参数nu取代了c,其值在区间[0,1]中,nu越大,
			      //决策边界越平滑
			//【3】ONE_CLASS--单分类器,所有饿训练数据提取自同一个类里,然后SVM建立了一个分界线以
			      //以分割该类在特征空间中所占区域与其他类在特征空间中所占区域
			//【4】EPS_SVR----回归--训练集中的特征向量和拟合出来的超平面的距离需要小于p.异常值惩
			     //罚因子C被采用
    CV_PROP_RW int         svm_type;
	//【4】kernel_type--核类型:
	        //【1】CvSVM::LINEAR---没有任何向量映射至高维空间,线性区分(或回归)在原始特征空间中被
			     //完成,这是最快的选择.d(x,y)=x*y=(x,y)
			//【2】CvSVM::POLY-----多项式核d(x,y)=(gamma*(x*y)+core0)degree
			//【3】CvSVM::RBF------径向基,这对大多数情况都是一个比较好的选择d(x,y)=exp(-gramma*|x-y|2)
			//【4】CvSVM::SIGMOID---sigmoid函数被用作核函数:d(x,y)=tanh(gamma*(x*y)+coref0)
    CV_PROP_RW int         kernel_type;
	//【5】degree,gramma,coref0都是核函数的参数,具体的参见上面的核函数方程
    CV_PROP_RW double      degree; // for poly
    CV_PROP_RW double      gamma;  // for poly/rbf/sigmoid
    CV_PROP_RW double      coef0;  // for poly/sigmoid
    //【6】C,nu,p---在一般的SVM优化求解时的参数
    CV_PROP_RW double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
    CV_PROP_RW double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
    CV_PROP_RW double      p; // for CV_SVM_EPS_SVR
	//【8】class_weights--可选权重,赋给指定的类别.一般乘以C以后去影响不同类别的错误分类惩罚项.
	       //权重越大,某一类别的误分类数据的惩罚项就越大
    CvMat*      class_weights; // for CV_SVM_C_SVC
	//【9】迭代训练过程的--终止--解决了部分受约束二次最优问题
    CV_PROP_RW CvTermCriteria term_crit; // termination criteria
};

//【1】CvSVM核函数类
struct CV_EXPORTS CvSVMKernel
{
    typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
                                       const float* another, float* results );
	//【1】核函数类的构造函数
    CvSVMKernel();
    CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
	//【2】
    virtual bool create( const CvSVMParams* params, Calc _calc_func );
	//【3】析构函数
    virtual ~CvSVMKernel();
    //【4】
    virtual void clear();
	//【5】
    virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
    //【6】指向CvSVM的参数类的---类对象指针
    const CvSVMParams* params;
    Calc calc_func;
    //【7】虚函数
    virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
                                    const float* another, float* results,
                                    double alpha, double beta );

    virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
                              const float* another, float* results );
    virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
                           const float* another, float* results );
    virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
                            const float* another, float* results );
    virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
                               const float* another, float* results );
};


struct CvSVMKernelRow
{
    CvSVMKernelRow* prev;
    CvSVMKernelRow* next;
    float* data;
};


struct CvSVMSolutionInfo
{
    double obj;
    double rho;
    double upper_bound_p;
    double upper_bound_n;
    double r;   // for Solver_NU
};

class CV_EXPORTS CvSVMSolver
{
public:
    typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
    typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
    typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );

    CvSVMSolver();

    CvSVMSolver( int count, int var_count, const float** samples, schar* y,
                 int alpha_count, double* alpha, double Cp, double Cn,
                 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
                 SelectWorkingSet select_working_set, CalcRho calc_rho );
    virtual bool create( int count, int var_count, const float** samples, schar* y,
                 int alpha_count, double* alpha, double Cp, double Cn,
                 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
                 SelectWorkingSet select_working_set, CalcRho calc_rho );
    virtual ~CvSVMSolver();

    virtual void clear();
    virtual bool solve_generic( CvSVMSolutionInfo& si );

    virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
                              double Cp, double Cn, CvMemStorage* storage,
                              CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
    virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
                               CvMemStorage* storage, CvSVMKernel* kernel,
                               double* alpha, CvSVMSolutionInfo& si );
    virtual bool solve_one_class( int count, int var_count, const float** samples,
                                  CvMemStorage* storage, CvSVMKernel* kernel,
                                  double* alpha, CvSVMSolutionInfo& si );

    virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
                                CvMemStorage* storage, CvSVMKernel* kernel,
                                double* alpha, CvSVMSolutionInfo& si );

    virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
                               CvMemStorage* storage, CvSVMKernel* kernel,
                               double* alpha, CvSVMSolutionInfo& si );

    virtual float* get_row_base( int i, bool* _existed );
    virtual float* get_row( int i, float* dst );

    int sample_count;
    int var_count;
    int cache_size;
    int cache_line_size;
    const float** samples;
    const CvSVMParams* params;
    CvMemStorage* storage;
    CvSVMKernelRow lru_list;
    CvSVMKernelRow* rows;

    int alpha_count;

    double* G;
    double* alpha;

    // -1 - lower bound, 0 - free, 1 - upper bound
    schar* alpha_status;

    schar* y;
    double* b;
    float* buf[2];
    double eps;
    int max_iter;
    double C[2];  // C[0] == Cn, C[1] == Cp
    CvSVMKernel* kernel;

    SelectWorkingSet select_working_set_func;
    CalcRho calc_rho_func;
    GetRow get_row_func;

    virtual bool select_working_set( int& i, int& j );
    virtual bool select_working_set_nu_svm( int& i, int& j );
    virtual void calc_rho( double& rho, double& r );
    virtual void calc_rho_nu_svm( double& rho, double& r );

    virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
    virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
    virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
};


struct CvSVMDecisionFunc
{
    double rho;
    int sv_count;
    double* alpha;
    int* sv_index;
};


// SVM model
//【1】支持向量机CvSVM,继承自基类CvStatModel

class CV_EXPORTS_W CvSVM : public CvStatModel
{
public:
    // SVM type
	//【1】SVM的类型
	//【2】如果选择SVC--则是分类器
	//【3】如果选择SVR--则SVR是SVM的回归
	//【1】C_SVC----分类器---允许用异常值惩罚因子C进行不完全分类
	//【2】NU_SVC---类似然不完全分类的--分类器.参数nu取代了c,其值在区间[0,1]中,nu越大,
		   //决策边界越平滑
	//【3】ONE_CLASS--单分类器,所有饿训练数据提取自同一个类里,然后SVM建立了一个分界线以
		   //以分割该类在特征空间中所占区域与其他类在特征空间中所占区域
	//【4】EPS_SVR----回归--训练集中的特征向量和拟合出来的超平面的距离需要小于p.异常值惩
		   //罚因子C被采用
    enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };

    // SVM kernel type
	//【2】SVM提供四种核函数,分别是:
	        //【1】LINEAR----线性
			//【2】POLY------多项式
			//【3】RBF-------径向基
			//【4】SIGMOID---sigmoid型函数
    enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };

    // SVM params type
	//【3】SVM的参数类型
	        //【1】
			//【2】
    enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
    //【4】CvSVM的默认构造函数和析构函数
    CV_WRAP CvSVM();
    virtual ~CvSVM();
    //【5】CvSVM的带参构造函数
    CvSVM( const CvMat* trainData, const CvMat* responses,
           const CvMat* varIdx=0, const CvMat* sampleIdx=0,
           CvSVMParams params=CvSVMParams() );
    //【6】训练支持向量机,调用CvSVM::train来建立SVM模型
	//【7】该方法训练支持向量机模型,它遵循的泛型“方法”约定具有如下的限制:
			//【1】仅仅支持CV_ROW_SAMPLE--行样本的数据布局
			//【2】所有的输入变量总是有序的
			//【3】所有的params参数都由CvSVMParams结构体收集
    virtual bool train( const CvMat* trainData, const CvMat* responses,
                        const CvMat* varIdx=0, const CvMat* sampleIdx=0,
                        CvSVMParams params=CvSVMParams() );
    //【8】使用最佳的,最理想的参数训练SVM支持向量机模型
    virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
        const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
        int kfold = 10,
        CvParamGrid Cgrid      = get_default_grid(CvSVM::C),
        CvParamGrid gammaGrid  = get_default_grid(CvSVM::GAMMA),
        CvParamGrid pGrid      = get_default_grid(CvSVM::P),
        CvParamGrid nuGrid     = get_default_grid(CvSVM::NU),
        CvParamGrid coeffGrid  = get_default_grid(CvSVM::COEF),
        CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
        bool balanced=false );
    //【8】函数CvSVM::predit通过重建完毕的支持向量机来将输入的样本分类.本例中,我们通过该函数给向量空间着色,以及
	      //将图像中的每个像素当做笛卡尔平面上的一点,每一点的着色取决于SVM对该点的分类类别:绿色表示标记为1的点,
		  //蓝色表示标记为-1的点
    virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
    virtual float predict( const CvMat* samples, CV_OUT CvMat* results ) const;
   
    CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
          const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
          CvSVMParams params=CvSVMParams() );

    CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
                       const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
                       CvSVMParams params=CvSVMParams() );

    CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
                            const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
                            int k_fold = 10,
                            CvParamGrid Cgrid      = CvSVM::get_default_grid(CvSVM::C),
                            CvParamGrid gammaGrid  = CvSVM::get_default_grid(CvSVM::GAMMA),
                            CvParamGrid pGrid      = CvSVM::get_default_grid(CvSVM::P),
                            CvParamGrid nuGrid     = CvSVM::get_default_grid(CvSVM::NU),
                            CvParamGrid coeffGrid  = CvSVM::get_default_grid(CvSVM::COEF),
                            CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
                            bool balanced=false);
    CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
    CV_WRAP_AS(predict_all) void predict( cv::InputArray samples, cv::OutputArray results ) const;
    //【10】得到支持向量的个数
    CV_WRAP virtual int get_support_vector_count() const;
    virtual const float* get_support_vector(int i) const;
    virtual CvSVMParams get_params() const { return params; };
    CV_WRAP virtual void clear();

    static CvParamGrid get_default_grid( int param_id );

    virtual void write( CvFileStorage* storage, const char* name ) const;
    virtual void read( CvFileStorage* storage, CvFileNode* node );
    CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }

protected:

    virtual bool set_params( const CvSVMParams& params );
    virtual bool train1( int sample_count, int var_count, const float** samples,
                    const void* responses, double Cp, double Cn,
                    CvMemStorage* _storage, double* alpha, double& rho );
    virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
                    const CvMat* responses, CvMemStorage* _storage, double* alpha );
    virtual void create_kernel();
    virtual void create_solver();

    virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;

    virtual void write_params( CvFileStorage* fs ) const;
    virtual void read_params( CvFileStorage* fs, CvFileNode* node );

    void optimize_linear_svm();

    CvSVMParams params;
    CvMat* class_labels;
    int var_all;
    float** sv;
    int sv_total;
    CvMat* var_idx;
    CvMat* class_weights;
    CvSVMDecisionFunc* decision_func;
    CvMemStorage* storage;

    CvSVMSolver* solver;
    CvSVMKernel* kernel;

private:
    CvSVM(const CvSVM&);
    CvSVM& operator = (const CvSVM&);
};

你可能感兴趣的:(OpenCv学习笔记---OpenCv中支持向量机模块SVM------源代码分析)