libsvm代码阅读:关于svm_group_classes函数分析

目前libsvm最新的version是3.17,主要的改变是在svm_group_classes函数中加了几行代码。官方的说明如下:

Version 3.17 released on April Fools' day, 2013. We slightly adjust the way class labels are handled internally. By default labels are ordered by their first occurrence in the training set. Hence for a set with -1/+1 labels, if -1 appears first, then internally -1 becomes +1. This has caused confusion. Now for data with -1/+1 labels, we specifically ensure that internally the binary SVM has positive data corresponding to the +1 instances. For developers, see changes in the subrouting svm_group_classes of svm.cpp. 

本文就对这个函数进行分析:

svm_group_classes函数的功能是:group training data of the same class

Important:如何将一堆数据归类到一起,同类的连续存储!可参考这个函数。

函数原型如下:

[cpp]   view plain copy 在CODE上查看代码片
<EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)  

主要的输入是prob这个指针,它指向svm_group_classes将要处理的样本数据集,另外几个形参是指针类型,可以相当于输出数据,其中:

  1. nr_class_ret——统计得出样本集的类别总数
  2. label_ret——指向存储类别标号的数组
  3. start_ret——指向存储每个类别的起始位置的数组
  4. count_tet——指向存储每个类别的样本个数的数组
  5. perm——指向原始数据的索引数组
下面,先看一部分代码,这部分代码中的for循环的功能:统计类别总数、将相应的相同类别y[i]赋到相应的label,并统计各个类别的样本数量count。
设一个例子:{ 有6个样本,总共4类,其中y[0]=y[1],y[2]=y[3],y[4],y[5] },则for循环的运行过程如下所示:
i=0  label[0]=y[0],           data_label[0]=0
i=1  label[0]=y[0]=y[1],   data_label[1]=0 count[0]=2
i=2  label[1]=y[2],           data_label[2]=1
i=3  label[1]=y[2]=y[3],   data_label[3]=1 count[1]=2
i=4  label[2]=y[4],           data_label[2]=2 count[2]=1
i=5  label[3]=y[5],           data_label[2]=3 count[3]=1

[cpp]   view plain copy 在CODE上查看代码片
<EMBED id=ZeroClipboardMovie_2 height=18 name=ZeroClipboardMovie_2 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=2&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. // label: label name, start: begin of each class, count: #data of classes, perm: indices to the original data  
  2. // perm, length l, must be allocated before calling this subroutine  
  3. static void svm_group_classes(const svm_problem *prob, int *nr_class_ret, int **label_ret, int **start_ret, int **count_ret, int *perm)  
  4. {  
  5.     int l = prob->l;//样本总数  
  6.     int max_nr_class = 16;//不够的话,自动增长为原来的两倍(见下文)  
  7.     int nr_class = 0;  
  8.     int *label = Malloc(int,max_nr_class);//Malloc(type,n) (type *)malloc((n)*sizeof(type))  
  9.     int *count = Malloc(int,max_nr_class);  
  10.     int *data_label = Malloc(int,l);      
  11.     int i;  
  12.   
  13.     for(i=0;i<l;i++)  
  14.     {  
  15.         int this_label = (int)prob->y[i];//将类别赋给this_label  
  16.         int j;  
  17.         for(j=0;j<nr_class;j++)  
  18.         {  
  19.             if(this_label == label[j])//虽然刚开始label里面没值,但是第一步循环本内层也没有被运行  
  20.             {  
  21.                 ++count[j];  
  22.                 break;  
  23.             }  
  24.         }  
  25.         data_label[i] = j;  
  26.         if(j == nr_class)  
  27.         {  
  28.             if(nr_class == max_nr_class)  
  29.             {  
  30.                 max_nr_class *= 2;//扩大最大类别数  
  31.                 label = (int *)realloc(label,max_nr_class*sizeof(int));  
  32.                 count = (int *)realloc(count,max_nr_class*sizeof(int));  
  33.             }  
  34.             label[nr_class] = this_label;  
  35.             count[nr_class] = 1;//这个是1  
  36.             ++nr_class;  
  37.         }  
  38.     }  


本version更新部分:本部分主要是处理二类分类,当第一个出现的是-1时,负责把-1和+1的数据对调。

[cpp]   view plain copy 在CODE上查看代码片
<EMBED id=ZeroClipboardMovie_3 height=18 name=ZeroClipboardMovie_3 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=3&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. //  
  2. // Labels are ordered by their first occurrence in the training set.   
  3. // However, for two-class sets with -1/+1 labels and -1 appears first,   
  4. // we swap labels to ensure that internally the binary SVM has positive data corresponding to the +1 instances.  
  5. //  
  6. if (nr_class == 2 && label[0] == -1 && label[1] == 1)  
  7. {  
  8.     swap(label[0],label[1]);  
  9.     swap(count[0],count[1]);  
  10.     for(i=0;i<l;i++)  
  11.     {  
  12.         if(data_label[i] == 0)  
  13.             data_label[i] = 1;  
  14.         else  
  15.             data_label[i] = 0;  
  16.     }  
  17. }  


下面这一部分代码是用来计算每个类别的起始位置start、以及各个样本分类后的在原始数据中的索引位置perm数组。其中perm[i]=j: i表示当前同类样本位置,j表示原始数据位置。

Important:如何将一堆数据归类到一起,同类的连续存储!可参考这个函数。

[cpp]   view plain copy 在CODE上查看代码片
<EMBED id=ZeroClipboardMovie_4 height=18 name=ZeroClipboardMovie_4 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=4&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
  1. int *start = Malloc(int,nr_class);  
  2. start[0] = 0;  
  3. for(i=1;i<nr_class;i++)  
  4.     start[i] = start[i-1]+count[i-1];  
  5. for(i=0;i<l;i++)  
  6. {  
  7.     perm[start[data_label[i]]] = i;  
  8.     ++start[data_label[i]];  
  9. }  
  10. start[0] = 0;  
  11. for(i=1;i<nr_class;i++)  
  12.     start[i] = start[i-1]+count[i-1];  
  13.   
  14. *nr_class_ret = nr_class;  
  15. *label_ret = label;  
  16. *start_ret = start;  
  17. *count_ret = count;  
  18. free(data_lab

你可能感兴趣的:(libsvm代码阅读:关于svm_group_classes函数分析)