Java读取python训练的模型
package com.hikvision.tpse.mml.scala
import com.microsoft.ml.spark.lightgbm.LightGBMClassificationModel
import org.apache.arrow.vector.types.pojo.ArrowType.Struct
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.slf4j.{Logger, LoggerFactory}
import scala.collection.mutable
class LightGbmExplore {
@transient lazy private[this] val log: Logger = LoggerFactory.getLogger(this.getClass)
def createSession(): Unit = {
}
}
object LightGbmExplore {
@transient lazy private[this] val log: Logger = LoggerFactory.getLogger(this.getClass)
/**
* 第一批数据路径
*/
private val modelPath_batch_1 = "A:\\mnt\\disk1\\data\\work\\wuyuan_v3_data_4_t4_lb\\s6_unified_model_result\\jinku\\model_0514\\wugan_0514_model.txt"
/**
* 第二批次数据路径
*/
private val modelPath_batch_2 = "A:\\mnt\\disk1\\data\\work\\wuyuan_v3_data_4_t4_lb\\s6_unified_model_result\\jinku\\model_0514\\wugan_0514_model.txt"
def main(args: Array[String]): Unit = {
//创建spark local连接
//fixme测试使用,部署到服务器可以删除.
System.setProperty("hadoop.home.dir", "C:\\app2\\hadoop-2.7.6")
val spark: SparkSession = SparkSession.builder().appName("test-LightGbm").config("spark.debug.maxToStringFields", "1000").master("local[*]").getOrCreate()
spark.sparkContext.setLogLevel("warn")
val model: LightGBMClassificationModel = com.microsoft.ml.spark.lightgbm.LightGBMClassificationModel.loadNativeModelFromFile(modelPath_batch_1)
log.info("模型加载成功.")
val y: org.apache.spark.ml.linalg.Vector = new org.apache.spark.ml.linalg.DenseVector(Array(0, 6891.386295979998, 6119.465462649059, 7440.010157031022, Double.NaN, Double.NaN, Double.NaN))
log.info("预测.")
val sim: Double = model.getModel.score(y, raw = true, classification = true).max
println(sim)
//读取csv文件
val df: DataFrame = spark.sqlContext.read
.format("csv")
.option("header", value = true)
.load("A:\\mnt\\disk1\\data\\work\\wuyuan_v3_data_4_t4_lb\\s6_unified_model_result\\jinku\\model_0514\\data_10.csv")
将全部的列转换成double类型
//val colNames: Array[String] = df.columns
//import org.apache.spark.sql.functions._
//val cols: Array[Column] = colNames.map(f => col(f).cast(DoubleType))
//val data1: DataFrame = df.select(cols: _*)
//val colNames = df.columns
//var df1 = df
//import org.apache.spark.sql.functions._
//for (colName new org.apache.spark.ml.linalg.DenseVector(Array(row))).map(row => model.getModel.score(row, raw = true, classification = true)).collect()
//println(r)
//val array: Array[Seq[Any]] = data.collect().map(_.toSeq)
//println(array)
//import scala.collection.
//data.collect().foreach((row: Row) => {
// val x = row(2).asInstanceOf[Double]
// println(row(2))
// //val value: Any = row(1).asInstanceOf[Double]
// //println(value)
//})
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
val columns: Array[String] = df.columns
val colNames = df.columns
var df1 = df
import org.apache.spark.sql.types._
for (colName currentDF.withColumn(column, col(column).cast("Double")) }
//df2.printSchema()
//println(df("10.33.24.100_10.33.24.55"))
//df.foreach(r => {
// val lst: mutable.ListBuffer[Any] = new scala.collection.mutable.ListBuffer[Any]()
// for (i