Apache Spark MLlib学习笔记(六)MLlib决策树类算法源码解析 2

上篇说道建立分类决策树模型调用了trainClassifier方法,这章分析trainClassifier方法相关内容
按照以下路径打开源码文件:
/home/yangqiao/codes/spark/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
先重点分析DecisionTree.scala文件。
首先找到trainClassifier方法,代码如下:

 def trainClassifier( input: RDD[LabeledPoint], numClasses: Int, categoricalFeaturesInfo: Map[Int, Int], impurity: String, maxDepth: Int, maxBins: Int): DecisionTreeModel = {
    val impurityType = Impurities.fromString(impurity)
    train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,categoricalFeaturesInfo)
  }

可以看到trainClassifier又调用了train方法,找到train方法进行查看:

def train( input: RDD[LabeledPoint], algo: Algo, impurity: Impurity, maxDepth: Int, numClasses: Int, maxBins: Int, quantileCalculationStrategy: QuantileStrategy, categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
    val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
      quantileCalculationStrategy, categoricalFeaturesInfo)
    new DecisionTree(strategy).run(input)
  }

首先介绍一下以上的参数含义:

 @param input Training dataset: RDD,标签是{0, 1, ..., numClasses-1}.
 @param algo :classification(分类) 或者 regression(回归)
 @param impurity:信息增益的计算方法,包括gini,entropy,varience。
 @param maxDepth:树的最大深度,0代表只有根节点,1代表1个根节点,两个叶子节点。
 @param numClasses:分类的数量,默认值是2。
 @param maxBins :分类属性的最大值。
 @param quantileCalculationStrategy:计算分位数算法
 @param categoricalFeaturesInfo:存储类别/属性键值对 (n -> k),特性n有K个类别,下标分别是   
  0: {0, 1, ..., k-1}.

从这段程序可以看出,所有的参数先被封装成strategy对象,将其作为参数初始化建立DecisionTree,接着调用run方法,首先看一下run方法,其代码是:

class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {

  strategy.assertValid()

  /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @return DecisionTreeModel that can be used for prediction */
  def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
    // Note: random seed will not be used since numTrees = 1.
    val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
    val rfModel = rf.run(input)
    rfModel.trees(0)
  }
}

可以看出程序最终调用了RandomForest的方法,即对于spark MLlib,决策树作为随即森林的一个特例,即只有一棵树,因此 rfModel.trees(0)方法中传入的参数为0,即只有一棵树。下面进入RandomForest源文件,路径是
/home/yangqiao/codes/spark/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
因为使用了RandomForest的run方法,因此找到run方法进行查看:

def run(input: RDD[LabeledPoint]): RandomForestModel = {

    val timer = new TimeTracker()

    timer.start("total") timer.start("init") val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) logDebug("algo = " + strategy.algo) logDebug("numTrees = " + numTrees) logDebug("seed = " + seed) logDebug("maxBins = " + metadata.maxBins) logDebug("featureSubsetStrategy = " + featureSubsetStrategy) logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) logDebug("subsamplingRate = " + strategy.subsamplingRate) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplitsBins") val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) timer.stop("findSplitsBins") logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => s"\t$featureIndex\t${metadata.numBins(featureIndex)}" }.mkString("\n")) // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) val withReplacement = if (numTrees > 1) true else false val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth require(maxDepth <= 30, s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") // Max memory usage for aggregates // TODO: Calculate memory usage more precisely. val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") val maxMemoryPerNode = { val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. Some(metadata.numBins.zipWithIndex.sortBy(- _._1) .take(metadata.numFeaturesPerNode).map(_._2)) } else { None } RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L } require(maxMemoryPerNode <= maxMemoryUsage, s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," + " which is too small for the given features." + s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") timer.stop("init") /* * The main idea here is to perform group-wise training of the decision tree nodes thus * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). * Each data sample is handled by a particular node (or it reaches a leaf and is not used * in lower levels). */ // Create an RDD of node Id cache. // At first, all the rows belong to the root nodes (node Id == 1). val nodeIdCache = if (strategy.useNodeIdCache) { Some(NodeIdCache.init( data = baggedInput, numTrees = numTrees, checkpointInterval = strategy.checkpointInterval, initVal = 1)) } else { None } // FIFO queue of nodes to train: (treeIndex, node) val nodeQueue = new mutable.Queue[(Int, Node)]() val rng = new scala.util.Random() rng.setSeed(seed) // Allocate and queue root nodes. val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1)) Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) while (nodeQueue.nonEmpty) { // Collect some nodes to split, and choose features for each node (if subsampling). // Each group of nodes may come from one or multiple trees, and at multiple levels. val (nodesForGroup, treeToNodeToIndexInfo) = RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) // Sanity check (should never occur): assert(nodesForGroup.size > 0, s"RandomForest selected empty nodesForGroup. Error for unknown reason.") // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache) timer.stop("findBestSplits") } baggedInput.unpersist() timer.stop("total") logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") // Delete any remaining checkpoints used for node Id cache. if (nodeIdCache.nonEmpty) { try { nodeIdCache.get.deleteAllCheckpoints() } catch { case e:IOException => logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}") } } val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) new RandomForestModel(strategy.algo, trees) } 

可以看到首先DecisionTreeMetadata类中的buildMetadata方法将输入数据进行处理,因此应该先分析下buildMetadata都做了什么。具体将在下一篇分析。

你可能感兴趣的:(apache,源码,spark)