Spark MLlib 编程

数据集的构造

val rawData = sc.textFile("...")
val data = rawdata.map{ line =>
  val row = line.split(',').map(_.toDouble)
  val featVec = Vectors.dense(row.init)
  val label = row.last
  LabeledPoint(label, featVec)
}

训练集,交叉验证集(CV),测试集的构造

val Array(trainData, cvData, testData) = 
  data.randomSplit(Array(.8, .1, .1))
trainData.cache()
cvData.cache()
testData.cache()

模型训练与模型评价(metric)

  • MultiClassMetrics
  • BinaryClassificationMetrics
def getMetrics(model: DecisionTreeModel, data: RDD[LabeledPoint]) = {
  val predsAndLabels = data.map(sample => 
 (model.predict(sample.features), sample.label))
  new MultiClassMetrics(predsAndLabels)
}

val model = DecisionTree.trainClassifier(trainData, numClasses, Map[Int, Int](), "gini")
val metrics = getMetrics(model, cvData)

统计样本集的类别分布

def classProb(data: RDD[LabeledPoint]) = {
    val countsByCategory = data.map(_.label).countByValue()
    val counts = countsByCategory.toArray().sortBy(_._1).map(_._2)
    counts.map(_.toDouble/counts.sum)
}

超参的确定(在CV上进行评估)

val evaluations = 
  for ( impurity <- Array("gini", "entropy");
    depth <- Array(1, 20);
    bins <- Array(10, 300)
  )
  yield {
    val model = DecisionTree.trainClassifier(trainData, numClasses.toInt, Map[Int, Int](), impurity, depth, bins)
    val predsAndLabels = cvData.map(sample => (model.predict(sample.features), sample.label))
    val accuracy = new MultiClassMetrics(predsAndLabels)
    ((impurity, depth, bins), accuracy)
  }

evaluations.sotyBy(_._2).reverse.foreach(println)

你可能感兴趣的:(Spark MLlib 编程)