利用xgboost4j下的xgboost分类模型案例

package spark.xgb.test

import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

/**
 * Created by zhaijianwei on 2017/12/7.
 */
object sparkWithDataFrame {
  def main(args: Array[String]) {
    if(args.length != 4){
      println(
        "usage: program num_of_rounds num_workers training_path test_path")
      sys.exit(1)
    }
    val numRound = args(0).toInt
    val num_workers = args(1).toInt
    val inputTrainPath = args(2)
    val inputTestPath = args(3)

    // 使用kyro序列化,需要对序列化的类进行注册
    val sparkConf = new SparkConf().setAppName("sparkWithDataFrame")
      .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    sparkConf.registerKryoClasses(Array(classOf[Booster])) 

    val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()

    val trainDF = sparkSession.sqlContext.read.format("libsvm").load(inputTrainPath)
    val testDF = sparkSession.sqlContext.read.format("libsvm").load(inputTestPath)

    val params = List(
      "eta" -> 0.1f,
      "max_depth" -> 2,
      "objective" -> "binary:logistic"
    ).toMap

    val xgbModel = XGBoost.trainWithDataFrame(trainDF, params, numRound, num_workers, useExternalMemory = true)
    xgbModel.transform(testDF).show()
  }
}

提交spark的shell程序:

numRound=100
num_workers=10
inputTrainPath="/tmp/zjw/agaricus.txt.train" //存放训练数据的hdfs路径
inputValidPath="/tmp/zjw/agaricus.txt.test"  //存放测试数据的hdfs路径

spark-submit --class spark.xgb.test.sparkWithDataFrame \
    --num-executors 60 \
    --executor-memory 16g \
    --driver-memory 16g \
    --executor-cores 4 \
    --queue root.bdp_jdw_up \
    --jars ./jar/xgboost4j-0.7.jar,./jar/xgboost4j-spark-0.7.jar \
./jar/spark_prac-1.0-SNAPSHOT.jar $numRound $num_workers $inputTrainPath $inputValidPath

运行结果:
利用xgboost4j下的xgboost分类模型案例_第1张图片

你可能感兴趣的:(spark)