Apache Spark MLlib学习笔记(五)MLlib决策树类算法源码解析 1

从这章开始分析spark MLlib的decision tree的源码实现。
首先看下官方给的java使用决策树的例子,路径是/home/yangqiao/codes/spark/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTree.java
为了方便,部分解析我将直接在代码上进行注释:

public final class JavaDecisionTree {

  public static void main(String[] args) {
    String datapath = "data/mllib/sample_libsvm_data.txt";//官方给的数据文件路径
    if (args.length == 1) {
      datapath = args[0];//可以使用自己的数据文件,作为参数传入主函数即可
    } else if (args.length > 1) {//路径非法
      System.err.println("Usage: JavaDecisionTree <libsvm format data file>");
      System.exit(1);
    }
    SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
    //spark的配置信息,这里是为APP命名
    JavaSparkContext sc = new JavaSparkContext(sparkConf);
    //JavaSparkContext是spark程序的主入口,连接到spark集群,可以用来在集群上创建RDD,交换变量。
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache();
    /*cache()方法使得RDD只驻留在内存 loadLibSVMFile将 LIBSVM格式的二进制标识数据转换成一个RDD[LabeledPoint]*/
    // 计算数据中有多少类
    Integer numClasses = data.map(new Function<LabeledPoint, Double>() {
      @Override public Double call(LabeledPoint p) {
        return p.label();
      }
    }).countByValue().size();
    /*map方法是对每一个RDD都执行function,并返回一个新的RDD,在此操作中,函数返回了label值, 接下来使用了RDD的countByValue方法,该方法返回不同值的 (value, count)映射*/
    // 设置参数
    // Empty categoricalFeaturesInfo indicates all features are continuous.
    HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
    String impurity = "gini";
    Integer maxDepth = 5;
    Integer maxBins = 32;
    //采用GINI作为分裂指标,最大深度为5,最大叶子节点数目为32
    // 训练用于分类的决策树
    final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
      categoricalFeaturesInfo, impurity, maxDepth, maxBins);

    // 后面是使用模型进行预测和进行模型评价,稍晚再分析,主要分析模型构建过程
    JavaPairRDD<Double, Double> predictionAndLabel =
      data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
        @Override public Tuple2<Double, Double> call(LabeledPoint p) {
          return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
        }
      });
    Double trainErr =
      1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
        @Override public Boolean call(Tuple2<Double, Double> pl) {
          return !pl._1().equals(pl._2());
        }
      }).count() / data.count();
    System.out.println("Training error: " + trainErr);
    System.out.println("Learned classification tree model:\n" + model);

    // Train a DecisionTree model for regression.
    impurity = "variance";
    final DecisionTreeModel regressionModel = DecisionTree.trainRegressor(data,
        categoricalFeaturesInfo, impurity, maxDepth, maxBins);

    // Evaluate model on training instances and compute training error
    JavaPairRDD<Double, Double> regressorPredictionAndLabel =
      data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
        @Override public Tuple2<Double, Double> call(LabeledPoint p) {
          return new Tuple2<Double, Double>(regressionModel.predict(p.features()), p.label());
        }
      });
    Double trainMSE =
      regressorPredictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
        @Override public Double call(Tuple2<Double, Double> pl) {
          Double diff = pl._1() - pl._2();
          return diff * diff;
        }
      }).reduce(new Function2<Double, Double, Double>() {
        @Override public Double call(Double a, Double b) {
          return a + b;
        }
      }) / data.count();
    System.out.println("Training Mean Squared Error: " + trainMSE);
    System.out.println("Learned regression tree model:\n" + regressionModel);

    sc.stop();
  }
}

以上代码可以看出,模型构建的核心代码是:

 final DecisionTreeModel model = DecisionTree.trainClassifier(data, numClasses,
      categoricalFeaturesInfo, impurity, maxDepth, maxBins);

也就是说建立模型调用了trainClassifier方法,那么trainClassifier具体有什么呢,就需要深入源码分析。
按照以下路径打开源码文件:
/home/yangqiao/codes/spark/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
先重点分析DecisionTree.scala文件,相关分析将在下一篇博客连载。

你可能感兴趣的:(apache,源码,spark,MLlib)