一、RandomTree算法
在网上搜了一下,并没有找到RandomTree的严格意义上的算法描述,因此我觉得RandomTree充其量只是一种构建树的思路,和普通决策树相比,RandomTree会随机的选择若干属性来进行构建而不是选取所有的属性。
Weka在实现上,对于随机属性的选取、生成分裂点的过程是这样的:
1、设置一个要选取的属性的数量K
2、在全域属性中无放回的对属性进行抽样
3、算出该属性的信息增益(注意不是信息增益率)
4、重复K次,选出信息增益最大的当分裂节点。
5、构建该节点的孩子子树。
二、具体代码实现
(1)buildClassifier
public void buildClassifier(Instances data) throws Exception { // 如果传入的K不合理,把K放到一个合理的范围里 if (m_KValue > data.numAttributes() - 1) m_KValue = data.numAttributes() - 1; if (m_KValue < 1) m_KValue = (int) Utils.log2(data.numAttributes()) + 1;//这个是K的默认值 // 判断一下该分类器是否有能力处理这个数据集,如果没能力直接就在testWithFail里抛异常退出了 getCapabilities().testWithFail(data); // 删除掉missClass data = new Instances(data); data.deleteWithMissingClass(); // 如果只有一列,就build一个ZeroR模型,之后就结束了。ZeroR模型分类是这样的:如果是连续型,总是返回期望,如果离散型,总是返回训练集中出现最多的那个 if (data.numAttributes() == 1) { System.err .println("Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_zeroR = new weka.classifiers.rules.ZeroR(); m_zeroR.buildClassifier(data); return; } else { m_zeroR = null; } // 如果m_NumFlods大于0,则会把数据集分为两部分,一部分用于train,一部分用于test,也就是backfit
//分的方式和多折交叉验证是一样的,例如m_NumFlods是10的话,则train占90%,backfit占10% Instances train = null; Instances backfit = null; Random rand = data.getRandomNumberGenerator(m_randomSeed); if (m_NumFolds <= 0) { train = data; } else { data.randomize(rand); data.stratify(m_NumFolds); train = data.trainCV(m_NumFolds, 1, rand); backfit = data.testCV(m_NumFolds, 1); } // 生成所有的可选属性 int[] attIndicesWindow = new int[data.numAttributes() - 1]; int j = 0; for (int i = 0; i < attIndicesWindow.length; i++) { if (j == data.classIndex()) j++; // 忽略掉classIndex attIndicesWindow[i] = j++;//这段代码有点奇怪,i和j是相等的,为啥不用attIndicesWindow=i? } // 算出每个class的频率,也就是每个分类出现的次数(更正确的说法应该是权重,但权重默认都是1) double[] classProbs = new double[train.numClasses()]; for (int i = 0; i < train.numInstances(); i++) { Instance inst = train.instance(i); classProbs[(int) inst.classValue()] += inst.weight(); } // Build tree m_Tree = new Tree(); m_Info = new Instances(data, 0); m_Tree.buildTree(train, classProbs, attIndicesWindow, rand, 0);//调用tree的build方法,在后面单独分析 // Backfit if required if (backfit != null) { m_Tree.backfitData(backfit);//在后面单独分析 } }
(2)tree.buildTree
protected void buildTree(Instances data, double[] classProbs, int[] attIndicesWindow, Random random, int depth) throws Exception { //首先判断一下是否有instance,如果没有的话直接就返回 if (data.numInstances() == 0) { m_Attribute = -1; m_ClassDistribution = null; m_Prop = null; return; } m_ClassDistribution = classProbs.clone(); if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum || Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)], Utils.sum(m_ClassDistribution)) || ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) { // 递归结束的条件有3个 1、instance数量小于2*m_Minnum 2、instance都已经在同一个类中 3、达到最大的深度
//前两个条件和j48的递归结束条件很相似,相关内容可参考我之前的几篇博客。 m_Attribute = -1; m_Prop = null; return; } double val = -Double.MAX_VALUE; double split = -Double.MAX_VALUE; double[][] bestDists = null; double[] bestProps = null; int bestIndex = 0; double[][] props = new double[1][0]; double[][][] dists = new double[1][0][0];//这个数组第一列只有下标为0的被用到,不知道为啥这么设计 int attIndex = 0;//存储被选择到的属性 int windowSize = attIndicesWindow.length;//存储目前可选择的属性的数量 int k = m_KValue;//k代表还能选择的属性的数量 boolean gainFound = false;//是否发现了一个有信息增益的节点 while ((windowSize > 0) && (k-- > 0 || !gainFound)) {//此循环退出条件有2个 1、没有节点可以选了 2、已经选了k个属性了并且找到了一个有用的属性 换句话说,如果K次迭代没有找到可以分裂的随机节点,循环也会继续下去
int chosenIndex = random.nextInt(windowSize);//随机选一个,生成下标 attIndex = attIndicesWindow[chosenIndex];//得到该属性的index //下面三行把选择的属性放到attIndicesWindow的末尾,然后把windowSize-1这样下个循环就不会选到了,也就是实现了无放回的抽取 attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1]; attIndicesWindow[windowSize - 1] = attIndex; windowSize--; double currSplit = distribution(props, dists, attIndex, data);//这个函数计算了在使用attIndex进行分裂所产生的分布,如果classIndex是连续值的话,还计算了分裂点,原理和J48的split一样,不在赘述。 double currVal = gain(dists[0], priorVal(dists[0]));//这个计算了信息增益 if (Utils.gr(currVal, 0)) gainFound = true;//如果信息增益大于0的话,说明节点有效,设置gainFound if ((currVal > val) || ((currVal == val) && (attIndex < bestIndex))) { val = currVal; //如果信息增益大的话,则更新把attIndex更新为bestIndex,这是为了选取最优的节点(ID3)的方法 bestIndex = attIndex; split = currSplit; bestProps = props[0]; bestDists = dists[0]; } } m_Attribute = bestIndex; // Any useful split found? if (Utils.gr(val, 0)) { <span style="white-space:pre"> </span>//如果找到了一个分裂点,则在该分裂点的基础上构建子树
m_SplitPoint = split; m_Prop = bestProps; Instances[] subsets = splitData(data); m_Successors = new Tree[bestDists.length]; for (int i = 0; i < bestDists.length; i++) { m_Successors[i] = new Tree(); m_Successors[i].buildTree(subsets[i], bestDists[i], attIndicesWindow, random, depth + 1);//注意这里传入的attIndicesWindow没有变,换句话说,每次迭代传入的可选属性集合是一样的,因此子节点在进行属性的random选择时,很有可能会选择到父节点已经选过的节点,但因为不产生信息增益,因此不会再次作为bestIndex,但会产生额外的计算量(我感觉还不少),这里还有一定的优化空间,同理j48也是这么实现的。 } boolean emptySuccessor = false; for (int i = 0; i < subsets.length; i++) { if (m_Successors[i].m_ClassDistribution == null) { emptySuccessor = true; break; } } if (!emptySuccessor) { m_ClassDistribution = null; } } else { //这个else是<span style="font-family: Arial, Helvetica, sans-serif;">Utils.gr(currVal, 0)这个条件的,代表没有选择到合适的分裂节点</span> m_Attribute = -1; } }
什么是Backfit?Backfit将改变已有tree节点及其子节点的class分布,而class分布将直接被用于实例的预测。
直接使用RandomTree有时会出现过拟合的现象(通过代码可以看到,和J48相比没有剪枝过程),因此通过传入一个新的数据集来backfit已有节点是一个解决过拟合的方法。
protected void backfitData(Instances data, double[] classProbs) throws Exception {
<span style="white-space:pre"> </span>//判断一下是否有数据 if (data.numInstances() == 0) { m_Attribute = -1; m_ClassDistribution = null; m_Prop = null; return; } m_ClassDistribution = classProbs.clone(); if (m_Attribute > -1) { // m_Attribut>-1代表不是leaf,可以看上面的buildTree得出这个结论 m_Prop = new double[m_Successors.length];//子节点数组的length也就是分类的类的数量
<span style="white-space:pre"> </span>//把传入的data用此节点算各类的频率 for (int i = 0; i < data.numInstances(); i++) { Instance inst = data.instance(i); if (!inst.isMissing(m_Attribute)) { if (data.attribute(m_Attribute).isNominal()) { m_Prop[(int) inst.value(m_Attribute)] += inst.weight(); } else { m_Prop[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1] += inst .weight();//连续型只会分两类,小于splitPoint一类,大于是一类,和J48采用的策略相同 } } } if (Utils.sum(m_Prop) <= 0) { m_Attribute = -1;//如果data全部都是missingValue,则把此节点变成leaf节点 m_Prop = null; return; } // 归一化 Utils.normalize(m_Prop); // 根据本节点算出在data上进行分类的subset Instances[] subsets = splitData(data); for (int i = 0; i < subsets.length; i++) { // 递归的对孩子节点进行backfit double[] dist = new double[data.numClasses()]; for (int j = 0; j < subsets[i].numInstances(); j++) { dist[(int) subsets[i].instance(j).classValue()] += subsets[i] .instance(j).weight(); } m_Successors[i].backfitData(subsets[i], dist); } <span style="white-space:pre"> </span> if (getAllowUnclassifiedInstances()) { m_ClassDistribution = null; return; } <span style="white-space:pre"> </span>//如果某个子节点的分布为空的话,则父节点要保存分布,否则不需要持有分布。
<span style="white-space:pre"> </span>//为什么呢?因为使用RandomTree进行预测时会遍历节点的分布并进行累加,得到分布最大的class作为预测class,在J48的那篇博客中有分析 boolean emptySuccessor = false; for (int i = 0; i < subsets.length; i++) { if (m_Successors[i].m_ClassDistribution == null) { emptySuccessor = true; return; } } m_ClassDistribution = null; } }
三、总结
对RandomForest的分析到这里就结束了,首先分析了RandomForest,接着分析了Bagging,最后分析了RandomTree。