有个需求就是要使用数据分析团队实现好的模型,而且是python的,要求在Flink平台上跑起来提供实时调用模型处理数据
文章目录
- 背景
- JPMML介绍
- 环境准备
- 安装
- 使用
- 步骤
- 示例:决策树分类Iris数据集
- 训练模型并获得PMML文件
- JAVA工程调用PMML模型
- 总结
在Flink平台上通过调用现有python实现的模型,进行实时预测处理
sklearn2pmml
0.14.0 or newer.pip install sklearn2pmml -i https://pypi.tuna.tsinghua.edu.cn/simple/
这里基于sklearn做测试,其他框架的pmml包请查阅作者github示例。
作者示例:https://github.com/jpmml/sklearn2pmml
A typical workflow can be summarized as follows:
1.Create a
PMMLPipeline
object, and populate it with pipeline steps as usual. Classsklearn2pmml.pipeline.PMMLPipeline
extends classsklearn.pipeline.Pipeline
with the following functionality:
- If the
PMMLPipeline.fit(X, y)
method is invoked withpandas.DataFrame
orpandas.Series
object as anX
argument, then its column names are used as feature names. Otherwise, feature names default to “x1”, “x2”, …, “x{number_of_features}”.- If the
PMMLPipeline.fit(X, y)
method is invoked withpandas.Series
object as any
argument, then its name is used as the target name (for supervised models). Otherwise, the target name defaults to “y”.2.Fit and validate the pipeline as usual.
3.Optionally, compute and embed verification data into the
PMMLPipeline
object by invokingPMMLPipeline.verify(X)
method with a small but representative subset of training data.4.Convert the
PMMLPipeline
object to a PMML file in local filesystem by invoking utility methodsklearn2pmml.sklearn2pmml(pipeline, pmml_destination_path)
.
PMMLPipeline
对象,并设置它的pipeline。PMMLPipeline
对象。预热模型。PMMLPipeline
对象转换成PMML文件github上作者有两个示例,一个决策树分类iris数据集,一个逻辑回归分类iris数据集,我这只演示决策树的示例
import pandas
from sklearn.datasets import load_iris
# github上作者的代码示例,我这直接用sklearn里的,不读文件
# iris_df = pandas.read_csv("Iris.csv")
# iris_X = iris_df[iris_df.columns.difference(["Species"])]
# iris_y = iris_df["Species"]
# 加载鸢尾花数据集(sklearn中的数据集)
iris = load_iris()
# 通过feature_names构造dataFrame
iris_df = pandas.DataFrame(iris.data, columns=iris.feature_names)
# 把iris的结果放到dataFrame的label属性中
iris_df['label'] = iris.target
# 声明dataFrame的新列项
iris_df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
# 划分数据集
iris_X =iris_df[iris_df.columns.difference(["label"])]
iris_y = iris_df["label"]
from sklearn.tree import DecisionTreeClassifier
from sklearn2pmml.pipeline import PMMLPipeline
# 这里分类模型就写classifier,作者定义好了不同模型的pipeline标识,
# 工作流内需要设置二元组,(名称,模型对象),名称也不是乱指定的,每个名称都是对应特定功能的transformer的
# 像"selector"对应特征选择,“mapper”对应特征预处理,”pca“对应pca,”classifier“对应分类器,”regressor“对应回归器
# 具体去看github上说明吧
pipeline = PMMLPipeline([
("classifier", DecisionTreeClassifier())
])
# 训练
pipeline.fit(iris_X, iris_y)
from sklearn2pmml import sklearn2pmml
# 把模型转成pmml文件
sklearn2pmml(pipeline, "D:\DecisionTreeIris.pmml", with_repr = True)
注意,执行时出现如下warn,无需理会
D:\ITinstall\anaconda3\lib\subprocess.py:848: RuntimeWarning: line buffering (buffering=1) isn't supported in binary mode, the default buffer size will be used
self.stdout = io.open(c2pread, 'rb', bufsize)
D:\ITinstall\anaconda3\lib\subprocess.py:853: RuntimeWarning: line buffering (buffering=1) isn't supported in binary mode, the default buffer size will be used
self.stderr = io.open(errread, 'rb', bufsize)
如下是针对PMMLPipeline构造的更多一些说明
能设置的名称其实不少,但是关于怎么设置这些二元组,作者都是在github上使用示例代码给出的,挺多使用方法分散在项目的不同角落(主要是README),找起来还挺费劲(估计都是用到了才会仔细一点一点搜,要不就在issue直接问作者了),而且也没统一的文档什么的。(可能作者觉得自己写的那些使用说明很详细,大家都能在各种链接之间跳来跳去找到问题的答案)
要对指定特征就行预处理需要用到mapper
mapper = mapper = DataFrameMapper([ (X.columns.to_list(), [ContinuousDomain(with_data = False),StandardScaler()]), ]) classifier = RandomForestClassifier(**params) pipeline = PMMLPipeline([ ("mapper",mapper), ("selector", SelectorProxy(VarianceThreshold())), ("classifier", classifier), ])
DataFrameMapper中传入二元组列表,前面是指定的列名,可以是多个,后面是处理方式。上面演示的是标准缩放,也可以进行行独热编码。
ContinuousDomain是这个库特色的特征装饰器,这个是对连续型特征进行装饰
装饰器主要作用就是能进行一些错误值、空值和离群点的处理。
还有其他像是”顺序特征“,”分类特征“,”时间特征“的装饰器,具体可以看官方说明
比较坑的一点是,连续型特征的装饰器会学习训练数据,分析离群点,然后在预测的时候会强制将离群点判定为非法值,从而导致预测的时候可能会发生拒绝接受特征的报错。这里再里设置with_data = False可以避免这个问题。
with_data是设置是否要再训练时对数据进行分析(分析离群点)
作者这样设计好像是因为,他认为模型不应该预测不在接受范围内的值,所以强迫你对离群点啊什么的进行处理。使用selector需要使用SelectorProxy对feature_selection下的对象进行包裹。
maven引入包(我这是完整的Flink工程,故有比较完整的Flink依赖,注意适当取用)
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0modelVersion>
<groupId>org.examplegroupId>
<artifactId>Flink_secondartifactId>
<version>1.0-SNAPSHOTversion>
<properties>
<project.build.sourceEncoding>UTF-8project.build.sourceEncoding>
<maven.compiler.source>1.8maven.compiler.source>
<maven.compiler.target>1.8maven.compiler.target>
<pmml.version>1.5.15pmml.version>
properties>
<dependencies>
<dependency>
<groupId>org.apache.flinkgroupId>
<artifactId>flink-javaartifactId>
<version>1.11.1version>
dependency>
<dependency>
<groupId>org.apache.flinkgroupId>
<artifactId>flink-streaming-java_2.12artifactId>
<version>1.11.1version>
dependency>
<dependency>
<groupId>org.apache.flinkgroupId>
<artifactId>flink-connector-kafka-0.11_2.12artifactId>
<version>1.11.1version>
dependency>
<dependency>
<groupId>org.apache.bahirgroupId>
<artifactId>flink-connector-redis_2.11artifactId>
<version>1.0version>
dependency>
<dependency>
<groupId>mysqlgroupId>
<artifactId>mysql-connector-javaartifactId>
<version>5.1.44version>
dependency>
<dependency>
<groupId>org.apache.flinkgroupId>
<artifactId>flink-connector-elasticsearch6_2.12artifactId>
<version>1.11.1version>
dependency>
<dependency>
<groupId>org.apache.flinkgroupId>
<artifactId>flink-statebackend-rocksdb_2.12artifactId>
<version>1.11.1version>
dependency>
<dependency>
<groupId>org.apache.flinkgroupId>
<artifactId>flink-table-planner_2.12artifactId>
<version>1.11.1version>
dependency>
<dependency>
<groupId>org.apache.flinkgroupId>
<artifactId>flink-table-planner-blink_2.12artifactId>
<version>1.11.1version>
dependency>
<dependency>
<groupId>org.apache.flinkgroupId>
<artifactId>flink-csvartifactId>
<version>1.11.1version>
dependency>
<dependency>
<groupId>org.apache.flinkgroupId>
<artifactId>flink-clients_2.12artifactId>
<version>1.11.0version>
dependency>
<dependency>
<groupId>org.apache.flinkgroupId>
<artifactId>flink-cep_2.11artifactId>
<version>1.11.1version>
dependency>
<dependency>
<groupId>org.jpmmlgroupId>
<artifactId>pmml-evaluatorartifactId>
<version>${pmml.version}version>
dependency>
<dependency>
<groupId>org.jpmmlgroupId>
<artifactId>pmml-evaluator-extensionartifactId>
<version>${pmml.version}version>
dependency>
<dependency>
<groupId>com.alibabagroupId>
<artifactId>fastjsonartifactId>
<version>1.2.78version>
dependency>
dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.pluginsgroupId>
<artifactId>maven-compiler-pluginartifactId>
<version>3.1version>
<configuration>
<source>1.8source>
<target>1.8target>
configuration>
plugin>
<plugin>
<artifactId>maven-assembly-pluginartifactId>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependenciesdescriptorRef>
descriptorRefs>
configuration>
plugin>
plugins>
build>
project>
将获得的模型DecisionTreeIris.pmml
文件放到java工程的resources
目录下
载入模型工具类PMMLUtils
package com.mym.jpmml.util;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.LoadingModelEvaluatorBuilder;
import org.xml.sax.SAXException;
import javax.xml.bind.JAXBException;
import java.io.IOException;
import java.util.List;
public class PMMLUtils {
public static void main(String[] args) throws IOException, JAXBException, SAXException {
Evaluator evaluator = loadEvaluator("/DecisionTreeIris.pmml");
// Printing input (x1, x2, .., xn) fields
List<? extends InputField> inputFields = evaluator.getInputFields();
System.out.println(inputFields);
}
/**
* 载入PMML模型的方法
*
* @param pmmlFileName
* @return
* @throws JAXBException
* @throws SAXException
* @throws IOException
*/
public static Evaluator loadEvaluator(String pmmlFileName) throws JAXBException, SAXException, IOException {
Evaluator evaluator = new LoadingModelEvaluatorBuilder()
.load(PMMLUtils.class.getResourceAsStream(pmmlFileName))
.build();
//自校验&预热模型
evaluator.verify();
System.out.println("评估器自校验&预热完成");
return evaluator;
}
}
构建评估器:决策树Iris评估器
package com.mym.jpmml.predictor;
import com.mym.jpmml.util.PMMLUtils;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import java.io.Serializable;
import java.util.LinkedHashMap;
import java.util.Map;
public class DecisionTreeLrisPredictor implements Serializable {
/* 结果标签名对应的key, 不同模型可能label标签名称不一致,故这里开放给调用处设置 */
public String RESULT_LABEL_NAME = "labelName";
private Evaluator evaluator;
public DecisionTreeLrisPredictor() throws Exception {
evaluator = PMMLUtils.loadEvaluator("/DecisionTreeIris.pmml");
}
public Object predict(Map<String, ?> inputRecord) {
if (inputRecord == null) {
throw new NullPointerException("特征为空!");
}
// 封装参数:特征转成模型可识别的参数
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
for (InputField inputField : evaluator.getInputFields()) {
FieldName inputName = inputField.getName();
Object rawValue = inputRecord.get(inputName.getValue());
FieldValue inputValue = inputField.prepare(rawValue);
arguments.put(inputName, inputValue);
}
// 评估-预测
Map<FieldName, ?> results = evaluator.evaluate(arguments);
// 结果解析
Map<String, ?> resultRecord = EvaluatorUtil.decodeAll(results);
// 获取标签:获取分类结果
return resultRecord.get(inputRecord.get(RESULT_LABEL_NAME));
}
}
[可选]定义一个Flink处理的特征数据对象
package com.mym.jpmml.bean;
import java.io.Serializable;
public class IrisModel implements Serializable {
//petal length、petal width、sepal length、sepal width
private double petalLength;
private double petalWidth;
private double sepalLength;
private double sepalWidth;
public IrisModel(double petalLength, double petalWidth, double sepalLength, double sepalWidth) {
this.petalLength = petalLength;
this.petalWidth = petalWidth;
this.sepalLength = sepalLength;
this.sepalWidth = sepalWidth;
}
public IrisModel() {
}
public double getPetalLength() {
return petalLength;
}
public void setPetalLength(double petalLength) {
this.petalLength = petalLength;
}
public double getPetalWidth() {
return petalWidth;
}
public void setPetalWidth(double petalWidth) {
this.petalWidth = petalWidth;
}
public double getSepalLength() {
return sepalLength;
}
public void setSepalLength(double sepalLength) {
this.sepalLength = sepalLength;
}
public double getSepalWidth() {
return sepalWidth;
}
public void setSepalWidth(double sepalWidth) {
this.sepalWidth = sepalWidth;
}
@Override
public String toString() {
return "IrisModel{" +
"petalLength=" + petalLength +
", petalWidth=" + petalWidth +
", sepalLength=" + sepalLength +
", sepalWidth=" + sepalWidth +
'}';
}
}
调用模型:Flink算子进行调用模型
package com.mym.jpmml.flinkpredictor;import com.mym.jpmml.bean.IrisModel;import com.mym.jpmml.predictor.DecisionTreeLrisPredictor;import org.apache.flink.api.common.functions.FlatMapFunction;import org.apache.flink.streaming.api.datastream.DataStream;import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;import org.apache.flink.util.Collector;import java.util.HashMap;import java.util.Map;public class IrirsJPMMLFlinkInvokeTest { public static void main(String[] args) throws Exception { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(1); DataStream<String> inputStream = env.socketTextStream("localhost", 7777); DataStream<IrisModel> dataStream = inputStream.map(line -> { String[] fields = line.split(","); return new IrisModel(new Double(fields[0]), new Double(fields[1]), new Double(fields[2]), new Double(fields[3])); }); SingleOutputStreamOperator<Object> result = dataStream.flatMap(new PredictorFlatMapFunction()); result.print("result"); env.execute(); } static class PredictorFlatMapFunction implements FlatMapFunction<IrisModel, Object>{ private DecisionTreeLrisPredictor predictor; public PredictorFlatMapFunction() throws Exception { this.predictor = new DecisionTreeLrisPredictor(); } @Override public void flatMap(IrisModel irisModel, Collector<Object> collector) throws Exception { Map<String, Object> inputRecord = new HashMap<>(); // petal length、petal width、sepal length、sepal width // 1.4,0.2,5.1,3.5 inputRecord.put("petal length", irisModel.getPetalLength()); inputRecord.put("petal width", irisModel.getPetalWidth()); inputRecord.put("sepal length", irisModel.getSepalLength()); inputRecord.put("sepal width", irisModel.getSepalWidth()); inputRecord.put(predictor.RESULT_LABEL_NAME, "label"); Object predict = predictor.predict(inputRecord); collector.collect(predict); } }}
我这里使用netcat网络socket方式测试,可以改成使用其他任何方式测试,测试数据如下(测试数据可以自行查看sklearn的iris数据集)
测试
测试数据和对应标签结果:label表示分类结果,实际0,1代表鸢尾花的啥,自行去研究下iris数据集的分类吧
sepal length | sepal width | petal length | petal width | label1.4,0.2,5.1,3.5 01.4,0.2,4.9,3.0 01.3,0.2,4.7,3.2 01.5,0.2,4.6,3.1 01.4,0.2,5.0,3.6 04.2,1.2,5.7,3.0 14.2,1.3,5.7,2.9 14.3,1.3,6.2,2.9 13.0,1.1,5.1,2.5 14.1,1.3,5.7,2.8 1
测试输入
C:\Users\mym>nc -l -p 77771.4,0.2,5.1,3.55.7,2.8,4.1,1.34.1,1.3,5.7,2.85.7,2.8,4.1,1.34.1,1.3,5.7,2.81.4,0.2,5.1,3.5
flink预测输出
result> 0result> 2result> 1result> 2result> 1result> 0