Spark MLlib 入门学习笔记 - GradientBoostedTree和随机森林

GradientBoostedTree

参考文档。

train(input: RDD[LabeledPoint], boostingStrategy: BoostingStrategy): GradientBoostedTreesModel
Method to train a gradient boosting model.
input Training dataset: RDD of org.apache.spark.mllib.regression.LabeledPoint. For classification, labels should take values {0, 1, ..., numClasses-1}. For regression, labels are real numbers.
boostingStrategy Configuration options for the boosting algorithm.
returns GradientBoostedTreesModel that can be used for prediction.
测试代码,使用kyposis数据集。

package classify

import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy

object GDTree {

  def parseLine(line: String): LabeledPoint = {
    val parts = line.split(" ")
    val vd: Vector = Vectors.dense(parts(1).toInt, parts(2).toInt, parts(3).toInt)
    var target = 0
    parts(0) match {
      case "absent" => target = 0;
      case "present" => target = 1;
    }
    return LabeledPoint(target, vd)
  }

  def main(args: Array[String]) {
    val conf = new SparkConf().setMaster(args(0)).setAppName("Iris")
    val sc = new SparkContext(conf)
    val data = sc.textFile(args(1)).map(parseLine(_))

    val splits = data.randomSplit(Array(0.7, 0.3), seed = 11L)
    val trainData = splits(0)
    val testData = splits(1)

    val boostingStrategy = BoostingStrategy.defaultParams("Classification")
    boostingStrategy.numIterations = 3
    boostingStrategy.treeStrategy.numClasses = 2 //分类数量
    boostingStrategy.treeStrategy.maxDepth = 5  //树的高度
    boostingStrategy.treeStrategy.categoricalFeaturesInfo =  Map[Int, Int]() //输入格式

    val model = GradientBoostedTrees.train(trainData, boostingStrategy)

    val predictionAndLabel = testData.map(p => (model.predict(p.features), p.label))
    predictionAndLabel.foreach(println)

    val metrics = new MulticlassMetrics(predictionAndLabel)
    val precision = metrics.precision
    println("Precision = " + precision)
  }
}

随机森林,参考文档。

def trainRegressor(input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, seed: Int): RandomForestModel
Method to train a decision tree model for regression.
input Training dataset: RDD of org.apache.spark.mllib.regression.LabeledPoint. Labels are real numbers.
strategy Parameters for training each tree in the forest.
numTrees Number of trees in the random forest.
featureSubsetStrategy Number of features to consider for splits at each node. Supported values: "auto", "all", "sqrt", "log2", "onethird". If "auto" is set, this parameter is set based on numTrees: if numTrees == 1, set to "all"; if numTrees > 1 (forest) set to "onethird".
seed Random seed for bootstrapping and choosing feature subsets.
returns RandomForestModel that can be used for prediction.

测试代码,使用kyposis数据集。

package classify

import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest

object RFDTree {
  def parseLine(line: String): LabeledPoint = {
    val parts = line.split(" ")
    val vd: Vector = Vectors.dense(parts(1).toInt, parts(2).toInt, parts(3).toInt)
    var target = 0
    parts(0) match {
      case "absent" => target = 0;
      case "present" => target = 1;
    }
    return LabeledPoint(target, vd)
  }

  def main(args: Array[String]) {
    val conf = new SparkConf().setMaster(args(0)).setAppName("Iris")
    val sc = new SparkContext(conf)
    val data = sc.textFile(args(1)).map(parseLine(_))

    val splits = data.randomSplit(Array(0.7, 0.3), seed = 11L)
    val trainData = splits(0)
    val testData = splits(1)

    val numClasses = 2 //分类数量
    val categoricalFeaturesInfo = Map[Int, Int]() //输入格式
    val numTrees = 3
    val featuresSubsetStrategy = "auto"
    val impurity = "entropy" //信息增益计算方式 gini
    val maxDepth = 5 //树的高度
    val maxBins = 3 //分裂数据集
    val model = RandomForest.trainClassifier(trainData, numClasses, categoricalFeaturesInfo,
      numTrees, featuresSubsetStrategy, impurity, maxDepth, maxBins)

    val predictionAndLabel = testData.map(p => (model.predict(p.features), p.label))
    predictionAndLabel.foreach(println)

    val metrics = new MulticlassMetrics(predictionAndLabel)
    val precision = metrics.precision
    println("Precision = " + precision)
  }
}

你可能感兴趣的:(Spark)