Apache Spark MLlib学习笔记(七)MLlib决策树类算法源码解析 3

上篇已经分析到Spark MLlib库的决策树最终实现使用了random forrest的run方法,这篇将对run方法进行详细的剖析和解释。
上篇提到input先被转换成Metadata处理,因此首先看一下buildMetadata方法
可以看出DecisionTreeMetadata确定了叶子节点数在不同情况下的范围,将数据的属性分为了有序和无序两种情况。将二元分类和回归问题放在了一起考虑。
另外,在分割数量上,对于连续数值,先进行抽样,然后分割数目就是分支数减一,对于离散数据,分成有序和无序属性讨论,有序情况使用每个属性的类别数量作为划分(split)数量,无序情况下则使用属性类别数量的子集作为划分依据,解决了属性太多产生大量叶子节点问题。

private[tree] object DecisionTreeMetadata extends Logging {

  /** * 该方法创建一个 [[DecisionTreeMetadata]] 实例 * 对特征的处理分为有序和无序两种情况 */
  def buildMetadata(
      input: RDD[LabeledPoint],
      strategy: Strategy,
      numTrees: Int,
      featureSubsetStrategy: String): DecisionTreeMetadata = {

    val numFeatures = input.take(1)(0).features.size//属性个数
    val numExamples = input.count()//样例总数
    val numClasses = strategy.algo match {//最终分类的个数
      case Classification => strategy.numClasses//分类问题等于离散类数量
      case Regression => 0//回归问题为0,无意义
    }
     //下面的部分是关键的确定树的叶子节点的数目范围
    val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
    //最大的叶子节点数目不超过样例个数和设定值中的较小值
    if (maxPossibleBins < strategy.maxBins) {
      logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
        s" (= number of training instances)")
    }

    //categoricalFeaturesInfo是一个映射,记录了每个属性对应的取值个数
    //每个属性的最大取值不能大于最大叶子节点数
    if (strategy.categoricalFeaturesInfo.nonEmpty) {
      val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
      require(maxCategoriesPerFeature <= maxPossibleBins,
        s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " +
          s"in categorical features (= $maxCategoriesPerFeature)")
    }
    //对于无序属性
    val unorderedFeatures = new mutable.HashSet[Int]()
    val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
    if (numClasses > 2) {
      // 分类结果多于两个,说明是多元分类
      val maxCategoriesForUnorderedFeature =
        ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
      strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
        // Decide if some categorical features should be treated as unordered features,
        // which require 2 * ((1 << numCategories - 1) - 1) bins.
        // We do this check with log values to prevent overflows in case numCategories is large.
        // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
        if (numCategories <= maxCategoriesForUnorderedFeature) {
          unorderedFeatures.add(featureIndex)
          numBins(featureIndex) = numUnorderedBins(numCategories)
        } else {
          numBins(featureIndex) = numCategories
        }
      }
    } else {
      // 二元分类和回归问题
      strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
        numBins(featureIndex) = numCategories
      }
    }

    // 以下参数都是随机森林使用的,暂时不用
    val _featureSubsetStrategy = featureSubsetStrategy match {
      case "auto" =>
        if (numTrees == 1) {
          "all"
        } else {
          if (strategy.algo == Classification) {
            "sqrt"
          } else {
            "onethird"
          }
        }
      case _ => featureSubsetStrategy
    }
    val numFeaturesPerNode: Int = _featureSubsetStrategy match {
      case "all" => numFeatures
      case "sqrt" => math.sqrt(numFeatures).ceil.toInt
      case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
      case "onethird" => (numFeatures / 3.0).ceil.toInt
    }

    new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
      strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
      strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
      strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
  }

  def buildMetadata(
      input: RDD[LabeledPoint],
      strategy: Strategy): DecisionTreeMetadata = {
    buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
  }

    /** * 得到分类属性的元数,即分类数目。返回非有序属性的叶子节点数,数目共计 math.pow(2, arity - 1) - 1 个。每次分割产生两个叶子 */
  def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)

}

你可能感兴趣的:(apache,源码,spark)