Mllib系列之决策树

决策树是机器学习领域的经典算法之一,这里借鉴了一位博友的博客文章<决策树算法学习笔记> http://blog.sina.com.cn/s/blog_8095e51d01013chj.html.
决策树并不需要很强的数学知识,理解上面也比较直观.首先看如下的一组数据:(各个属性的含义:年龄,0:青年,1:中年. 身高,0:高,1:低.收入,0:低,1:高, 满意度,0:不满意,1:满意)

客户ID    年龄  身高  收入        满意度
001     0       0        1   1
002     0       0        0   0
003     1   1    0   0
004     1   0    0   0
005     0   1    0   0
006     0   1    1   1

上述数据是一个经典分类问题的数据,一共有两个类别(满意度),两种离散属性(故障原因,故障类型),一种连续属性(保障时长).对于连续的属性,要实现进行离散化,选取的分裂点保证其gini指数最大即可(这里不进行较为详细的说明,请自行查阅资料)
1 分裂点的选取:
直观上,分类点应该选取最纯的特征属性.这里的纯指的是该属性能够最大划分训练集.比如:某一个属性能够最大程度的区分数据集,这里的收属性就是最佳的选择,收入=1,全部都是满意,收入=0,全部都是不满意.
评估分裂点的指标有:增益指数(GA)和基尼指数(GINI).基尼指数要优于增益指数,主要是因为增益指数倾向于属性值多的属性.
2 剪枝:
剪枝过程是决策树不可缺少的一个步骤,可能预防过拟合现象的发生.
一些注意事项:
对于离散数据的分类,直接在叶子节点选取类别数目最多的类别即可;而对于连续数据的分类(回归),要取该叶子节点数据的平均值.
下面的例子直接来自于spark官网,这里写了一些注释方便阅读:

package zhangluoyang.sparkDemo

import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
// $example off$
import org.apache.spark.{SparkConf, SparkContext}

object DecisionTreeClassificationExample {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("DecisionTreeClassificationExample").setMaster("local")
    val sc = new SparkContext(conf)

    // $example on$
    // Load and parse the data file.
    val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
    // Split the data into training and test sets (30% held out for testing)
    val splits = data.randomSplit(Array(0.7, 0.3))
    val (trainingData, testData) = (splits(0), splits(1))

    val numClasses = 2
    // 用于表示每一个属性的取值数目, 比如:1->20,表示第1个特征有20种类别
    // 连续数据为空
    //全为空表示全部特征是连续属性
    val categoricalFeaturesInfo = Map[Int, Int]()
    // 选取分类点的方式
    val impurity = "gini"
    // 树最大深度
    val maxDepth = 5
    // 使用的最大的特征数目
    val maxBins = 32

    val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo,
      impurity, maxDepth, maxBins)

    // 计算误差
    val labelAndPreds = testData.map { point =>
      val prediction = model.predict(point.features)
      (point.label, prediction)
    }

    val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count()
    println("Test Error = " + testErr)
    println("Learned classification tree model:\n" + model.toDebugString)

    // 保存模型文件
    model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
    val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
    // $example off$
  }
}
// scalastyle:on println

你可能感兴趣的:(spark,机器学习,MLlib)