Spark Mlib- Decision Tree

Spark Mlib- Decision Tree
Q:决策树是什么?
A:决策树是模拟人类决策过程,将判断一件事情所要做的一系列决策的各种可能的集合,以数的形式展现出来,的一中树形图。
Q:决策树的结构是怎样的?
A:决策树与普通树一样,由节点和边组成。树中每一个节点都是一个属性(特征),或者说是对特征的判断。根据一个节点的判断结果,决策(预测)流程走向不同的子节点,或者直接到达叶节点,即决策(预测)结束,得到结果。
Q:决策树是怎么训练出来的?
A:典型的决策树的训练过程如下,以根据色泽、根蒂、敲声预测一个西瓜是否好瓜为例——

  1. {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
    1、输入一个数据集D2、生成一个节点node3、如果数据集中的样本全部属于同一类,比如西瓜样本全部是“好瓜”,那么node就是叶子节点(好瓜)4、如果样本中的数据集不属于同一类,比如西瓜中既有好瓜也有坏瓜,那就选择一个属性,把西瓜根据选好的属性分类。比如按照“纹理”属性,把西瓜分为清晰、模糊、稍糊三类。
  2. {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
    5、把第4步得到的几个数据子集作为输入数据,分别执行上面的第1到5步、直到不再执行第4步为止(也就是叶子节点全部构建完成,算法结束)。
    Q:整个决策树的训练算法很简洁,但是第4步的“选择一个属性,把西瓜根据选好的属性分类”,怎样来选择合适的属性呢?
    A:假设我们现在只有两种属性选择,一种是“色泽”、另一种种是“触感”。选择“色泽”,我们可以把西瓜分成三类——“浅白”全是坏瓜;“墨绿”全是好瓜;“青绿”85%是好瓜,15%是坏瓜。选择“触感”,我们可以把西瓜分为两类——“硬滑”一半是好瓜,一半是坏瓜;“软粘”40%是好瓜、60%是坏瓜。
    稍一思考,我们自然会选“色泽”作为本次的属性。因为“色泽”可以一下把很多好瓜和坏瓜区分开来,也就是说我们知道了一个西瓜的色泽后有很大几率正确判断它是好瓜还是坏瓜。而目前来说,知道“触感”却对我们的判断没什么帮助。因此,我们选择“能给我们带来更多信息”的属性,或者说能够“减少混淆程度”的属性。
    假设Pk是指当前集合中第k类样本占的比例。比如10个西瓜中纹理清晰、模糊、稍糊的各有3,3,4个,那么p1=3/10,p2=3/10,p3=4/10。于是根据信息的定义,我们有信息熵
  3. {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
    衡量那一个属性“能给我们带来更多信息”,我们用“信息增益”这个指标:
  4. {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
    信息增益越大,越能带来信息。因此每一次在上述算法第4步选择属性作为节点时,我们对待选属性都做一次信息增益的计算,选择信息增益最大的属性。
    Q:有哪些对上述决策树算法改进的方案呢?
    A:有两种思路——1、改进算法第4步中选择属性用的指标,比如用“增益率”或者“基尼系数”来代替“信息增益”。2、用“剪枝处理”,也就是去掉一些无用的分枝来降低“过拟合”的风险。
    Q:什么是增益率?什么是基尼系数?
    A:信息增益对于选项多的属性有偏好。如果我们把训练样本中每个西瓜编号,然后把编号也作为待选择属性,那么编号肯定能带来最多的信息,最大程度降低混淆程度,所以编号这个属性肯定会被选中。但是这然并卵,因为这样训练出来的算法对于新的样本根本不起作用。所以我们不会选择编号作为属性。类似的,如果有些属性的可选项特别多,比如色泽现在有浅白、白、青绿、绿、墨绿五个可选项,那么色泽被选中的几率比其他属性要大。所以可以考虑用增益率代替信息增益:
  5. {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
    基尼系数则是另一项可以考虑的指标:
  6. {failImgCache = [];}if(failImgCache.indexOf(src) == -1 && src.trim().length){failImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-error').removeClass('md-img-loaded'); " onload="var src = window.removeLastModifyQuery(this.getAttribute('src'));if(!src.trim()) return;if(loadedImgCache.indexOf(src) == -1 && src.trim().length){loadedImgCache.push(src);}$(this).closest('.md-image').addClass('md-img-loaded').removeClass('md-img-error');" style="box-sizing: border-box; border-width: 0px 4px 0px 2px; border-right-style: solid; border-left-style: solid; border-right-color: transparent; border-left-color: transparent; vertical-align: middle; max-width: 100%; cursor: default;">
    基尼系数表示从一堆样本中随机抽取两个样本,这 两个样本不同类的概率,也就刚是一个是好瓜一个是坏瓜的概率。这样,基尼系数越低,越意味着样本集中某一类比另一类样本多,也就是说,混淆程度越低。若选择某个属性后各个划分子集的基尼系数最低,那么就选择这个属性。
    Q:剪枝处理的过程是怎样进行的?
    A:剪枝分为预剪枝和后剪枝。预剪枝过程就是杂生成决策树时,用一些新的样本去测试刚刚训练好的节点,如果这个节点的存在并不能让决策树的泛化性能提高,也就是说分类精度或者其他衡量指标提高了,就把这个刚训练出来的节点舍弃。后剪枝过程是用一些新的样本来测试这课刚刚训练出来的决策树,从下往上开始测试,如果某个节点被换成叶节点后整一棵决策树的泛化性能提高了,,也就是说分类精度或者其他衡量指标提高了那么就直接替换掉。
    实现

import org.apache.spark.ml.Pipelineimport org.apache.spark.ml.classification.DecisionTreeClassificationModelimport org.apache.spark.ml.classification.DecisionTreeClassifierimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluatorimport org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}​// Load the data stored in LIBSVM format as a DataFrame.val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")​// Index labels, adding metadata to the label column.// Fit on whole dataset to include all labels in index.val labelIndexer = new StringIndexer() .setInputCol("label") .setOutputCol("indexedLabel") .fit(data)// Automatically identify categorical features, and index them.val featureIndexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexedFeatures") .setMaxCategories(4) // features with > 4 distinct values are treated as continuous. .fit(data)​// Split the data into training and test sets (30% held out for testing).val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))​// Train a DecisionTree model.val dt = new DecisionTreeClassifier() .setLabelCol("indexedLabel") .setFeaturesCol("indexedFeatures")​// Convert indexed labels back to original labels.val labelConverter = new IndexToString() .setInputCol("prediction") .setOutputCol("predictedLabel") .setLabels(labelIndexer.labels)​// Chain indexers and tree in a Pipeline.val pipeline = new Pipeline() .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))​// Train model. This also runs the indexers.val model = pipeline.fit(trainingData)​// Make predictions.val predictions = model.transform(testData)​// Select example rows to display.predictions.select("predictedLabel", "label", "features").show(5)​// Select (prediction, true label) and compute test error.val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("indexedLabel") .setPredictionCol("prediction") .setMetricName("accuracy")val accuracy = evaluator.evaluate(predictions)println("Test Error = " + (1.0 - accuracy))​val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]println("Learned classification tree model:\n" + treeModel.toDebugString)

Q文:http://www.jianshu.com/p/d925a064a13d

你可能感兴趣的:(Spark Mlib- Decision Tree)