Normal Bayes 分类器过程详解

OpenCV的机器学习类定义在ml.hpp文件中,基础类是CvStatModel,其他各种分类器从这里继承而来。

今天研究CvNormalBayesClassifier分类器。

1.类定义

在ml.hpp中有以下类定义:

[cpp]  view plain copy print ?
  1. class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel  
  2. {  
  3. public:  
  4.     CV_WRAP CvNormalBayesClassifier();  
  5.     virtual ~CvNormalBayesClassifier();  
  6.   
  7.     CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,  
  8.         const CvMat* varIdx=0, const CvMat* sampleIdx=0 );  
  9.   
  10.     virtual bool train( const CvMat* trainData, const CvMat* responses,  
  11.         const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );  
  12.   
  13.     virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;  
  14.     CV_WRAP virtual void clear();  
  15.   
  16.     CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,  
  17.                             const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );  
  18.     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,  
  19.                        const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),  
  20.                        bool update=false );  
  21.     CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;  
  22.   
  23.     virtual void write( CvFileStorage* storage, const char* name ) const;  
  24.     virtual void read( CvFileStorage* storage, CvFileNode* node );  
  25.   
  26. protected:  
  27.     int     var_count, var_all;  
  28.     CvMat*  var_idx;  
  29.     CvMat*  cls_labels;  
  30.     CvMat** count;  
  31.     CvMat** sum;  
  32.     CvMat** productsum;  
  33.     CvMat** avg;  
  34.     CvMat** inv_eigen_values;  
  35.     CvMat** cov_rotate_mats;  
  36.     CvMat*  c;  
  37. };  

2.示例

此类使用方法如下:(引用别人的代码,忘记出处了,非常抱歉这个。。。)

[cpp]  view plain copy print ?
  1. //openCV中贝叶斯分类器的API函数用法举例  
  2. //运行环境:win7 + VS2005 + openCV2.4.5  
  3.   
  4. #include "global_include.h"  
  5.   
  6. using namespace std;  
  7. using namespace cv;  
  8.   
  9. //10个样本特征向量维数为12的训练样本集,第一列为该样本的类别标签  
  10. double inputArr[10][13] =   
  11. {  
  12.      1,0.708333,1,1,-0.320755,-0.105023,-1,1,-0.419847,-1,-0.225806,0,1,   
  13.     -1,0.583333,-1,0.333333,-0.603774,1,-1,1,0.358779,-1,-0.483871,0,-1,  
  14.      1,0.166667,1,-0.333333,-0.433962,-0.383562,-1,-1,0.0687023,-1,-0.903226,-1,-1,  
  15.     -1,0.458333,1,1,-0.358491,-0.374429,-1,-1,-0.480916,1,-0.935484,0,-0.333333,  
  16.     -1,0.875,-1,-0.333333,-0.509434,-0.347032,-1,1,-0.236641,1,-0.935484,-1,-0.333333,  
  17.     -1,0.5,1,1,-0.509434,-0.767123,-1,-1,0.0534351,-1,-0.870968,-1,-1,  
  18.      1,0.125,1,0.333333,-0.320755,-0.406393,1,1,0.0839695,1,-0.806452,0,-0.333333,  
  19.      1,0.25,1,1,-0.698113,-0.484018,-1,1,0.0839695,1,-0.612903,0,-0.333333,  
  20.      1,0.291667,1,1,-0.132075,-0.237443,-1,1,0.51145,-1,-0.612903,0,0.333333,  
  21.      1,0.416667,-1,1,0.0566038,0.283105,-1,1,0.267176,-1,0.290323,0,1  
  22. };  
  23.   
  24. //一个测试样本的特征向量  
  25. double testArr[]=  
  26. {  
  27.     0.25,1,1,-0.226415,-0.506849,-1,-1,0.374046,-1,-0.83871,0,-1  
  28. };  
  29.   
  30.   
  31. int _tmain(int argc, _TCHAR* argv[])  
  32. {  
  33.     Mat trainData(10, 12, CV_32FC1);//构建训练样本的特征向量  
  34.     for (int i=0; i<10; i++)  
  35.     {  
  36.         for (int j=0; j<12; j++)  
  37.         {  
  38.             trainData.at<float>(i, j) = inputArr[i][j+1];  
  39.         }  
  40.     }  
  41.   
  42.     Mat trainResponse(10, 1, CV_32FC1);//构建训练样本的类别标签  
  43.     for (int i=0; i<10; i++)  
  44.     {  
  45.         trainResponse.at<float>(i, 0) = inputArr[i][0];  
  46.     }  
  47.   
  48.     CvNormalBayesClassifier nbc;  
  49.     bool trainFlag = nbc.train(trainData, trainResponse);//进行贝叶斯分类器训练  
  50.     if (trainFlag)  
  51.     {  
  52.         cout<<"train over..."<<endl;  
  53.         nbc.save("normalBayes.txt");  
  54.     }  
  55.     else  
  56.     {  
  57.         cout<<"train error..."<<endl;  
  58.         system("pause");  
  59.         exit(-1);  
  60.     }  
  61.   
  62.   
  63.     CvNormalBayesClassifier testNbc;  
  64.     testNbc.load("normalBayes.txt");  
  65.   
  66.     Mat testSample(1, 12, CV_32FC1);//构建测试样本  
  67.     for (int i=0; i<12; i++)  
  68.     {  
  69.         testSample.at<float>(0, i) = testArr[i];  
  70.     }  
  71.   
  72.     float flag = testNbc.predict(testSample);//进行测试  
  73.     cout<<"flag = "<<flag<<endl;  
  74.   
  75.     system("pause");  
  76.     return 0;  
  77. }  

3.步骤

两步走:

1.调用train函数训练分类器;

2.调用predict函数,判定测试样本的类别。

以上示例代码还延时了怎样使用save和load函数,使得训练好的分类器可以保存在文本中。

4.初始化

接下来,看CvNormalBayesClassifier类的无参数初始化:

[cpp]  view plain copy print ?
  1. CvNormalBayesClassifier::CvNormalBayesClassifier()  
  2. {  
  3.     var_count = var_all = 0;  
  4.     var_idx = 0;  
  5.     cls_labels = 0;  
  6.     count = 0;  
  7.     sum = 0;  
  8.     productsum = 0;  
  9.     avg = 0;  
  10.     inv_eigen_values = 0;  
  11.     cov_rotate_mats = 0;  
  12.     c = 0;  
  13.     default_model_name = "my_nb";  
  14. }  
还有另一种带参数的初始化形式:
[cpp]  view plain copy print ?
  1. CvNormalBayesClassifier::CvNormalBayesClassifier(  
  2.     const CvMat* _train_data, const CvMat* _responses,  
  3.     const CvMat* _var_idx, const CvMat* _sample_idx )  
  4. {  
  5.     var_count = var_all = 0;  
  6.     var_idx = 0;  
  7.     cls_labels = 0;  
  8.     count = 0;  
  9.     sum = 0;  
  10.     productsum = 0;  
  11.     avg = 0;  
  12.     inv_eigen_values = 0;  
  13.     cov_rotate_mats = 0;  
  14.     c = 0;  
  15.     default_model_name = "my_nb";  
  16.   
  17.     train( _train_data, _responses, _var_idx, _sample_idx );  
  18. }  
可见,带参数形式糅合了类的初始化和train函数。

另外,以Mat参数形式的对应函数版本,功能是一致的,只不过为了体现2.0以后版本的C++特性罢了。如下:

[cpp]  view plain copy print ?
  1. CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,  
  2.                         const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );  
  3. CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,  
  4.                    const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),  
  5.                    bool update=false );  
  6. CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;  

5.训练

下面开始分析train函数,分析CvMat格式参数的train函数,即:

[cpp]  view plain copy print ?
  1. bool train( const CvMat* trainData, const CvMat* responses,const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );  

在进入该函数之前,还要先回头看看CvNormalBayesClassifier类有哪些数据成员:

[cpp]  view plain copy print ?
  1. protected:  
  2.     int     var_count, var_all; //每个样本的特征维数、即变量数目,或者说trainData的列数目(在varIdx=0时)  
  3.     CvMat*  var_idx;        //特征子集的索引,可能特征数目为100,但是只用其中一部分训练  
  4.     CvMat*  cls_labels;     //类别数目  
  5.     CvMat** count;      //count[0...(classNum-1)],每个元素是一个CvMat(rows=1,cols=var_count)指针,代表训练数据中每一类的某个特征的数目  
  6.     CvMat** sum;        //sum[0...(classNum-1)],每个元素是一个CvMat(rows=1,cols=var_count)指针,代表训练数据中每一类的某个特征的累加和  
  7.     CvMat** productsum;     //productsum[0...(classNum-1)],每个元素是一个CvMat(rows=cols=var_count)指针,存储类内特征相关矩阵  
  8.     CvMat** avg;        //avg[0...(classNum-1)],每个元素是一个CvMat(rows=1,cols=var_count)指针,代表训练数据中每一类的某个特征的平均值  
  9.     CvMat** inv_eigen_values;//inv_eigen_values[0...(classNum-1)],每个元素是一个CvMat(rows=1,cols=var_count)指针,代表训练数据中每一类的某个特征的特征值的倒数  
  10.     CvMat** cov_rotate_mats;    //特征变量的协方差矩阵经过SVD奇异值分解后得到的特征向量矩阵  
  11.     CvMat*  c;  

这些数据成员,怎样使用呢?在train函数中见分晓:

[cpp]  view plain copy print ?
  1. bool CvNormalBayesClassifier::train( const CvMat* _train_data, const CvMat* _responses,  
  2.                                     const CvMat* _var_idx, const CvMat* _sample_idx, bool update )  
  3. {  
  4.     const float min_variation = FLT_EPSILON;  
  5.     bool result = false;  
  6.     CvMat* responses   = 0;  
  7.     const float** train_data = 0;  
  8.     CvMat* __cls_labels = 0;  
  9.     CvMat* __var_idx = 0;  
  10.     CvMat* cov = 0;  
  11.   
  12.     CV_FUNCNAME( "CvNormalBayesClassifier::train" );  
  13.   
  14.     __BEGIN__;  
  15.   
  16.     int cls, nsamples = 0, _var_count = 0, _var_all = 0, nclasses = 0;  
  17.     int s, c1, c2;  
  18.     const int* responses_data;  
  19.   
  20.     //1.整理训练数据  
  21.     CV_CALL( cvPrepareTrainData( 0,  
  22.         _train_data, CV_ROW_SAMPLE, _responses, CV_VAR_CATEGORICAL,  
  23.         _var_idx, _sample_idx, false, &train_data,  
  24.         &nsamples, &_var_count, &_var_all, &responses,  
  25.         &__cls_labels, &__var_idx ));  
  26.   
  27.     if( !update )   //如果是初始训练数据  
  28.     {  
  29.         const size_t mat_size = sizeof(CvMat*);  
  30.         size_t data_size;  
  31.   
  32.         clear();  
  33.   
  34.         var_idx = __var_idx;  
  35.         cls_labels = __cls_labels;  
  36.         __var_idx = __cls_labels = 0;  
  37.         var_count = _var_count;  
  38.         var_all = _var_all;  
  39.   
  40.         nclasses = cls_labels->cols;  
  41.         data_size = nclasses*6*mat_size;  
  42.   
  43.         CV_CALL( count = (CvMat**)cvAlloc( data_size ));  
  44.         memset( count, 0, data_size );          //count[cls]存储第cls类每个属性变量个数  
  45.                                         
  46.         sum             = count      + nclasses;//sum[cls]存储第cls类每个属性取值的累加和  
  47.         productsum      = sum        + nclasses;//productsum[cls]存储第cls类的协方差矩阵的乘积项sum(XiXj),cov(Xi,Xj)=sum(XiXj)-sum(Xi)E(Xj)  
  48.         avg             = productsum + nclasses;//avg[cls]存储第cls类的每个变量均值  
  49.         inv_eigen_values= avg        + nclasses;//inv_eigen_values[cls]存储第cls类的协方差矩阵的特征值  
  50.         cov_rotate_mats = inv_eigen_values         + nclasses;//存储第cls类的矩阵的特征值对应的特征向量  
  51.   
  52.         CV_CALL( c = cvCreateMat( 1, nclasses, CV_64FC1 ));  
  53.           
  54.         for( cls = 0; cls < nclasses; cls++ )    //对所有类别  
  55.         {  
  56.             CV_CALL(count[cls]            = cvCreateMat( 1, var_count, CV_32SC1 ));  
  57.             CV_CALL(sum[cls]              = cvCreateMat( 1, var_count, CV_64FC1 ));  
  58.             CV_CALL(productsum[cls]       = cvCreateMat( var_count, var_count, CV_64FC1 ));  
  59.             CV_CALL(avg[cls]              = cvCreateMat( 1, var_count, CV_64FC1 ));  
  60.             CV_CALL(inv_eigen_values[cls] = cvCreateMat( 1, var_count, CV_64FC1 ));  
  61.             CV_CALL(cov_rotate_mats[cls]  = cvCreateMat( var_count, var_count, CV_64FC1 ));  
  62.             CV_CALL(cvZero( count[cls] ));  
  63.             CV_CALL(cvZero( sum[cls] ));  
  64.             CV_CALL(cvZero( productsum[cls] ));  
  65.             CV_CALL(cvZero( avg[cls] ));  
  66.             CV_CALL(cvZero( inv_eigen_values[cls] ));  
  67.             CV_CALL(cvZero( cov_rotate_mats[cls] ));  
  68.         }  
  69.     }  
  70.     else    //如果是更新训练数据  
  71.     {  
  72.         // check that the new training data has the same dimensionality etc.  
  73.         if( _var_count != var_count || _var_all != var_all || !((!_var_idx && !var_idx) ||  
  74.             (_var_idx && var_idx && cvNorm(_var_idx,var_idx,CV_C) < DBL_EPSILON)) )  
  75.             CV_ERROR( CV_StsBadArg,  
  76.             "The new training data is inconsistent with the original training data" );  
  77.   
  78.         if( cls_labels->cols != __cls_labels->cols ||  
  79.             cvNorm(cls_labels, __cls_labels, CV_C) > DBL_EPSILON )  
  80.             CV_ERROR( CV_StsNotImplemented,  
  81.             "In the current implementation the new training data must have absolutely "  
  82.             "the same set of class labels as used in the original training data" );  
  83.   
  84.         nclasses = cls_labels->cols;  
  85.     }  
  86.   
  87.     responses_data = responses->data.i;  
  88.     CV_CALL( cov = cvCreateMat( _var_count, _var_count, CV_64FC1 ));  
  89.   
  90.     //2.处理训练数据,计算每一类的  
  91.     // process train data (count, sum , productsum)   
  92.     for( s = 0; s < nsamples; s++ )  
  93.     {  
  94.         cls = responses_data[s];  
  95.         int* count_data = count[cls]->data.i;  
  96.         double* sum_data = sum[cls]->data.db;  
  97.         double* prod_data = productsum[cls]->data.db;  
  98.         const float* train_vec = train_data[s];  
  99.   
  100.         for( c1 = 0; c1 < _var_count; c1++, prod_data += _var_count )  
  101.         {  
  102.             double val1 = train_vec[c1];  
  103.             sum_data[c1] += val1;  
  104.             count_data[c1]++;  
  105.             for( c2 = c1; c2 < _var_count; c2++ )  
  106.                 prod_data[c2] += train_vec[c2]*val1;  
  107.         }  
  108.     }  
  109.   
  110.     //计算每一类的每个属性平均值、协方差矩阵  
  111.     // calculate avg, covariance matrix, c  
  112.     for( cls = 0; cls < nclasses; cls++ )    //对每一类  
  113.     {  
  114.         double det = 1;  
  115.         int i, j;  
  116.         CvMat* w = inv_eigen_values[cls];  
  117.         int* count_data = count[cls]->data.i;  
  118.         double* avg_data = avg[cls]->data.db;  
  119.         double* sum1 = sum[cls]->data.db;  
  120.   
  121.         cvCompleteSymm( productsum[cls], 0 );  
  122.   
  123.         for( j = 0; j < _var_count; j++ )    //计算当前类别cls的每个变量属性值的平均值  
  124.         {  
  125.             int n = count_data[j];  
  126.             avg_data[j] = n ? sum1[j] / n : 0.;  
  127.         }  
  128.   
  129.         count_data = count[cls]->data.i;  
  130.         avg_data = avg[cls]->data.db;  
  131.         sum1 = sum[cls]->data.db;  
  132.   
  133.         for( i = 0; i < _var_count; i++ )//计算当前类别cls的变量协方差矩阵,矩阵大小为_var_count * _var_count,注意协方差矩阵对称。  
  134.         {  
  135.             double* avg2_data = avg[cls]->data.db;  
  136.             double* sum2 = sum[cls]->data.db;  
  137.             double* prod_data = productsum[cls]->data.db + i*_var_count;  
  138.             double* cov_data = cov->data.db + i*_var_count;  
  139.             double s1val = sum1[i];  
  140.             double avg1 = avg_data[i];  
  141.             int _count = count_data[i];  
  142.   
  143.             for( j = 0; j <= i; j++ )  
  144.             {  
  145.                 double avg2 = avg2_data[j];  
  146.                 double cov_val = prod_data[j] - avg1 * sum2[j] - avg2 * s1val + avg1 * avg2 * _count;  
  147.                 cov_val = (_count > 1) ? cov_val / (_count - 1) : cov_val;  
  148.                 cov_data[j] = cov_val;  
  149.             }  
  150.         }  
  151.   
  152.         CV_CALL( cvCompleteSymm( cov, 1 ));  
  153.         CV_CALL( cvSVD( cov, w, cov_rotate_mats[cls], 0, CV_SVD_U_T ));  
  154.         CV_CALL( cvMaxS( w, min_variation, w ));  
  155.         for( j = 0; j < _var_count; j++ )  
  156.             det *= w->data.db[j];  
  157.   
  158.         CV_CALL( cvDiv( NULL, w, w ));  
  159.         c->data.db[cls] = det > 0 ? log(det) : -700;  
  160.     }  
  161.   
  162.     result = true;  
  163.   
  164.     __END__;  
  165.   
  166.     if( !result || cvGetErrStatus() < 0 )  
  167.         clear();  
  168.   
  169.     cvReleaseMat( &cov );  
  170.     cvReleaseMat( &__cls_labels );  
  171.     cvReleaseMat( &__var_idx );  
  172.     cvFree( &train_data );  
  173.   
  174.     return result;  
  175. }  
训练部分就此完成。

6.预测

下面看用于预测的predict函数的实现代码:

[cpp]  view plain copy print ?
  1. float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) const  
  2. {  
  3.     float value = 0;  
  4.   
  5.     if( !CV_IS_MAT(samples) || CV_MAT_TYPE(samples->type) != CV_32FC1 || samples->cols != var_all )  
  6.         CV_Error( CV_StsBadArg,  
  7.         "The input samples must be 32f matrix with the number of columns = var_all" );  
  8.   
  9.     if( samples->rows > 1 && !results )  
  10.         CV_Error( CV_StsNullPtr,  
  11.         "When the number of input samples is >1, the output vector of results must be passed" );  
  12.   
  13.     if( results )  
  14.     {  
  15.         if( !CV_IS_MAT(results) || (CV_MAT_TYPE(results->type) != CV_32FC1 &&  
  16.         CV_MAT_TYPE(results->type) != CV_32SC1) ||  
  17.         (results->cols != 1 && results->rows != 1) ||  
  18.         results->cols + results->rows - 1 != samples->rows )  
  19.         CV_Error( CV_StsBadArg, "The output array must be integer or floating-point vector "  
  20.         "with the number of elements = number of rows in the input matrix" );  
  21.     }  
  22.   
  23.     const int* vidx = var_idx ? var_idx->data.i : 0;  
  24.   
  25.     cv::parallel_for(cv::BlockedRange(0, samples->rows), predict_body(c, cov_rotate_mats, inv_eigen_values, avg, samples,  
  26.                                                                       vidx, cls_labels, results, &value, var_count  
  27.     ));  
  28.   
  29.     return value;  
  30. }  
可以发现,预测部分核心代码是:
[cpp]  view plain copy print ?
  1. cv::parallel_for(cv::BlockedRange(0, samples->rows), predict_body(c, cov_rotate_mats, inv_eigen_values, avg, samples,  
  2.                                                                       vidx, cls_labels, results, &value, var_count));  
parallel_for是用于并行支持的,可能会调用tbb模块。predict_body则是一个结构体,内部的()符号被重载,实现预测功能。其完整定义如下:

[cpp]  view plain copy print ?
  1. //predict函数调用predict_body结构体的()符号重载函数,实现基于贝叶斯的分类  
  2. struct predict_body   
  3. {  
  4.     predict_body(CvMat* _c, CvMat** _cov_rotate_mats, CvMat** _inv_eigen_values, CvMat** _avg,  
  5.                 const CvMat* _samples, const int* _vidx, CvMat* _cls_labels,  
  6.                 CvMat* _results, float* _value, int _var_count1)  
  7.     {  
  8.         c = _c;  
  9.         cov_rotate_mats = _cov_rotate_mats;  
  10.         inv_eigen_values = _inv_eigen_values;  
  11.         avg = _avg;  
  12.         samples = _samples;  
  13.         vidx = _vidx;  
  14.         cls_labels = _cls_labels;  
  15.         results = _results;  
  16.         value = _value;  
  17.         var_count1 = _var_count1;  
  18.     }  
  19.   
  20.     CvMat* c;  
  21.     CvMat** cov_rotate_mats;  
  22.     CvMat** inv_eigen_values;  
  23.     CvMat** avg;  
  24.     const CvMat* samples;  
  25.     const int* vidx;  
  26.     CvMat* cls_labels;  
  27.   
  28.     CvMat* results;  
  29.     float* value;  
  30.     int var_count1;  
  31.   
  32.     void operator()( const cv::BlockedRange& range ) const  
  33.     {  
  34.   
  35.         int cls = -1;  
  36.         int rtype = 0, rstep = 0;  
  37.         int nclasses = cls_labels->cols;  
  38.         int _var_count = avg[0]->cols;  
  39.   
  40.         if (results)  
  41.         {  
  42.             rtype = CV_MAT_TYPE(results->type);  
  43.             rstep = CV_IS_MAT_CONT(results->type) ? 1 : results->step/CV_ELEM_SIZE(rtype);  
  44.         }  
  45.         // allocate memory and initializing headers for calculating  
  46.         cv::AutoBuffer<double> buffer(nclasses + var_count1);  
  47.         CvMat diff = cvMat( 1, var_count1, CV_64FC1, &buffer[0] );  
  48.   
  49.         for(int k = range.begin(); k < range.end(); k += 1 )//对于每个输入测试样本  
  50.         {  
  51.             int ival;  
  52.             double opt = FLT_MAX;  
  53.   
  54.             for(int i = 0; i < nclasses; i++ )   //对于每一类别,计算其似然概率  
  55.             {  
  56.   
  57.                 double cur = c->data.db[i];  
  58.                 CvMat* u = cov_rotate_mats[i];  
  59.                 CvMat* w = inv_eigen_values[i];  
  60.   
  61.                 const double* avg_data = avg[i]->data.db;  
  62.                 const float* x = (const float*)(samples->data.ptr + samples->step*k);  
  63.   
  64.                 // cov = u w u'  -->  cov^(-1) = u w^(-1) u'  
  65.                 for(int j = 0; j < _var_count; j++ ) //计算特征相对于均值的偏移  
  66.                     diff.data.db[j] = avg_data[j] - x[vidx ? vidx[j] : j];  
  67.   
  68.                 cvGEMM( &diff, u, 1, 0, 0, &diff, CV_GEMM_B_T );  
  69.                 for(int j = 0; j < _var_count; j++ )//计算特征的联合概率  
  70.                 {  
  71.                     double d = diff.data.db[j];  
  72.                     cur += d*d*w->data.db[j];  
  73.                 }  
  74.   
  75.                 if( cur < opt )  //找到分类概率最大的  
  76.                 {  
  77.                     cls = i;  
  78.                     opt = cur;  
  79.                 }  
  80.                 // probability = exp( -0.5 * cur )   
  81.   
  82.             }//for(int i = 0; i < nclasses; i++ )  
  83.   
  84.             ival = cls_labels->data.i[cls];  
  85.             if( results )  
  86.             {  
  87.                 if( rtype == CV_32SC1 )  
  88.                     results->data.i[k*rstep] = ival;  
  89.                 else  
  90.                     results->data.fl[k*rstep] = (float)ival;  
  91.             }  
  92.             if( k == 0 )  
  93.                 *value = (float)ival;  
  94.   
  95.         }//for(int k = range.begin()...  
  96.   
  97.     }//void operator()...  
  98. };  

你可能感兴趣的:(Normal Bayes 分类器过程详解)