Weka算法Classifier-trees-RandomTree源码分析


一、RandomTree算法

在网上搜了一下,并没有找到RandomTree的严格意义上的算法描述,因此我觉得RandomTree充其量只是一种构建树的思路,和普通决策树相比,RandomTree会随机的选择若干属性来进行构建而不是选取所有的属性。

Weka在实现上,对于随机属性的选取、生成分裂点的过程是这样的:

1、设置一个要选取的属性的数量K

2、在全域属性中无放回的对属性进行抽样

3、算出该属性的信息增益(注意不是信息增益率)

4、重复K次,选出信息增益最大的当分裂节点。

5、构建该节点的孩子子树。


二、具体代码实现

(1)buildClassifier

 public void buildClassifier(Instances data) throws Exception {

    // 如果传入的K不合理,把K放到一个合理的范围里
    if (m_KValue > data.numAttributes() - 1)
      m_KValue = data.numAttributes() - 1;
    if (m_KValue < 1)
      m_KValue = (int) Utils.log2(data.numAttributes()) + 1;//这个是K的默认值

    // 判断一下该分类器是否有能力处理这个数据集,如果没能力直接就在testWithFail里抛异常退出了
    getCapabilities().testWithFail(data);

    // 删除掉missClass
    data = new Instances(data);
    data.deleteWithMissingClass();

    // 如果只有一列,就build一个ZeroR模型,之后就结束了。ZeroR模型分类是这样的:如果是连续型,总是返回期望,如果离散型,总是返回训练集中出现最多的那个
    if (data.numAttributes() == 1) {
      System.err
          .println("Cannot build model (only class attribute present in data!), "
              + "using ZeroR model instead!");
      m_zeroR = new weka.classifiers.rules.ZeroR();
      m_zeroR.buildClassifier(data);
      return;
    } else {
      m_zeroR = null;
    }

    // 如果m_NumFlods大于0,则会把数据集分为两部分,一部分用于train,一部分用于test,也就是backfit
    //分的方式和多折交叉验证是一样的,例如m_NumFlods是10的话,则train占90%,backfit占10%
    Instances train = null;
    Instances backfit = null;
    Random rand = data.getRandomNumberGenerator(m_randomSeed);
    if (m_NumFolds <= 0) {
      train = data;
    } else {
      data.randomize(rand);
      data.stratify(m_NumFolds);
      train = data.trainCV(m_NumFolds, 1, rand);
      backfit = data.testCV(m_NumFolds, 1);
    }

    // 生成所有的可选属性
    int[] attIndicesWindow = new int[data.numAttributes() - 1];
    int j = 0;
    for (int i = 0; i < attIndicesWindow.length; i++) {
      if (j == data.classIndex())
        j++; // 忽略掉classIndex
      attIndicesWindow[i] = j++;//这段代码有点奇怪,i和j是相等的,为啥不用attIndicesWindow=i?
    }

    // 算出每个class的频率,也就是每个分类出现的次数(更正确的说法应该是权重,但权重默认都是1)
    double[] classProbs = new double[train.numClasses()];
    for (int i = 0; i < train.numInstances(); i++) {
      Instance inst = train.instance(i);
      classProbs[(int) inst.classValue()] += inst.weight();
    }

    // Build tree
    m_Tree = new Tree();
    m_Info = new Instances(data, 0);
    m_Tree.buildTree(train, classProbs, attIndicesWindow, rand, 0);//调用tree的build方法,在后面单独分析

    // Backfit if required
    if (backfit != null) {
      m_Tree.backfitData(backfit);//在后面单独分析
    }
  }

这个Tree对象是RandomTree的一个子类,之前我还以为会复用其余的决策树模型(比如J48),但weka没这么做,很惊奇的是RandomTree和J48的作者还是同一个,不知道为啥这么设计。


(2)tree.buildTree

 protected void buildTree(Instances data, double[] classProbs,
        int[] attIndicesWindow, Random random, int depth) throws Exception {

      //首先判断一下是否有instance,如果没有的话直接就返回
      if (data.numInstances() == 0) {
        m_Attribute = -1;
        m_ClassDistribution = null;
        m_Prop = null;
        return;
      }

      m_ClassDistribution = classProbs.clone();

      if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum
          || Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)],
              Utils.sum(m_ClassDistribution))
          || ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) {
        // 递归结束的条件有3个 1、instance数量小于2*m_Minnum  2、instance都已经在同一个类中 3、达到最大的深度
       //前两个条件和j48的递归结束条件很相似,相关内容可参考我之前的几篇博客。
        m_Attribute = -1;
        m_Prop = null;
        return;
      }

      double val = -Double.MAX_VALUE;
      double split = -Double.MAX_VALUE;
      double[][] bestDists = null;
      double[] bestProps = null;
      int bestIndex = 0;

      double[][] props = new double[1][0];
      double[][][] dists = new double[1][0][0];//这个数组第一列只有下标为0的被用到,不知道为啥这么设计

      int attIndex = 0;//存储被选择到的属性
      int windowSize = attIndicesWindow.length;//存储目前可选择的属性的数量
      int k = m_KValue;//k代表还能选择的属性的数量
      boolean gainFound = false;//是否发现了一个有信息增益的节点
      while ((windowSize > 0) && (k-- > 0 || !gainFound)) {//此循环退出条件有2个 1、没有节点可以选了 2、已经选了k个属性了并且找到了一个有用的属性 换句话说,如果K次迭代没有找到可以分裂的随机节点,循环也会继续下去
        int chosenIndex = random.nextInt(windowSize);//随机选一个,生成下标
        attIndex = attIndicesWindow[chosenIndex];//得到该属性的index

       //下面三行把选择的属性放到attIndicesWindow的末尾,然后把windowSize-1这样下个循环就不会选到了,也就是实现了无放回的抽取
        attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1];
        attIndicesWindow[windowSize - 1] = attIndex;
        windowSize--;

        double currSplit = distribution(props, dists, attIndex, data);//这个函数计算了在使用attIndex进行分裂所产生的分布,如果classIndex是连续值的话,还计算了分裂点,原理和J48的split一样,不在赘述。
        double currVal = gain(dists[0], priorVal(dists[0]));//这个计算了信息增益

        if (Utils.gr(currVal, 0))
          gainFound = true;//如果信息增益大于0的话,说明节点有效,设置gainFound

        if ((currVal > val) || ((currVal == val) && (attIndex < bestIndex))) {
          val = currVal; //如果信息增益大的话,则更新把attIndex更新为bestIndex,这是为了选取最优的节点(ID3)的方法
          bestIndex = attIndex;
          split = currSplit;
          bestProps = props[0];
          bestDists = dists[0];
        }
      }

      m_Attribute = bestIndex;

      // Any useful split found?
      if (Utils.gr(val, 0)) {

 <span style="white-space:pre">	</span>//如果找到了一个分裂点,则在该分裂点的基础上构建子树
        m_SplitPoint = split;
        m_Prop = bestProps;
        Instances[] subsets = splitData(data);
        m_Successors = new Tree[bestDists.length];
        for (int i = 0; i < bestDists.length; i++) {
          m_Successors[i] = new Tree();
          m_Successors[i].buildTree(subsets[i], bestDists[i], attIndicesWindow,
              random, depth + 1);//注意这里传入的attIndicesWindow没有变,换句话说,每次迭代传入的可选属性集合是一样的,因此子节点在进行属性的random选择时,很有可能会选择到父节点已经选过的节点,但因为不产生信息增益,因此不会再次作为bestIndex,但会产生额外的计算量(我感觉还不少),这里还有一定的优化空间,同理j48也是这么实现的。
        }

        boolean emptySuccessor = false;
        for (int i = 0; i < subsets.length; i++) {
          if (m_Successors[i].m_ClassDistribution == null) {
            emptySuccessor = true;
            break;
          }
        }
        if (!emptySuccessor) {
          m_ClassDistribution = null;
        }
      } else {

       //这个else是<span style="font-family: Arial, Helvetica, sans-serif;">Utils.gr(currVal, 0)这个条件的,代表没有选择到合适的分裂节点</span>

        m_Attribute = -1;
      }
    }

(3)tree.backfit

什么是Backfit?Backfit将改变已有tree节点及其子节点的class分布,而class分布将直接被用于实例的预测。

直接使用RandomTree有时会出现过拟合的现象(通过代码可以看到,和J48相比没有剪枝过程),因此通过传入一个新的数据集来backfit已有节点是一个解决过拟合的方法。

  protected void backfitData(Instances data, double[] classProbs)
        throws Exception {
<span style="white-space:pre">	</span>//判断一下是否有数据
      if (data.numInstances() == 0) {
        m_Attribute = -1;
        m_ClassDistribution = null;
        m_Prop = null;
        return;
      }

      m_ClassDistribution = classProbs.clone();

      if (m_Attribute > -1) {

        // m_Attribut>-1代表不是leaf,可以看上面的buildTree得出这个结论
        m_Prop = new double[m_Successors.length];//子节点数组的length也就是分类的类的数量
<span style="white-space:pre">	</span>//把传入的data用此节点算各类的频率
        for (int i = 0; i < data.numInstances(); i++) {
          Instance inst = data.instance(i);
          if (!inst.isMissing(m_Attribute)) {
            if (data.attribute(m_Attribute).isNominal()) {
              m_Prop[(int) inst.value(m_Attribute)] += inst.weight();
            } else {
              m_Prop[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1] += inst
                  .weight();//连续型只会分两类,小于splitPoint一类,大于是一类,和J48采用的策略相同
            }
          }
        }

        if (Utils.sum(m_Prop) <= 0) {
          m_Attribute = -1;//如果data全部都是missingValue,则把此节点变成leaf节点
          m_Prop = null;
          return;
        }

        // 归一化
        Utils.normalize(m_Prop);

        // 根据本节点算出在data上进行分类的subset
        Instances[] subsets = splitData(data);

        for (int i = 0; i < subsets.length; i++) {

          // 递归的对孩子节点进行backfit
          double[] dist = new double[data.numClasses()];
          for (int j = 0; j < subsets[i].numInstances(); j++) {
            dist[(int) subsets[i].instance(j).classValue()] += subsets[i]
                .instance(j).weight();
          }

          m_Successors[i].backfitData(subsets[i], dist);
        }

<span style="white-space:pre">	</span>
        if (getAllowUnclassifiedInstances()) {
          m_ClassDistribution = null;
          return;
        }
<span style="white-space:pre">	</span>//如果某个子节点的分布为空的话,则父节点要保存分布,否则不需要持有分布。
 <span style="white-space:pre">	</span>//为什么呢?因为使用RandomTree进行预测时会遍历节点的分布并进行累加,得到分布最大的class作为预测class,在J48的那篇博客中有分析
        boolean emptySuccessor = false;
        for (int i = 0; i < subsets.length; i++) {
          if (m_Successors[i].m_ClassDistribution == null) {
            emptySuccessor = true;
            return;
          }
        }
        m_ClassDistribution = null;
      }
    }


三、总结

对RandomForest的分析到这里就结束了,首先分析了RandomForest,接着分析了Bagging,最后分析了RandomTree。



你可能感兴趣的:(源码,算法,机器学习,weka,分类器)