xgboost之spark上运行-scala接口

概述

xgboost可以在spark上运行,我用的xgboost的版本是0.7的版本,目前只支持spark2.0以上版本上运行,

编译好jar包,加载到maven仓库里面去:

 

 
  1. mvn install:install-file -Dfile=xgboost4j-spark-0.7-jar-with-dependencies.jar -DgroupId=ml.dmlc -DartifactId=xgboost4j-spark -Dversion=0.7 -Dpackaging=jar

 

 

添加依赖:

 

 
  1. ml.dmlc

  2. xgboost4j-spark

  3. 0.7

  4. org.apache.spark

  5. spark-core_2.10

  6. 2.0.0

  7. org.apache.spark

  8. spark-mllib_2.10

  9. 2.0.0

 

 


RDD接口:

 

 
  1. package com.meituan.spark_xgboost

  2. import org.apache.log4j.{ Level, Logger }

  3. import org.apache.spark.{ SparkConf, SparkContext }

  4. import ml.dmlc.xgboost4j.scala.spark.XGBoost

  5. import org.apache.spark.sql.{ SparkSession, Row }

  6. import org.apache.spark.mllib.util.MLUtils

  7. import org.apache.spark.ml.feature.LabeledPoint

  8. import org.apache.spark.ml.linalg.Vectors

  9. object XgboostR {

  10.  
  11.  
  12. def main(args: Array[String]): Unit = {

  13. Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)

  14. Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

  15. val spark = SparkSession.builder.master("local").appName("example").

  16. config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").

  17. config("spark.sql.shuffle.partitions", "20").getOrCreate()

  18. spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

  19. val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"

  20. val trainString = "agaricus.txt.train"

  21. val testString = "agaricus.txt.test"

  22. val train = MLUtils.loadLibSVMFile(spark.sparkContext, path + trainString)

  23. val test = MLUtils.loadLibSVMFile(spark.sparkContext, path + testString)

  24. val traindata = train.map { x =>

  25. val f = x.features.toArray

  26. val v = x.label

  27. LabeledPoint(v, Vectors.dense(f))

  28. }

  29. val testdata = test.map { x =>

  30. val f = x.features.toArray

  31. val v = x.label

  32. Vectors.dense(f)

  33. }

  34.  
  35.  
  36. val numRound = 15

  37.  
  38. //"objective" -> "reg:linear", //定义学习任务及相应的学习目标

  39. //"eval_metric" -> "rmse", //校验数据所需要的评价指标 用于做回归

  40.  
  41. val paramMap = List(

  42. "eta" -> 1f,

  43. "max_depth" ->5, //数的最大深度。缺省值为6 ,取值范围为:[1,∞]

  44. "silent" -> 1, //取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0

  45. "objective" -> "binary:logistic", //定义学习任务及相应的学习目标

  46. "lambda"->2.5,

  47. "nthread" -> 1 //XGBoost运行时的线程数。缺省值是当前系统可以获得的最大线程数

  48. ).toMap

  49. println(paramMap)

  50.  
  51.  
  52. val model = XGBoost.trainWithRDD(traindata, paramMap, numRound, 55, null, null, useExternalMemory = false, Float.NaN)

  53. print("sucess")

  54.  
  55. val result=model.predict(testdata)

  56. result.take(10).foreach(println)

  57. spark.stop();

  58.  
  59. }

  60.  
  61. }


 

DataFrame接口:

 

 
  1. package com.meituan.spark_xgboost

  2. import org.apache.log4j.{ Level, Logger }

  3. import org.apache.spark.{ SparkConf, SparkContext }

  4. import ml.dmlc.xgboost4j.scala.spark.XGBoost

  5. import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics

  6. import org.apache.spark.sql.{ SparkSession, Row }

  7. object XgboostD {

  8. def main(args: Array[String]): Unit = {

  9. Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)

  10. Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)

  11. val spark = SparkSession.builder.master("local").appName("example").

  12. config("spark.sql.warehouse.dir", s"file:///Users/shuubiasahi/Documents/spark-warehouse").

  13. config("spark.sql.shuffle.partitions", "20").getOrCreate()

  14. spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

  15. val path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/"

  16. val trainString = "agaricus.txt.train"

  17. val testString = "agaricus.txt.test"

  18.  
  19. val train = spark.read.format("libsvm").load(path + trainString).toDF("label", "feature")

  20.  
  21. val test = spark.read.format("libsvm").load(path + testString).toDF("label", "feature")

  22.  
  23. val numRound = 15

  24.  
  25. //"objective" -> "reg:linear", //定义学习任务及相应的学习目标

  26. //"eval_metric" -> "rmse", //校验数据所需要的评价指标 用于做回归

  27.  
  28. val paramMap = List(

  29. "eta" -> 1f,

  30. "max_depth" -> 5, //数的最大深度。缺省值为6 ,取值范围为:[1,∞]

  31. "silent" -> 1, //取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0

  32. "objective" -> "binary:logistic", //定义学习任务及相应的学习目标

  33. "lambda" -> 2.5,

  34. "nthread" -> 1 //XGBoost运行时的线程数。缺省值是当前系统可以获得的最大线程数

  35. ).toMap

  36. val model = XGBoost.trainWithDataFrame(train, paramMap, numRound, 45, obj = null, eval = null, useExternalMemory = false, Float.NaN, "feature", "label")

  37. val predict = model.transform(test)

  38.  
  39. val scoreAndLabels = predict.select(model.getPredictionCol, model.getLabelCol)

  40. .rdd

  41. .map { case Row(score: Double, label: Double) => (score, label) }

  42.  
  43. //get the auc

  44. val metric = new BinaryClassificationMetrics(scoreAndLabels)

  45. val auc = metric.areaUnderROC()

  46. println("auc:" + auc)

  47.  
  48. }

  49.  
  50. }

你可能感兴趣的:(机器学习,scala)