spark java web_在Java Web中使用Spark MLlib训练的模型

PMML是一种通用的配置文件,只要遵循标准的配置文件,就可以在Spark中训练机器学习模型,然后再web接口端去使用。目前应用最广的就是基于Jpmml来加载模型在javaweb中应用,这样就可以实现跨平台的机器学习应用了。

spark java web_在Java Web中使用Spark MLlib训练的模型_第1张图片

训练模型

首先在spark MLlib中使用mllib包下的逻辑回归训练模型:

import org.apache.spark.mllib.classification.{LogisticRegressionModel, LogisticRegressionWithLBFGS}

import org.apache.spark.mllib.evaluation.MulticlassMetrics

import org.apache.spark.mllib.regression.LabeledPoint

import org.apache.spark.mllib.util.MLUtils

val training = spark.sparkContext

.parallelize(Seq("0,1 2 3 1", "1,2 4 1 5", "0,7 8 3 6", "1,2 5 6 9").map( line => LabeledPoint.parse(line)))

// Run training algorithm to build the model

val model = new LogisticRegressionWithLBFGS()

.setNumClasses(2)

.run(training)

val test = spark.sparkContext

.parallelize(Seq("0,1 2 3 1").map( line => LabeledPoint.parse(line)))

// Compute raw scores on the test set.

val predictionAndLabels = test.map { case LabeledPoint(label, features) =>

val prediction = model.predict(features)

(prediction, label)

}

// Get evaluation metrics.

val metrics = new MulticlassMetrics(predictionAndLabels)

val accuracy = metrics.accuracy

println(s"Accuracy = $accuracy")

// Save and load model

// model.save(spark.sparkContext, "target/tmp/scalaLogisticRegressionWithLBFGSModel")

// val sameModel = LogisticRegressionModel.load(spark.sparkContext,"target/tmp/scalaLogisticRegressionWithLBFGSModel")

model.toPMML(spark.sparkContext, "/tmp/xhl/data/test2")

训练得到的模型保存到hdfs。

PMML模型文件

模型下载到本地,重新命名为xml。

可以看到默认四个特征分别叫做feild_0,field_1...目标为target

2018-11-15T10:22:25

接口使用

在接口的web工程中引入maven jar:

org.jpmml

pmml-evaluator

1.4.3

org.jpmml

pmml-evaluator-extension

1.4.3

接口代码中直接读取pmml,使用模型进行预测:

package soundsystem;

import org.dmg.pmml.FieldName;

import org.dmg.pmml.PMML;

import org.jpmml.evaluator.*;

import java.io.FileInputStream;

import java.io.InputStream;

import java.util.HashMap;

import java.util.LinkedHashMap;

import java.util.List;

import java.util.Map;

public class PMMLDemo2 {

private Evaluator loadPmml(){

PMML pmml = new PMML();

try(InputStream inputStream = new FileInputStream("/Users/xingoo/Desktop/test2.xml")){

pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);

} catch (Exception e) {

e.printStackTrace();

}

ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();

return modelEvaluatorFactory.newModelEvaluator(pmml);

}

private Object predict(Evaluator evaluator,int a, int b, int c, int d) {

Map data = new HashMap();

data.put("field_0", a);

data.put("field_1", b);

data.put("field_2", c);

data.put("field_3", d);

List inputFields = evaluator.getInputFields();

//过模型的原始特征,从画像中获取数据,作为模型输入

Map arguments = new LinkedHashMap();

for (InputField inputField : inputFields) {

FieldName inputFieldName = inputField.getName();

Object rawValue = data.get(inputFieldName.getValue());

FieldValue inputFieldValue = inputField.prepare(rawValue);

arguments.put(inputFieldName, inputFieldValue);

}

Map results = evaluator.evaluate(arguments);

List targetFields = evaluator.getTargetFields();

TargetField targetField = targetFields.get(0);

FieldName targetFieldName = targetField.getName();

ProbabilityDistribution target = (ProbabilityDistribution) results.get(targetFieldName);

System.out.println(a + " " + b + " " + c + " " + d + ":" + target);

return target;

}

public static void main(String args[]){

PMMLDemo2 demo = new PMMLDemo2();

Evaluator model = demo.loadPmml();

demo.predict(model,2,5,6,8);

demo.predict(model,7,9,3,6);

demo.predict(model,1,2,3,1);

demo.predict(model,2,4,1,5);

}

}

得到输出内容:

2 5 6 8:ProbabilityDistribution{result=1, probability_entries=[1=0.9999949538769296, 0=5.046123070395758E-6]}

7 9 3 6:ProbabilityDistribution{result=0, probability_entries=[1=1.1216598160542013E-9, 0=0.9999999988783402]}

1 2 3 1:ProbabilityDistribution{result=0, probability_entries=[1=2.363331367481431E-8, 0=0.9999999763666864]}

2 4 1 5:ProbabilityDistribution{result=1, probability_entries=[1=0.9999999831203591, 0=1.6879640907241367E-8]}

其中result为LR最终的结果,概率为二分类的概率。

参考资料

你可能感兴趣的:(spark,java,web)