Mahout版本:0.7,hadoop版本:1.0.4,jdk:1.7.0_25 64bit。
接上篇,分析到OptIgSplitl类的computeSplit函数里面的numbericalSplit函数,看这个函数的输入参数data和attr,应该是针对data计算出一个和attr相关的值而已。往下看
double[] values = sortedValues(data, attr); ,这一句是干啥的?
private static double[] sortedValues(Data data, int attr) { double[] values = data.values(attr); Arrays.sort(values); return values; }sortedValues就是把data中第attr个属性的值全部取出来,然后排个序(attr从0开始)。比如这次debug得到的三个属性s[5,1,4],第5个属性的值全部排序后得到的值如下:
[0.0, 0.02, 0.03, 0.04, 0.05, 0.06, 0.08, 0.09, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.18, 0.19, 0.23, 0.31, 0.32, 0.33, 0.35, 0.37, 0.38, 0.39, 0.44, 0.45, 0.47, 0.48, 0.49, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.72, 0.73, 0.76, 0.81, 0.97, 1.1, 1.46, 1.68, 1.76, 2.7, 6.21]这里一共有59个值,为什么只有在data.values()函数中把相同的值都去除了,采用的HashSet存储的,这个可以在Data类的第193行看到。
然后到了initCounts(data, values);这个就是初始化三个数组的函数,具体代码如下:
void initCounts(Data data, double[] values) { counts = new int[values.length][data.getDataset().nblabels()]; countAll = new int[data.getDataset().nblabels()]; countLess = new int[data.getDataset().nblabels()]; }这里values.length就是59,data.getDataset().nblabels()就是6;
然后到了computeFrequencies(data, attr, values);这个函数主要计算了两个数组:counts[i][j],其中i表示0-58之间的一个数字,j表示0-5之间的一个数字。因为前面的214个数据删除了重复的才变为了59个,所以原始数据里面肯定是有重复的,这里就是计算这些重复的且它的label值要是一样的,比如原始数据如下(只写一列,因为这里只取了一列):
0.3 0
0.3 0
0.3 1
0.4 1
0.5 2
那么values=[0.3,0.4,0.5],counts[0][0]=2,counts[0][1]=1,counts[1][1]=1,counts[2][2]=1,其他counts的值都为0;countAll就是label的值分别相加,countAll[0]=2,countAll[1]=2,countAll[2]=1;
void computeFrequencies(Data data, int attr, double[] values) { Dataset dataset = data.getDataset(); for (int index = 0; index < data.size(); index++) { Instance instance = data.get(index); counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++; countAll[(int) dataset.getLabel(instance)]++; } }比如上面的glass.data数据在attr是5的情况下得到的counts(size为59)和countAll(size为6)如下:
继续往下:int size = data.size(); double hy = entropy(countAll, size);这里的size就是214,entropy是啥来的?
private static double entropy(int[] counts, int dataSize) { if (dataSize == 0) { return 0.0; } double entropy = 0.0; double invDataSize = 1.0 / dataSize; for (int count : counts) { if (count == 0) { continue; // otherwise we get a NaN } double p = count * invDataSize; entropy += -p * Math.log(p) / LOG2; } return entropy; }这个好像叫做熵的?看下它是如何计算的:
下面的公式中pi是每个label的重复值除以总数214的结果。
在继续往下面看:
double invDataSize = 1.0 / size; int best = -1; double bestIg = -1.0; // try each possible split value for (int index = 0; index < values.length; index++) { double ig = hy; // instance with attribute value < values[index] size = DataUtils.sum(countLess); ig -= size * invDataSize * entropy(countLess, size); // instance with attribute value >= values[index] size = DataUtils.sum(countAll); ig -= size * invDataSize * entropy(countAll, size); if (ig > bestIg) { bestIg = ig; best = index; } DataUtils.add(countLess, counts[index]); DataUtils.dec(countAll, counts[index]); }上面算到的hy是2.110138986672679。进入for循环,size = DataUtils.sum(countLess);由于第一次countLess全部值为0,所以size也为0,ig=ig-0;然后size = DataUtils.sum(countAll);这个size值为214;然后就是ig -= size * invDataSize * entropy(countAll, size); size*invDataSize不是1么,entropy(countAll,size)不就是hy么?ig前面不是把hy的值赋给它了么?所以ig=ig-hy=ig-ig=0?然后debug得出的答案是:4.440892098500626E-16。尼玛 ,还10的负十六次方。
然后是最后两行,这个是什么意思?运行前:
运行最后两行代码后,变为:
这样就不用我多说了吧,等于是把counts里面的第一条记录加到countLess中,然后再把countAll中相应的次数减去第一条记录。
下面的就是按照这种规律循环遍历最后得到一个 attr --> bestIg 、bestIndex 的对应关系,然后输出 return new Split(attr, bestIg, values[best]);
今晚就到这里吧,不要熬夜。。。
感觉好像更新小说的样子。。。最近看小说太不给力了,卤煮跟新好慢。不过写这个算法更新更慢。哎,看源码。。。
分享,成长,快乐
转载请注明blog地址:http://blog.csdn.net/fansy1990