spark 决策树浅谈





spark 决策树浅谈_第1张图片




spark 决策树浅谈_第2张图片

图2 ——决策树示意图









以上面的表格为例,最后的结果是分类两个也就是1、见,2、不见。   每一个类别所占的比例为 (k=1,2,3...K)次数K为我们所分的类别个数,本例中K为2,也就是i取1,2.

,p1代表见面的,p2代表不见面的。信息熵的表达式为:。熵代表这个概率空间的纯度,说白了也就是这个概率空间取值的绝对性(比如在有房和无房之间,我们几乎可以说有房一定见面)。另外学过信息论的都知道,熵是关于概率分布的上凸函数,此处不再作图,简单描述一下。当概率空间是平均分布的时候,我们的熵可以去到最大值,当概率全取为0或者1时,熵值为0,两端为0 ,中间为1,也就是上凸函数。熵越大,不确定性越大,熵越小,不确定性越小,越纯。


信息增益是指某一个属性的信息增益,以属性1城市是否拥有房产为例。  spark 决策树浅谈_第3张图片

房产这个属性分为两种情况,有和无是所有记录的一个子集,v分为两种情况(有 无),每一种情况中各有5条记录,同时每条记录的结果又分为见面和不见面,首先要计算Ent(Dv),然后在计算信息增益

Ent(D1) = 0,Ent(D2) =

D1/D = 0.5,D2/D = 0.5。这样就可以计算出增益了,同理计算其他的所有属性的信息增益,选择增益最大的一个属性作为分支点,以此类推,就可以生成一个决策树。


重点讲一下针对连续特征是如何处理的,何为连续特征?先看一下什么是离散特征,比如是否拥有房产,就两种情况是  否,这种就为离散特征或者是名称特征。那比如说年龄12 ,13 ,16 ,19 ,30,32,45,21,78,90,50,。你第一眼看去,这不是连续的啊,取值连续不就是离散的吗?对,你说的没错,从信号的角度老看这就是离散数据,但是这样的话就是一个年龄一个类别,我们仅仅以年龄就可以确定最后的结果,这样好吗?显然是片面的,那就要想办法把他变为离散的数据,从理论上来讲,我们可以以每一个数据作为一个分割点,小于这个数据作为一类,大于这个数据作为另一类。对于少量的数据这没问题,但是对于百万条,亿级别的数据显然是不可取的。在spark中采用了一种采样的策略。




12, 14 ,16 ,11, 43, 32, 45, 56, 54, 89, 76




public static void medicalRandomForest(JavaRDD train,JavaRDD test) throws IOException{
		int numClasses = 2;
		HashMap categoricalFeaturesInfo = new HashMap<>();
		int numTrees = 7;
		String featureSubStrategy = "auto";//每个节点考虑的特征数量(auto时根据numTree来决定)
		String impurity = "gini";//信息增益的计算标准
		int maxDepth = 6;
		int maxBins = 25;//特征最大装箱数,划分区间的长度
		int seed = 12345;//选择特征子集的随机种子
//		File dtResult = new File("rfMaxBins.txt");
//		FileWriter dtOut = new FileWriter(dtResult);
//		for(maxBins=3; maxBins<=100; maxBins++){
		final RandomForestModel rfModel = RandomForest.trainClassifier(train, numClasses,
				  categoricalFeaturesInfo, numTrees, featureSubStrategy, impurity, maxDepth, maxBins,
		JavaRDD>scoresAndLabelTrain = -> {
			double score = rfModel.predict(line.features());
			return new Tuple2 (score, line.label());
		JavaRDD> scoresAndLabelTest = -> {
			double score = rfModel.predict(line.features());
			return new Tuple2 (score, line.label());

		double trainPre = 1.0 - (1.0 * scoresAndLabelTrain.filter(p1 -> {
			return !p1._1().equals(p1._2());
		double testPre = 1.0 - (1.0 * scoresAndLabelTest.filter(p1 -> {
			return !p1._1().equals(p1._2());
//		dtOut.write(maxBins+" "+trainPre+" "+testPre);
//		dtOut.write("\r\n");
		System.out.println("trainPre = "+ trainPre);
		System.out.println("testPre = " + testPre);
		BinaryClassificationMetrics metricTrain = new BinaryClassificationMetrics(scoresAndLabelTrain.rdd());
		System.out.println("训练集 Area under ROC = "+ metricTrain.areaUnderROC());
		BinaryClassificationMetrics metricTest = new BinaryClassificationMetrics(scoresAndLabelTest.rdd());
		System.out.println("测试集 Area under ROC = "+ metricTest.areaUnderROC());
		System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());
		System.out.println("maxBins = " +maxBins + "*****随机森林计算结束*****");
//		}
//		dtOut.close();

随机森林中关键的类是 org.apache.spark.mllib.tree.RandomForest、org.apache.spark.mllib.tree.model.RandomForestModel 这两个类,它们提供了随机森林具体的 trainClassifier 和 predict 函数。

从上面的 demo 中可以看到,训练随机森林算法采用的是 RandomForest 的伴生对象中的 trainClassifier 方法,其源码如下:



 def trainClassifier(
 input: RDD[LabeledPoint],
 numClasses: Int,
 categoricalFeaturesInfo: Map[Int, Int],
 numTrees: Int,
 featureSubsetStrategy: String,
 impurity: String,
 maxDepth: Int,
 maxBins: Int,
 seed: Int = Utils.random.nextInt()): RandomForestModel = {
 val impurityType = Impurities.fromString(impurity)
 val strategy = new Strategy(Classification, impurityType, maxDepth,
 numClasses, maxBins, Sort, categoricalFeaturesInfo)
 //调用的是重载的另外一个 trainClassifier
 trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed)

重载后 trainClassifier 方法代码如下:


 def trainClassifier(
 input: RDD[LabeledPoint],
 strategy: Strategy,
 numTrees: Int,
 featureSubsetStrategy: String,
 seed: Int): RandomForestModel = {
 require(strategy.algo == Classification,
 s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
 //在该方法中创建 RandomForest 对象
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
//再调用其 run 方法,传入的参数是类型 RDD[LabeledPoint],方法返回的是 RandomForestModel 实例
进入 RandomForest 中的 run 方法,其代码如下:


def run(input: RDD[LabeledPoint]): RandomForestModel = {

 val timer = new TimeTracker()


val retaggedInput = input.retag(classOf[LabeledPoint])
 val metadata =
 DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
 logDebug("algo = " + strategy.algo)
 logDebug("numTrees = " + numTrees)
 logDebug("seed = " + seed)
 logDebug("maxBins = " + metadata.maxBins)
 logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
 logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
 logDebug("subsamplingRate = " + strategy.subsamplingRate)

 // Find the splits and the corresponding bins (interval between the splits) using a sample
 // of the input data.
//对于名称型特征,如果是无序的,则最多有个 splits=2^(numBins-1)-1 划分
//如果是有序的,则最多有 splits=numBins-1 个划分
 val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
 logDebug("numBins: feature: number of bins")
 logDebug(Range(0, metadata.numFeatures).map { featureIndex =>

 // Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
//转换成树形的 RDD 类型,转换后,所有样本点已经按分裂点条件分到了各自的箱子中 
 val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)

 val withReplacement = if (numTrees > 1) true else false

// convertToBaggedRDD 方法使得每棵树就是样本的一个子集 
 val baggedInput
 = BaggedPoint.convertToBaggedRDD(treeInput,
 strategy.subsamplingRate, numTrees,
 withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)

 // depth of the decision tree
 val maxDepth = strategy.maxDepth
 require(maxDepth <= 30,
 s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")

 // Max memory usage for aggregates
 // TODO: Calculate memory usage more precisely.
 val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
 val maxMemoryPerNode = {
 val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
 // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
 Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
 } else {
 RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
require(maxMemoryPerNode <= maxMemoryUsage,
 s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," +
 " which is too small for the given features." +
 s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}")


 * The main idea here is to perform group-wise training of the decision tree nodes thus
 * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup).
 * Each data sample is handled by a particular node (or it reaches a leaf and is not used
 * in lower levels).

 // Create an RDD of node Id cache.
// At first, all the rows belong to the root nodes (node Id == 1).
//节点是否使用缓存,节点 ID 从 1 开始,1 即为这颗树的根节点,左节点为 2,右节点为 3,依次递增下去 
 val nodeIdCache = if (strategy.useNodeIdCache) {
 data = baggedInput,
 numTrees = numTrees,
 checkpointInterval = strategy.checkpointInterval,
 initVal = 1))
 } else {

 // FIFO queue of nodes to train: (treeIndex, node)
 val nodeQueue = new mutable.Queue[(Int, Node)]()

 val rng = new scala.util.Random()

// Allocate and queue root nodes.
val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
//将(树的索引,数的根节点)入队,树索引从 0 开始,根节点从 1 开始 
 Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))

 while (nodeQueue.nonEmpty) {
 // Collect some nodes to split, and choose features for each node (if subsampling).
 // Each group of nodes may come from one or multiple trees, and at multiple levels.
 // 取得每个树所有需要切分的节点
 val (nodesForGroup, treeToNodeToIndexInfo) =
 RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
 // Sanity check (should never occur):
 assert(nodesForGroup.size > 0,
 s"RandomForest selected empty nodesForGroup. Error for unknown reason.")

 // Choose node splits, and enqueue new nodes as needed.
 DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
 treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)



 logInfo("Internal timing for DecisionTree:")

 // Delete any remaining checkpoints used for node Id cache.
 if (nodeIdCache.nonEmpty) {
 try {
 } catch {
 case e: IOException =>
 logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")

 val trees = => new DecisionTreeModel(topNode, strategy.algo))
 new RandomForestModel(strategy.algo, trees)

上面给出的是 RandomForest 类中的核心方法 run 的代码,在确定切分点及箱子信息的时候调用了 DecisionTree.findSplitsBins 方法,跳入该方法,可以看到如下代码:


 * Returns splits and bins for decision tree calculation.
 * Continuous and categorical features are handled differently.
 * Continuous features:
 * For each feature, there are numBins - 1 possible splits representing the possible binary
 * decisions at each node in the tree.
 * This finds locations (feature values) for splits using a subsample of the data.
 * Categorical features:
 * For each feature, there is 1 bin per split.
 * Splits and bins are handled in 2 ways:
 * (a) "unordered features"
 * For multiclass classification with a low-arity feature
 * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
 * the feature is split based on subsets of categories.
 * (b) "ordered features"
 * For regression and binary classification,
 * and for multiclass classification with a high-arity feature,
 * there is one bin per category.
 * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
 * @param metadata Learning and dataset metadata
 * @return A tuple of (splits, bins).
 * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
 * of size (numFeatures, numSplits).
 * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]]
 * of size (numFeatures, numBins).
 protected[tree] def findSplitsBins(
 input: RDD[LabeledPoint],
 metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = {

 logDebug("isMulticlass = " + metadata.isMulticlass)

 val numFeatures = metadata.numFeatures

// Sample the input only if there are continuous features.
// 判断特征中是否存在连续特征
 val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
 val sampledInput = if (hasContinuousFeatures) {
 // Calculate the number of samples for approximate quantile calculation.
 //采样样本数量,最少应该为 10000 个
 val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
 val fraction = if (requiredSamples < metadata.numExamples) {
 requiredSamples.toDouble / metadata.numExamples
 } else {
 logDebug("fraction of data used for calculating quantiles = " + fraction)
 input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
} else {
 new Array[LabeledPoint](0)

 // //分裂点策略,目前 Spark 中只实现了一种策略:排序 Sort 
 metadata.quantileStrategy match {
 case Sort =>
 val splits = new Array[Array[Split]](numFeatures)
 val bins = new Array[Array[Bin]](numFeatures)

 // Find all splits.
 // Iterate over all features.
 var featureIndex = 0
 while (featureIndex < numFeatures) {
 if (metadata.isContinuous(featureIndex)) {
 val featureSamples = => lp.features(featureIndex))
 // findSplitsForContinuousFeature 返回连续特征的所有切分位置
 val featureSplits = findSplitsForContinuousFeature(featureSamples,
 metadata, featureIndex)

 val numSplits = featureSplits.length
 val numBins = numSplits + 1
 logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")

 splits(featureIndex) = new Array[Split](numSplits)
 bins(featureIndex) = new Array[Bin](numBins)

 var splitIndex = 0
 while (splitIndex < numSplits) {
 val threshold = featureSplits(splitIndex)
 splits(featureIndex)(splitIndex) =
 new Split(featureIndex, threshold, Continuous, List())
 splitIndex += 1
 //采用最小阈值 Double.MinValue 作为最左边的分裂位置并进行装箱
 bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
 splits(featureIndex)(0), Continuous, Double.MinValue)

 splitIndex = 1
 while (splitIndex < numSplits) {
 bins(featureIndex)(splitIndex) =
 new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
 Continuous, Double.MinValue)
 splitIndex += 1
 //最后一个箱子的计算采用最大阈值 Double.MaxValue 作为最右边的切分位置
 bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
 new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
 } else { //特征为离散情况时的计算
 val numSplits = metadata.numSplits(featureIndex)
 val numBins = metadata.numBins(featureIndex)
 // Categorical feature
 val featureArity = metadata.featureArity(featureIndex)
 if (metadata.isUnordered(featureIndex)) {
 // Unordered features
 // 2^(maxFeatureValue - 1) - 1 combinations
 splits(featureIndex) = new Array[Split](numSplits)
 var splitIndex = 0
 while (splitIndex < numSplits) {
 val categories: List[Double] =
 extractMultiClassCategories(splitIndex + 1, featureArity)
 splits(featureIndex)(splitIndex) =
 new Split(featureIndex, Double.MinValue, Categorical, categories)
 splitIndex += 1
 } else {
 // Ordered features
 // Bins correspond to feature values, so we do not need to compute splits or bins
 // beforehand. Splits are constructed as needed during training.
 splits(featureIndex) = new Array[Split](0)
 // For ordered features, bins correspond to feature values.
 // For unordered categorical features, there is no need to construct the bins.
 // since there is a one-to-one correspondence between the splits and the bins.
 bins(featureIndex) = new Array[Bin](0)
 featureIndex += 1
 (splits, bins)
 case MinMax =>
 throw new UnsupportedOperationException("minmax not supported yet.")
 case ApproxHist =>
 throw new UnsupportedOperationException("approximate histogram not supported yet.")
除 findSplitsBins 方法外,还有一个非常重要的 DecisionTree.findBestSplits() 方法,用于最优切分点的查找,该方法中的关键是对 binsToBestSplit 方法的调用,其 binsToBestSplit 方法代码如下:


 * Find the best split for a node.
 * @param binAggregates Bin statistics.
 * @return tuple for best split: (Split, information gain, prediction at node)
 private def binsToBestSplit(
 binAggregates: DTStatsAggregator, // DTStatsAggregator,其中引用了 ImpurityAggregator,给出计算不纯度 impurity 的逻辑
 splits: Array[Array[Split]],
 featuresForNode: Option[Array[Int]],
 node: Node): (Split, InformationGainStats, Predict) = {

 // calculate predict and impurity if current node is top node
 val level = Node.indexToLevel(
 var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
 } else {
 Some((node.predict, node.impurity))

// For each (feature, split), calculate the gain, and select the best (feature, split).
//对各特征及切分点,计算其信息增益并从中选择最优 (feature, split)
 val (bestSplit, bestSplitStats) =
 Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
 val featureIndex = if (featuresForNode.nonEmpty) {
 } else {
 val numSplits = binAggregates.metadata.numSplits(featureIndex)
 if (binAggregates.metadata.isContinuous(featureIndex)) {
 // Cumulative sum (scanLeft) of bin statistics.
 // Afterwards, binAggregates for a bin is the sum of aggregates for
 // that bin + all preceding bins.
 val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
 var splitIndex = 0
 while (splitIndex < numSplits) {
 binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
 splitIndex += 1
 // Find best split.
 val (bestFeatureSplitIndex, bestFeatureGainStats) =
 Range(0, numSplits).map { case splitIdx =>
 //计算 leftChild 及 rightChild 子节点的 impurity
 val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
 val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
 //求 impurity 的预测值,采用的是平均值计算
 predictWithImpurity = Some(predictWithImpurity.getOrElse(
 calculatePredictImpurity(leftChildStats, rightChildStats)))
 //求信息增益 information gain 值,用于评估切分点是否最优
 val gainStats = calculateGainForSplit(leftChildStats,
 rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
 (splitIdx, gainStats)
 (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
 } else if (binAggregates.metadata.isUnordered(featureIndex)) { //无序离散特征时的情况
 // Unordered categorical feature
 val (leftChildOffset, rightChildOffset) =
 val (bestFeatureSplitIndex, bestFeatureGainStats) =
 Range(0, numSplits).map { splitIndex =>
 val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
 val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
 predictWithImpurity = Some(predictWithImpurity.getOrElse(
 calculatePredictImpurity(leftChildStats, rightChildStats)))
 val gainStats = calculateGainForSplit(leftChildStats,
 rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
 (splitIndex, gainStats)
 (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
 } else { //有序离散特征时的情况
 // Ordered categorical feature
 val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
 val numBins = binAggregates.metadata.numBins(featureIndex)

 /* Each bin is one category (feature value).
 * The bins are ordered based on centroidForCategories, and this ordering determines which
 * splits are considered. (With K categories, we consider K - 1 possible splits.)
 * centroidForCategories is a list: (category, centroid)
 val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
 // For categorical variables in multiclass classification,
 // the bins are ordered by the impurity of their corresponding labels.
 Range(0, numBins).map { case featureValue =>
 val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
 val centroid = if (categoryStats.count != 0) {
 // impurity 求的就是均方差
 } else {
 (featureValue, centroid)
 } else { // 回归或二元分类时的情况 regression or binary classification
 // For categorical variables in regression and binary classification,
 // the bins are ordered by the centroid of their corresponding labels.
 Range(0, numBins).map { case featureValue =>
 val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
 val centroid = if (categoryStats.count != 0) {
 //求的就是平均值作为 impurity
 } else {
 (featureValue, centroid)

 logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(","))

 // bins sorted by centroids
 val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)

 logDebug("Sorted centroids for categorical variable = " +

 // Cumulative sum (scanLeft) of bin statistics.
 // Afterwards, binAggregates for a bin is the sum of aggregates for
 // that bin + all preceding bins.
 var splitIndex = 0
 while (splitIndex < numSplits) {
 val currentCategory = categoriesSortedByCentroid(splitIndex)._1
 val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
 binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
 splitIndex += 1
 // lastCategory = index of bin with total aggregates for this (node, feature)
 val lastCategory = categoriesSortedByCentroid.last._1
 // Find best split.
 val (bestFeatureSplitIndex, bestFeatureGainStats) =
 Range(0, numSplits).map { splitIndex =>
 val featureValue = categoriesSortedByCentroid(splitIndex)._1
 val leftChildStats =
 binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
 val rightChildStats =
 binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
 predictWithImpurity = Some(predictWithImpurity.getOrElse(
 calculatePredictImpurity(leftChildStats, rightChildStats)))
 val gainStats = calculateGainForSplit(leftChildStats,
 rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
 (splitIndex, gainStats)
 val categoriesForSplit =, bestFeatureSplitIndex + 1)
 val bestFeatureSplit =
 new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
 (bestFeatureSplit, bestFeatureGainStats)

 (bestSplit, bestSplitStats, predictWithImpurity.get._1)


