python训练模型_Java读取python训练的模型

Java读取python训练的模型

package com.hikvision.tpse.mml.scala

import com.microsoft.ml.spark.lightgbm.LightGBMClassificationModel

import org.apache.arrow.vector.types.pojo.ArrowType.Struct

import org.apache.spark.sql.types.StructType

import org.apache.spark.sql.{DataFrame, SparkSession}

import org.slf4j.{Logger, LoggerFactory}

import scala.collection.mutable

class LightGbmExplore {

@transient lazy private[this] val log: Logger = LoggerFactory.getLogger(this.getClass)

def createSession(): Unit = {

}

}

object LightGbmExplore {

@transient lazy private[this] val log: Logger = LoggerFactory.getLogger(this.getClass)

/**

* 第一批数据路径

*/

private val modelPath_batch_1 = "A:\\mnt\\disk1\\data\\work\\wuyuan_v3_data_4_t4_lb\\s6_unified_model_result\\jinku\\model_0514\\wugan_0514_model.txt"

/**

* 第二批次数据路径

*/

private val modelPath_batch_2 = "A:\\mnt\\disk1\\data\\work\\wuyuan_v3_data_4_t4_lb\\s6_unified_model_result\\jinku\\model_0514\\wugan_0514_model.txt"

def main(args: Array[String]): Unit = {

//创建spark local连接

//fixme测试使用,部署到服务器可以删除.

System.setProperty("hadoop.home.dir", "C:\\app2\\hadoop-2.7.6")

val spark: SparkSession = SparkSession.builder().appName("test-LightGbm").config("spark.debug.maxToStringFields", "1000").master("local[*]").getOrCreate()

spark.sparkContext.setLogLevel("warn")

val model: LightGBMClassificationModel = com.microsoft.ml.spark.lightgbm.LightGBMClassificationModel.loadNativeModelFromFile(modelPath_batch_1)

log.info("模型加载成功.")

val y: org.apache.spark.ml.linalg.Vector = new org.apache.spark.ml.linalg.DenseVector(Array(0, 6891.386295979998, 6119.465462649059, 7440.010157031022, Double.NaN, Double.NaN, Double.NaN))

log.info("预测.")

val sim: Double = model.getModel.score(y, raw = true, classification = true).max

println(sim)

//读取csv文件

val df: DataFrame = spark.sqlContext.read

.format("csv")

.option("header", value = true)

.load("A:\\mnt\\disk1\\data\\work\\wuyuan_v3_data_4_t4_lb\\s6_unified_model_result\\jinku\\model_0514\\data_10.csv")

将全部的列转换成double类型

//val colNames: Array[String] = df.columns

//import org.apache.spark.sql.functions._

//val cols: Array[Column] = colNames.map(f => col(f).cast(DoubleType))

//val data1: DataFrame = df.select(cols: _*)

//val colNames = df.columns

//var df1 = df

//import org.apache.spark.sql.functions._

//for (colName new org.apache.spark.ml.linalg.DenseVector(Array(row))).map(row => model.getModel.score(row, raw = true, classification = true)).collect()

//println(r)

//val array: Array[Seq[Any]] = data.collect().map(_.toSeq)

//println(array)

//import scala.collection.

//data.collect().foreach((row: Row) => {

// val x = row(2).asInstanceOf[Double]

// println(row(2))

// //val value: Any = row(1).asInstanceOf[Double]

// //println(value)

//})

import org.apache.spark.sql.DataFrame

import org.apache.spark.sql.functions.col

val columns: Array[String] = df.columns

val colNames = df.columns

var df1 = df

import org.apache.spark.sql.types._

for (colName currentDF.withColumn(column, col(column).cast("Double")) }

//df2.printSchema()

//println(df("10.33.24.100_10.33.24.55"))

//df.foreach(r => {

// val lst: mutable.ListBuffer[Any] = new scala.collection.mutable.ListBuffer[Any]()

// for (i

你可能感兴趣的:(python训练模型)