Java读取python训练的模型

Java读取python训练的模型

package com.hikvision.tpse.mml.scala

import com.microsoft.ml.spark.lightgbm.LightGBMClassificationModel
import org.apache.arrow.vector.types.pojo.ArrowType.Struct
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable

class LightGbmExplore {
  @transient lazy private[this] val log: Logger = LoggerFactory.getLogger(this.getClass)

  def createSession(): Unit = {

  }
}

object LightGbmExplore {
  @transient lazy private[this] val log: Logger = LoggerFactory.getLogger(this.getClass)

  /**
   * 第一批数据路径
   */
  private val modelPath_batch_1 = "A:\\mnt\\disk1\\data\\work\\wuyuan_v3_data_4_t4_lb\\s6_unified_model_result\\jinku\\model_0514\\wugan_0514_model.txt"
  /**
   * 第二批次数据路径
   */
  private val modelPath_batch_2 = "A:\\mnt\\disk1\\data\\work\\wuyuan_v3_data_4_t4_lb\\s6_unified_model_result\\jinku\\model_0514\\wugan_0514_model.txt"


  def main(args: Array[String]): Unit = {
    //创建spark local连接
    //fixme测试使用,部署到服务器可以删除.
    System.setProperty("hadoop.home.dir", "C:\\app2\\hadoop-2.7.6")
    val spark: SparkSession = SparkSession.builder().appName("test-LightGbm").config("spark.debug.maxToStringFields", "1000").master("local[*]").getOrCreate()
    spark.sparkContext.setLogLevel("warn")
    val model: LightGBMClassificationModel = com.microsoft.ml.spark.lightgbm.LightGBMClassificationModel.loadNativeModelFromFile(modelPath_batch_1)
    log.info("模型加载成功.")
    val y: org.apache.spark.ml.linalg.Vector = new org.apache.spark.ml.linalg.DenseVector(Array(0, 6891.386295979998, 6119.465462649059, 7440.010157031022, Double.NaN, Double.NaN, Double.NaN))
    log.info("预测.")
    val sim: Double = model.getModel.score(y, raw = true, classification = true).max
    println(sim)
    //读取csv文件
    val df: DataFrame = spark.sqlContext.read
      .format("csv")
      .option("header", value = true)
      .load("A:\\mnt\\disk1\\data\\work\\wuyuan_v3_data_4_t4_lb\\s6_unified_model_result\\jinku\\model_0514\\data_10.csv")

    ////将全部的列转换成double类型
    //val colNames: Array[String] = df.columns
    //import org.apache.spark.sql.functions._
    //val cols: Array[Column] = colNames.map(f => col(f).cast(DoubleType))
    //val data1: DataFrame = df.select(cols: _*)


    //val colNames = df.columns
    //var df1 = df
    //import org.apache.spark.sql.functions._
    //for (colName <- colNames) {
    //  df1 = df1.withColumn(colName, col(colName).cast(DoubleType))
    //}
    //df1.show()


    ////局部查看结果
    //data1.printSchema()
    //data1.show(1)

    val model2: LightGBMClassificationModel = com.microsoft.ml.spark.lightgbm.LightGBMClassificationModel.loadNativeModelFromFile(modelPath_batch_2)
    //val r = data.map(row => new org.apache.spark.ml.linalg.DenseVector(Array(row))).map(row => model.getModel.score(row, raw = true, classification = true)).collect()
    //println(r)

    //val array: Array[Seq[Any]] = data.collect().map(_.toSeq)
    //println(array)
    //import scala.collection.

    //data.collect().foreach((row: Row) => {
    //  val x = row(2).asInstanceOf[Double]
    //  println(row(2))
    //  //val value: Any = row(1).asInstanceOf[Double]
    //  //println(value)
    //})


    import org.apache.spark.sql.DataFrame
    import org.apache.spark.sql.functions.col
    val columns: Array[String] = df.columns


    val colNames = df.columns

    var df1 = df
    import org.apache.spark.sql.types._
    for (colName <- colNames) {
      df1 = df1.withColumn(colName, col(colName).cast(DoubleType))
    }
    df1.show()


    //val df2: DataFrame = columns.foldLeft(df) { (currentDF, column) => currentDF.withColumn(column, col(column).cast("Double")) }

    //df2.printSchema()

    //println(df("10.33.24.100_10.33.24.55"))


    //df.foreach(r => {
    //  val lst: mutable.ListBuffer[Any] = new scala.collection.mutable.ListBuffer[Any]()
    //  for (i <- 0 until r.length) {
    //    val x: Any = r.get(i)
    //    //val x: Double = r.getAs[Double](i)
    //    lst.append(x)
    //  }
    //  //println(lst)
    //  val arr: Array[Any] = lst.toArray
    //  for (xi: Int <- arr.indices) {
    //    val xxx: Double = arr(xi).asInstanceOf[Double]
    //    println(xxx)
    //  }

      //StructType(Array("xx",Do))

      //val arr2: Array[Double] = arr.asInstanceOf[Array[Double]]
      //val vec: org.apache.spark.ml.linalg.Vector = new org.apache.spark.ml.linalg.DenseVector(arr2)
      //val sim: Double = model2.getModel.score(vec, raw = true, classification = true).max
      //println(sim)


      //val doubles: Array[Double] = lst.toArray[Double]
      //val denseVec: org.apache.spark.ml.linalg.Vector = new org.apache.spark.ml.linalg.DenseVector(doubles)
      //model.getModel.score(denseVec, raw = true, classification = true)
    //})


  }
}

你可能感兴趣的:(spark)