libsvm代码阅读:关于svm_train函数分析

在svm中,训练是一个十分重要的步骤,下面我们来看看svm的train部分。

在libsvm中的svm_train中分别有回归和分类两部分,我只对其中分类做介绍。

分类的步骤如下:

  • 统计类别总数,同时记录类别的标号,统计每个类的样本数目
  • 将属于相同类的样本分组,连续存放
  • 计算权重C
  • 训练n(n-1)/2 个模型
    • 初始化nozero数组,便于统计SV
    • //初始化概率数组
    • 训练过程中,需要重建子数据集,样本的特征不变,但样本的类别要改为+1/-1
    • //如有必要,先调用svm_binary_svc_probability
    • 训练子数据集svm_train_one
    • 统计一下nozero,如果nozero已经是真,就不变,否则改为真
  • 输出模型
    • 主要是填充svm_model
  • 清除内存

函数中调用过程如下:

svm_train-->svm_train_one-->solve_c_svc(for example)-->s.Solve

[cpp]   view plain copy 在CODE上查看代码片
<EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. //  
  2. // Interface functions  
  3. //重点函数:svm训练函数  
  4. //根据选择的算法,来组织参加训练的分样本,以及进行训练结果的保存。其中会对样本进行初步的统计。  
  5. svm_model *svm_train(const svm_problem *prob, const svm_parameter *param)  
  6. {  
  7.     svm_model *model = Malloc(svm_model,1);//#define Malloc(type,n) (type *)malloc((n)*sizeof(type))  
  8.     model->param = *param;  
  9.     model->free_sv = 0;  // XXX  
  10.   
  11.     if(param->svm_type == ONE_CLASS ||  
  12.        param->svm_type == EPSILON_SVR ||  
  13.        param->svm_type == NU_SVR)  
  14.     {  
  15.         // regression or one-class-svm  
  16.         model->nr_class = 2;  
  17.         model->label = NULL;  
  18.         model->nSV = NULL;  
  19.         model->probA = NULL; model->probB = NULL;  
  20.         model->sv_coef = Malloc(double *,1);  
  21.   
  22.         if(param->probability &&   
  23.            (param->svm_type == EPSILON_SVR ||  
  24.             param->svm_type == NU_SVR))  
  25.         {  
  26.             model->probA = Malloc(double,1);  
  27.             model->probA[0] = svm_svr_probability(prob,param);  
  28.         }  
  29.   
  30.         decision_function f = svm_train_one(prob,param,0,0);  
  31.         model->rho = Malloc(double,1);  
  32.         model->rho[0] = f.rho;  
  33.   
  34.         int nSV = 0;  
  35.         int i;  
  36.         for(i=0;i<prob->l;i++)  
  37.             if(fabs(f.alpha[i]) > 0) ++nSV;  
  38.         model->l = nSV;  
  39.         model->SV = Malloc(svm_node *,nSV);  
  40.         model->sv_coef[0] = Malloc(double,nSV);  
  41.         model->sv_indices = Malloc(int,nSV);  
  42.         int j = 0;  
  43.         for(i=0;i<prob->l;i++)  
  44.             if(fabs(f.alpha[i]) > 0)  
  45.             {  
  46.                 model->SV[j] = prob->x[i];  
  47.                 model->sv_coef[0][j] = f.alpha[i];  
  48.                 model->sv_indices[j] = i+1;  
  49.                 ++j;  
  50.             }         
  51.         free(f.alpha);  
  52.     }  
  53.     else  
  54.     {  
  55.         // classification  
  56.         int l = prob->l;  
  57.         int nr_class;  
  58.         int *label = NULL;  
  59.         int *start = NULL;  
  60.         int *count = NULL;  
  61.         int *perm = Malloc(int,l);  
  62.   
  63.         // group training data of the same class对训练样本进行处理,同类整合到一起  
  64.         svm_group_classes(prob,&nr_class,&label,&start,&count,perm);  
  65.         if(nr_class == 1)   
  66.             info("WARNING: training data in only one class. See README for details.\n");  
  67.           
  68.         svm_node **x = Malloc(svm_node *,l);  
  69.         int i;  
  70.         for(i=0;i<l;i++)  
  71.             x[i] = prob->x[perm[i]];  
  72.   
  73.         // calculate weighted C  
  74.   
  75.         double *weighted_C = Malloc(double, nr_class);  
  76.         for(i=0;i<nr_class;i++)  
  77.             weighted_C[i] = param->C;  
  78.         for(i=0;i<param->nr_weight;i++)  
  79.         {     
  80.             int j;  
  81.             for(j=0;j<nr_class;j++)  
  82.                 if(param->weight_label[i] == label[j])  
  83.                     break;  
  84.             if(j == nr_class)  
  85.                 fprintf(stderr,"WARNING: class label %d specified in weight is not found\n", param->weight_label[i]);  
  86.             else  
  87.                 weighted_C[j] *= param->weight[i];  
  88.         }  
  89.   
  90.         // train k*(k-1)/2 models  
  91.           
  92.         bool *nonzero = Malloc(bool,l);  
  93.         for(i=0;i<l;i++)  
  94.             nonzero[i] = false;  
  95.         decision_function *f = Malloc(decision_function,nr_class*(nr_class-1)/2);  
  96.   
  97.         double *probA=NULL,*probB=NULL;  
  98.         if (param->probability)  
  99.         {  
  100.             probA=Malloc(double,nr_class*(nr_class-1)/2);  
  101.             probB=Malloc(double,nr_class*(nr_class-1)/2);  
  102.         }  
  103.   
  104.         int p = 0;  
  105.         for(i=0;i<nr_class;i++)  
  106.             for(int j=i+1;j<nr_class;j++)  
  107.             {  
  108.                 svm_problem sub_prob;  
  109.                 int si = start[i], sj = start[j];  
  110.                 int ci = count[i], cj = count[j];  
  111.                 sub_prob.l = ci+cj;  
  112.                 sub_prob.x = Malloc(svm_node *,sub_prob.l);  
  113.                 sub_prob.y = Malloc(double,sub_prob.l);  
  114.                 int k;  
  115.                 for(k=0;k<ci;k++)  
  116.                 {  
  117.                     sub_prob.x[k] = x[si+k];  
  118.                     sub_prob.y[k] = +1;  
  119.                 }  
  120.                 for(k=0;k<cj;k++)  
  121.                 {  
  122.                     sub_prob.x[ci+k] = x[sj+k];  
  123.                     sub_prob.y[ci+k] = -1;  
  124.                 }  
  125.   
  126.                 if(param->probability)  
  127.                     svm_binary_svc_probability(&sub_prob,param,weighted_C[i],weighted_C[j],probA[p],probB[p]);  
  128.   
  129.                 f[p] = svm_train_one(&sub_prob,param,weighted_C[i],weighted_C[j]);  
  130.                 for(k=0;k<ci;k++)  
  131.                     if(!nonzero[si+k] && fabs(f[p].alpha[k]) > 0)  
  132.                         nonzero[si+k] = true;  
  133.                 for(k=0;k<cj;k++)  
  134.                     if(!nonzero[sj+k] && fabs(f[p].alpha[ci+k]) > 0)  
  135.                         nonzero[sj+k] = true;  
  136.                 free(sub_prob.x);  
  137.                 free(sub_prob.y);  
  138.                 ++p;  
  139.             }  
  140.   
  141.         // build output  
  142.   
  143.         model->nr_class = nr_class;  
  144.           
  145.         model->label = Malloc(int,nr_class);  
  146.         for(i=0;i<nr_class;i++)  
  147.             model->label[i] = label[i];  
  148.           
  149.         model->rho = Malloc(double,nr_class*(nr_class-1)/2);  
  150.         for(i=0;i<nr_class*(nr_class-1)/2;i++)  
  151.             model->rho[i] = f[i].rho;  
  152.   
  153.         if(param->probability)  
  154.         {  
  155.             model->probA = Malloc(double,nr_class*(nr_class-1)/2);  
  156.             model->probB = Malloc(double,nr_class*(nr_class-1)/2);  
  157.             for(i=0;i<nr_class*(nr_class-1)/2;i++)  
  158.             {  
  159.                 model->probA[i] = probA[i];  
  160.                 model->probB[i] = probB[i];  
  161.             }  
  162.         }  
  163.         else  
  164.         {  
  165.             model->probA=NULL;  
  166.             model->probB=NULL;  
  167.         }  
  168.   
  169.         int total_sv = 0;  
  170.         int *nz_count = Malloc(int,nr_class);  
  171.         model->nSV = Malloc(int,nr_class);  
  172.         for(i=0;i<nr_class;i++)  
  173.         {  
  174.             int nSV = 0;  
  175.             for(int j=0;j<count[i];j++)  
  176.                 if(nonzero[start[i]+j])  
  177.                 {     
  178.                     ++nSV;  
  179.                     ++total_sv;  
  180.                 }  
  181.             model->nSV[i] = nSV;  
  182.             nz_count[i] = nSV;  
  183.         }  
  184.           
  185.         info("Total nSV = %d\n",total_sv);  
  186.   
  187.         model->l = total_sv;  
  188.         model->SV = Malloc(svm_node *,total_sv);  
  189.         model->sv_indices = Malloc(int,total_sv);  
  190.         p = 0;  
  191.         for(i=0;i<l;i++)  
  192.             if(nonzero[i])  
  193.             {  
  194.                 model->SV[p] = x[i];  
  195.                 model->sv_indices[p++] = perm[i] + 1;  
  196.             }  
  197.   
  198.         int *nz_start = Malloc(int,nr_class);  
  199.         nz_start[0] = 0;  
  200.         for(i=1;i<nr_class;i++)  
  201.             nz_start[i] = nz_start[i-1]+nz_count[i-1];  
  202.   
  203.         model->sv_coef = Malloc(double *,nr_class-1);  
  204.         for(i=0;i<nr_class-1;i++)  
  205.             model->sv_coef[i] = Malloc(double,total_sv);  
  206.   
  207.         p = 0;  
  208.         for(i=0;i<nr_class;i++)  
  209.             for(int j=i+1;j<nr_class;j++)  
  210.             {  
  211.                 // classifier (i,j): coefficients with  
  212.                 // i are in sv_coef[j-1][nz_start[i]...],  
  213.                 // j are in sv_coef[i][nz_start[j]...]  
  214.   
  215.                 int si = start[i];  
  216.                 int sj = start[j];  
  217.                 int ci = count[i];  
  218.                 int cj = count[j];  
  219.                   
  220.                 int q = nz_start[i];  
  221.                 int k;  
  222.                 for(k=0;k<ci;k++)  
  223.                     if(nonzero[si+k])  
  224.                         model->sv_coef[j-1][q++] = f[p].alpha[k];  
  225.                 q = nz_start[j];  
  226.                 for(k=0;k<cj;k++)  
  227.                     if(nonzero[sj+k])  
  228.                         model->sv_coef[i][q++] = f[p].alpha[ci+k];  
  229.                 ++p;  
  230.             }  
  231.           
  232.         free(label);  
  233.         free(probA);  
  234.         free(probB);  
  235.         free(count);  
  236.         free(perm);  
  237.         free(start);  
  238.         free(x);  
  239.         free(weighted_C);  
  240.         free(nonzero);  
  241.         for(i=0;i<nr_class*(nr_class-1)/2;i++)  
  242.             free(f[i].alpha);  
  243.         free(f);  
  244.         free(nz_count);  
  245.         free(nz_start);  
  246.     }  
  247.     return model;  
  248. }  

你可能感兴趣的:(libsvm代码阅读:关于svm_train函数分析)