一、算法
关于REPTree我实在是没找到什么相关其算法的资料,或许是Weka自创的一个关于决策树的改进,也许是其它某种决策树方法的别名,根据类的注释:Fast decision tree learner. Builds a decision/regression tree using information gain/variance and prunes it using reduced-error pruning (with backfitting). Only sorts values for numeric attributes once. Missing values are dealt with by splitting the corresponding instances into pieces (i.e. as in C4.5).
我们大概知道和C4.5相比,大概多了backfitting过程,并且数值型排序只进行一次(回想一下J48也就是C4.5算法是每个数据子集都要进行排序),并且缺失值的处理方式和C4.5一样,走不同的path再把结果进行加权。
具体和C4.5的比较将在代码分析之后给出一个总结。
二、buildClassifier
“大名鼎鼎”的分类器训练主入口,几乎每篇分析分类器源码都从这个方法入手。
public void buildClassifier(Instances data) throws Exception { // 首先例行公事看一下给定数据集是否能使用REPTree进行分类,REPTREE基本能支持所有类型 getCapabilities().testWithFail(data); // 把classIndex上没有数据的instance干掉,这些数据既不能用于训练也不能用于backfit data = new Instances(data); data.deleteWithMissingClass(); Random random = new Random(m_Seed); m_zeroR = null; if (data.numAttributes() == 1) { m_zeroR = new ZeroR();//如果只有一列的话,就是用m_ZerO作为分类器,很直观只有一列的话肯定就是结果列了,只有结果列无法训练分类器,只能使用最基本的米ZerO作为分类器,mZerO的分类方法再上篇日志有说到。 m_zeroR.buildClassifier(data); return; } // Randomize and stratify data.randomize(random);//进行随机排列 if (data.classAttribute().isNominal()) { data.stratify(m_NumFolds);//如果枚举型还要进行一下分层,目的是 } // 如果需要剪枝,则分为train集合和prune集合,否则只要train集合就行了 Instances train = null; Instances prune = null; if (!m_NoPruning) { train = data.trainCV(m_NumFolds, 0, random);//这里是用了多折交叉验证的方法取得train和test prune = data.testCV(m_NumFolds, 0); } else { train = data; } // 建立了两个数组,第一维数据无意义,只是把三维数组当二维数组用而已,第二维代表各属性,第三维代表排序的index(顺序统计量) int[][][] sortedIndices = new int[1][train.numAttributes()][0];//这个里面存放的是各instance的下标 double[][][] weights = new double[1][train.numAttributes()][0];//这个里面存放的是下标对应的instance的weight double[] vals = new double[train.numInstances()];//这个是临时数组,用于排序用的 for (int j = 0; j < train.numAttributes(); j++) { if (j != train.classIndex()) { weights[0][j] = new double[train.numInstances()]; if (train.attribute(j).isNominal()) { //如果是枚举类型,所做的排序工作就是简单的把Missing放到最后面 sortedIndices[0][j] = new int[train.numInstances()]; int count = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (!inst.isMissing(j)) { sortedIndices[0][j][count] = i; weights[0][j][count] = inst.weight(); count++; } } for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (inst.isMissing(j)) { sortedIndices[0][j][count] = i; weights[0][j][count] = inst.weight(); count++; } } } else { // 如果是数值类型,则进行排序 for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); vals[i] = inst.value(j); } sortedIndices[0][j] = Utils.sort(vals); for (int i = 0; i < train.numInstances(); i++) { weights[0][j][i] = train.instance(sortedIndices[0][j][i]).weight(); } } } } // 这里建立数组存放训练集中每个类的分布 double[] classProbs = new double[train.numClasses()]; double totalWeight = 0, totalSumSquared = 0; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); if (data.classAttribute().isNominal()) {
classProbs[(int)inst.classValue()] += inst.weight();//如果是枚举类型,就进行简单的统计 totalWeight += inst.weight(); } else { classProbs[0] += inst.classValue() * inst.weight();//如果是数值型,就相加,到后面进行取平均的操作 totalSumSquared += inst.classValue() * inst.classValue() * inst.weight(); totalWeight += inst.weight(); } } m_Tree = new Tree();//建立决策树节点 double trainVariance = 0;//训练集的方差 if (data.classAttribute().isNumeric()) { trainVariance = m_Tree. singleVariance(classProbs[0], totalSumSquared, totalWeight) / totalWeight; classProbs[0] /= totalWeight;//这里取平均操作 } // Build tree m_Tree.buildTree(sortedIndices, weights, train, totalWeight, classProbs, new Instances(train, 0), m_MinNum, m_MinVarianceProp * trainVariance, 0, m_MaxDepth);//执行具体树上的构建操作,这参数还真多 // Insert pruning data and perform reduced error pruning if (!m_NoPruning) { m_Tree.insertHoldOutSet(prune);//传入剪枝数据 m_Tree.reducedErrorPrune();//进行剪枝 m_Tree.backfitHoldOutSet();//backfit } }
(2)Tree.buildTree
Tree是REPTree的一个子对象,训练用参数较多。
protected void buildTree(int[][][] sortedIndices, double[][][] weights, Instances data, double totalWeight, double[] classProbs, Instances header, double minNum, double minVariance, int depth, int maxDepth) throws Exception { //第一个参数是按属性排好序的下标,第二个是这些下标对应的weight,第三个是训练数据
<span style="white-space:pre"> </span>//第四个是总权重,第五个是各类的分布,第六个是表头,第七个是每个节点最小instance数量
<span style="white-space:pre"> </span>//第八个是最小的方差 ,第九个是当前深度(0 base),第十个是最大深度
</pre><pre name="code" class="java"> m_Info = header;//首先存下表头 if (data.classAttribute().isNumeric()) { m_HoldOutDist = new double[2];//这个数组用于存放分布 } else { m_HoldOutDist = new double[data.numClasses()]; } // 看看是否有有效数据 int helpIndex = 0; if (data.classIndex() == 0) { helpIndex = 1;//传入的数据至少两列,因为一列的话上层就用m_zerO模型了,这个if是为了保证helpIndex对应的肯定是训练数据 } if (sortedIndices[0][helpIndex].length == 0) {//如果没数据,就直接反悔了 if (data.classAttribute().isNumeric()) { m_Distribution = new double[2];//为什么是二维的?第一维存放方差,第二维存放weight,基于约定的编程方式 } else { m_Distribution = new double[data.numClasses()]; } m_ClassProbs = null; sortedIndices[0] = null; weights[0] = null; return; } double priorVar = 0;//存放class的方差(其实是方差*num),只有class是数值才有意义,下面就是计算方差的过程。 if (data.classAttribute().isNumeric()) { // 每个sortedIndices[0][i]里面的都是一个Instances的index不同排列而已,使用helpIndex只是为了保证别对应到classIndex上 double totalSum = 0, totalSumSquared = 0, totalSumOfWeights = 0; for (int i = 0; i < sortedIndices[0][helpIndex].length; i++) { Instance inst = data.instance(sortedIndices[0][helpIndex][i]); totalSum += inst.classValue() * weights[0][helpIndex][i]; totalSumSquared += inst.classValue() * inst.classValue() * weights[0][helpIndex][i]; totalSumOfWeights += weights[0][helpIndex][i]; } priorVar = singleVariance(totalSum, totalSumSquared, totalSumOfWeights); } //把分布拷贝一下 m_ClassProbs = new double[classProbs.length]; System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length); if ((//退出条件有4个
<span style="white-space:pre"> </span>//第一个是instances里面的totalweight总量(可以理解成里面的instance数量,因为weight默认都是1)小于两倍的minNum,minNum默认是2.
<span style="white-space:pre"> </span>totalWeight < (2 * minNum)) || // 如果是枚举类型,并且都在一类中 (data.classAttribute().isNominal() && Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)], Utils.sum(m_ClassProbs))) || // 数值型则比较方差是否小于minVariance,这个minVariance默认是原始方差的0.001,从上层代码可以得知 (data.classAttribute().isNumeric() && ((priorVar / totalWeight) < minVariance)) || // 达到最大深度 ((m_MaxDepth >= 0) && (depth >= maxDepth))) { // 设置成叶子 m_Attribute = -1; if (data.classAttribute().isNominal()) { // 设置枚举类型的分布 m_Distribution = new double[m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[i] = m_ClassProbs[i]; } Utils.normalize(m_ClassProbs); } else { // 设置数值类型的“分布” m_Distribution = new double[2]; m_Distribution[0] = priorVar; m_Distribution[1] = totalWeight; } sortedIndices[0] = null; weights[0] = null; return; } // 下面是寻找分裂点的过程 double[] vals = new double[data.numAttributes()];//每个属性产生的信息增益 double[][][] dists = new double[data.numAttributes()][0][0];//每个属性下每个类的分布 double[][] props = new double[data.numAttributes()][0];//每个属性下class的概率,也就是根据上面这个数组的分布求概率 double[][] totalSubsetWeights = new double[data.numAttributes()][0];//每个属性下每个subset的数量 double[] splits = new double[data.numAttributes()];//每个属性的分裂点,如果是枚举型则为NaN if (data.classAttribute().isNominal()) { // 首先来看classAttribute是枚举类型的情况 for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { splits[i] = distribution(props, dists, i, sortedIndices[0][i], weights[0][i], totalSubsetWeights, data);//得到分裂点、概率和分布 vals[i] = gain(dists[i], priorVal(dists[i]));//得到信息增益 } } } else { // 如果是数值类型则不算信息增益(为什么数值类型不算增益?只有因为枚举型才算的出信息熵)(吐个槽:话说这个if-else为啥不放在循环里面??) for (int i = 0; i < data.numAttributes(); i++) { if (i != data.classIndex()) { splits[i] = numericDistribution(props, dists, i, sortedIndices[0][i], weights[0][i], totalSubsetWeights, data, vals); } } } // 选出信息增益最大的作为分裂属性 m_Attribute = Utils.maxIndex(vals); int numAttVals = dists[m_Attribute].length; // 每个subset都要多于minNum,这样才算一个有效subset int count = 0; for (int i = 0; i < numAttVals; i++) { if (totalSubsetWeights[m_Attribute][i] >= minNum) { count++; } if (count > 1) { break; } } // 至少存在2个有效subset,才算是一个有效的split if (Utils.gr(vals[m_Attribute], 0) && (count > 1)) { // Set split point, proportions, and temp arrays m_SplitPoint = splits[m_Attribute]; m_Prop = props[m_Attribute]; double[][] attSubsetDists = dists[m_Attribute]; double[] attTotalSubsetWeights = totalSubsetWeights[m_Attribute]; // 释放内存 vals = null; dists = null; props = null; totalSubsetWeights = null; splits = null; // 得到subSet的有序index int[][][][] subsetIndices = new int[numAttVals][1][data.numAttributes()][0]; double[][][][] subsetWeights = new double[numAttVals][1][data.numAttributes()][0]; splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint, sortedIndices[0], weights[0], data); // 释放内存 sortedIndices[0] = null; weights[0] = null; //释放内存 m_Successors = new Tree[numAttVals]; for (int i = 0; i < numAttVals; i++) { m_Successors[i] = new Tree();//构建孩子节点 m_Successors[i]. buildTree(subsetIndices[i], subsetWeights[i], data, attTotalSubsetWeights[i], attSubsetDists[i], header, minNum, minVariance, depth + 1, maxDepth); // 还是释放内存 attSubsetDists[i] = null; } } else { // 如果不存在2个有效的subset,就直接当叶子节点了 m_Attribute = -1; sortedIndices[0] = null; weights[0] = null; } // 构建attribute用于之后的分类过程(当然这是在没有prune和backfit情况下用的) if (data.classAttribute().isNominal()) { m_Distribution = new double[m_ClassProbs.length]; for (int i = 0; i < m_ClassProbs.length; i++) { m_Distribution[i] = m_ClassProbs[i]; } Utils.normalize(m_ClassProbs); } else { m_Distribution = new double[2]; m_Distribution[0] = priorVar; m_Distribution[1] = totalWeight; } }
(未完待续)