Flink:调用JPMML机器学习模型

有个需求就是要使用数据分析团队实现好的模型,而且是python的,要求在Flink平台上跑起来提供实时调用模型处理数据

文章目录

  • 背景
  • JPMML介绍
  • 环境准备
    • 安装
  • 使用
    • 步骤
    • 示例:决策树分类Iris数据集
      • 训练模型并获得PMML文件
      • JAVA工程调用PMML模型
  • 总结

背景

在Flink平台上通过调用现有python实现的模型,进行实时预测处理

  • Flink V1.11
  • Java 1.8
  • Python3
  • jpmml

JPMML介绍

  • 预言模型标记语言(Predictive Model Markup Language,PMML)
  • 是一种利用XML描述和存储数据挖掘模型的标准语言,它依托XML本身特有的数据分层思想和应用模式,实现了数据挖掘中模型的可移植性。
  • 其中的 J 就是java 调用处理
  • GITHUB:https://github.com/jpmml/jpmml-evaluator

环境准备

安装

  • sklearn2pmml 0.14.0 or newer.
pip install sklearn2pmml -i https://pypi.tuna.tsinghua.edu.cn/simple/

使用

这里基于sklearn做测试,其他框架的pmml包请查阅作者github示例。

作者示例:https://github.com/jpmml/sklearn2pmml

步骤

A typical workflow can be summarized as follows:

1.Create a PMMLPipeline object, and populate it with pipeline steps as usual. Class sklearn2pmml.pipeline.PMMLPipeline extends class sklearn.pipeline.Pipeline with the following functionality:

  • If the PMMLPipeline.fit(X, y) method is invoked with pandas.DataFrame or pandas.Series object as an X argument, then its column names are used as feature names. Otherwise, feature names default to “x1”, “x2”, …, “x{number_of_features}”.
  • If the PMMLPipeline.fit(X, y) method is invoked with pandas.Series object as an y argument, then its name is used as the target name (for supervised models). Otherwise, the target name defaults to “y”.

2.Fit and validate the pipeline as usual.

3.Optionally, compute and embed verification data into the PMMLPipeline object by invoking PMMLPipeline.verify(X) method with a small but representative subset of training data.

4.Convert the PMMLPipeline object to a PMML file in local filesystem by invoking utility method sklearn2pmml.sklearn2pmml(pipeline, pmml_destination_path).

  • 以上是作者写的原文流程,我就不翻译了,直接上重点:
  • 1.创建一个PMMLPipeline对象,并设置它的pipeline。
  • 2.训练并校验
  • 3.[可选操作]用一小部分具有代表性的训练数据给到PMMLPipeline对象。预热模型。
  • 4.把PMMLPipeline对象转换成PMML文件

示例:决策树分类Iris数据集

github上作者有两个示例,一个决策树分类iris数据集,一个逻辑回归分类iris数据集,我这只演示决策树的示例

训练模型并获得PMML文件

import pandas
from sklearn.datasets import load_iris

# github上作者的代码示例,我这直接用sklearn里的,不读文件
# iris_df = pandas.read_csv("Iris.csv")
# iris_X = iris_df[iris_df.columns.difference(["Species"])]
# iris_y = iris_df["Species"]

# 加载鸢尾花数据集(sklearn中的数据集)
iris = load_iris()
# 通过feature_names构造dataFrame
iris_df = pandas.DataFrame(iris.data, columns=iris.feature_names)
# 把iris的结果放到dataFrame的label属性中
iris_df['label'] = iris.target
# 声明dataFrame的新列项
iris_df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']
# 划分数据集
iris_X =iris_df[iris_df.columns.difference(["label"])]
iris_y = iris_df["label"]

from sklearn.tree import DecisionTreeClassifier
from sklearn2pmml.pipeline import PMMLPipeline

# 这里分类模型就写classifier,作者定义好了不同模型的pipeline标识,
# 工作流内需要设置二元组,(名称,模型对象),名称也不是乱指定的,每个名称都是对应特定功能的transformer的
# 像"selector"对应特征选择,“mapper”对应特征预处理,”pca“对应pca,”classifier“对应分类器,”regressor“对应回归器
# 具体去看github上说明吧
pipeline = PMMLPipeline([
	("classifier", DecisionTreeClassifier())
])
# 训练
pipeline.fit(iris_X, iris_y)

from sklearn2pmml import sklearn2pmml

# 把模型转成pmml文件
sklearn2pmml(pipeline, "D:\DecisionTreeIris.pmml", with_repr = True)

注意,执行时出现如下warn,无需理会

D:\ITinstall\anaconda3\lib\subprocess.py:848: RuntimeWarning: line buffering (buffering=1) isn't supported in binary mode, the default buffer size will be used
  self.stdout = io.open(c2pread, 'rb', bufsize)
D:\ITinstall\anaconda3\lib\subprocess.py:853: RuntimeWarning: line buffering (buffering=1) isn't supported in binary mode, the default buffer size will be used
  self.stderr = io.open(errread, 'rb', bufsize)

如下是针对PMMLPipeline构造的更多一些说明

  • 能设置的名称其实不少,但是关于怎么设置这些二元组,作者都是在github上使用示例代码给出的,挺多使用方法分散在项目的不同角落(主要是README),找起来还挺费劲(估计都是用到了才会仔细一点一点搜,要不就在issue直接问作者了),而且也没统一的文档什么的。(可能作者觉得自己写的那些使用说明很详细,大家都能在各种链接之间跳来跳去找到问题的答案)

  • 要对指定特征就行预处理需要用到mapper

  • mapper = mapper = DataFrameMapper([
      (X.columns.to_list(), [ContinuousDomain(with_data = False),StandardScaler()]),
    ])
    classifier = RandomForestClassifier(**params)
    pipeline = PMMLPipeline([
        ("mapper",mapper),
        ("selector", SelectorProxy(VarianceThreshold())),
        ("classifier", classifier),
    ])
    
    
  • DataFrameMapper中传入二元组列表,前面是指定的列名,可以是多个,后面是处理方式。上面演示的是标准缩放,也可以进行行独热编码。

  • ContinuousDomain是这个库特色的特征装饰器,这个是对连续型特征进行装饰

    装饰器主要作用就是能进行一些错误值、空值和离群点的处理。
    还有其他像是”顺序特征“,”分类特征“,”时间特征“的装饰器,具体可以看官方说明
    比较坑的一点是,连续型特征的装饰器会学习训练数据,分析离群点,然后在预测的时候会强制将离群点判定为非法值,从而导致预测的时候可能会发生拒绝接受特征的报错。这里再里设置with_data = False可以避免这个问题。
    with_data是设置是否要再训练时对数据进行分析(分析离群点)
    作者这样设计好像是因为,他认为模型不应该预测不在接受范围内的值,所以强迫你对离群点啊什么的进行处理。

  • 使用selector需要使用SelectorProxy对feature_selection下的对象进行包裹。

JAVA工程调用PMML模型

  • maven引入包(我这是完整的Flink工程,故有比较完整的Flink依赖,注意适当取用)

    
    <project xmlns="http://maven.apache.org/POM/4.0.0"
             xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
             xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
        <modelVersion>4.0.0modelVersion>
    
        <groupId>org.examplegroupId>
        <artifactId>Flink_secondartifactId>
        <version>1.0-SNAPSHOTversion>
    
        <properties>
            <project.build.sourceEncoding>UTF-8project.build.sourceEncoding>
            <maven.compiler.source>1.8maven.compiler.source>
            <maven.compiler.target>1.8maven.compiler.target>
            <pmml.version>1.5.15pmml.version>
        properties>
    
        <dependencies>
            <dependency>
                <groupId>org.apache.flinkgroupId>
                <artifactId>flink-javaartifactId>
                <version>1.11.1version>
            dependency>
            <dependency>
                <groupId>org.apache.flinkgroupId>
                <artifactId>flink-streaming-java_2.12artifactId>
                <version>1.11.1version>
            dependency>
            <dependency>
                <groupId>org.apache.flinkgroupId>
                <artifactId>flink-connector-kafka-0.11_2.12artifactId>
                <version>1.11.1version>
            dependency>
            <dependency>
                <groupId>org.apache.bahirgroupId>
                <artifactId>flink-connector-redis_2.11artifactId>
                <version>1.0version>
            dependency>
            <dependency>
                <groupId>mysqlgroupId>
                <artifactId>mysql-connector-javaartifactId>
                <version>5.1.44version>
            dependency>
    
            <dependency>
                <groupId>org.apache.flinkgroupId>
                <artifactId>flink-connector-elasticsearch6_2.12artifactId>
                <version>1.11.1version>
            dependency>
            <dependency>
                <groupId>org.apache.flinkgroupId>
                <artifactId>flink-statebackend-rocksdb_2.12artifactId>
                <version>1.11.1version>
            dependency>
            <dependency>
                <groupId>org.apache.flinkgroupId>
                <artifactId>flink-table-planner_2.12artifactId>
                <version>1.11.1version>
            dependency>
            <dependency>
                <groupId>org.apache.flinkgroupId>
                <artifactId>flink-table-planner-blink_2.12artifactId>
                <version>1.11.1version>
            dependency>
            <dependency>
                <groupId>org.apache.flinkgroupId>
                <artifactId>flink-csvartifactId>
                <version>1.11.1version>
            dependency>
            <dependency>
                <groupId>org.apache.flinkgroupId>
                <artifactId>flink-clients_2.12artifactId>
                <version>1.11.0version>
            dependency>
            <dependency>
                <groupId>org.apache.flinkgroupId>
                <artifactId>flink-cep_2.11artifactId>
                <version>1.11.1version>
            dependency>
    
    
            
            <dependency>
                <groupId>org.jpmmlgroupId>
                <artifactId>pmml-evaluatorartifactId>
                <version>${pmml.version}version>
            dependency>
            <dependency>
                <groupId>org.jpmmlgroupId>
                <artifactId>pmml-evaluator-extensionartifactId>
                <version>${pmml.version}version>
            dependency>
            
            <dependency>
                <groupId>com.alibabagroupId>
                <artifactId>fastjsonartifactId>
                <version>1.2.78version>
            dependency>
    
    
    
        dependencies>
    
        <build>
            <plugins>
                <plugin>
                    <groupId>org.apache.maven.pluginsgroupId>
                    <artifactId>maven-compiler-pluginartifactId>
                    <version>3.1version>
                    <configuration>
                        <source>1.8source>
                        <target>1.8target>
                    configuration>
                plugin>
                <plugin>
                    <artifactId>maven-assembly-pluginartifactId>
                    <configuration>
                        <descriptorRefs>
                            <descriptorRef>jar-with-dependenciesdescriptorRef>
                        descriptorRefs>
                    configuration>
                plugin>
    
            plugins>
        build>
    
    project>
    
  • 将获得的模型DecisionTreeIris.pmml文件放到java工程的resources目录下

  • 载入模型工具类PMMLUtils

    package com.mym.jpmml.util;
    
    import org.jpmml.evaluator.Evaluator;
    import org.jpmml.evaluator.InputField;
    import org.jpmml.evaluator.LoadingModelEvaluatorBuilder;
    import org.xml.sax.SAXException;
    
    import javax.xml.bind.JAXBException;
    import java.io.IOException;
    import java.util.List;
    
    
    public class PMMLUtils {
        public static void main(String[] args) throws IOException, JAXBException, SAXException {
            Evaluator evaluator = loadEvaluator("/DecisionTreeIris.pmml");
            // Printing input (x1, x2, .., xn) fields
            List<? extends InputField> inputFields = evaluator.getInputFields();
            System.out.println(inputFields);
        }
    
        /**
         * 载入PMML模型的方法
         *
         * @param pmmlFileName
         * @return
         * @throws JAXBException
         * @throws SAXException
         * @throws IOException
         */
        public static Evaluator loadEvaluator(String pmmlFileName) throws JAXBException, SAXException, IOException {
            Evaluator evaluator = new LoadingModelEvaluatorBuilder()
                    .load(PMMLUtils.class.getResourceAsStream(pmmlFileName))
                    .build();
            //自校验&预热模型
            evaluator.verify();
            System.out.println("评估器自校验&预热完成");
            return evaluator;
        }
    }
    
  • 构建评估器:决策树Iris评估器

    package com.mym.jpmml.predictor;
    
    import com.mym.jpmml.util.PMMLUtils;
    import org.dmg.pmml.FieldName;
    import org.jpmml.evaluator.Evaluator;
    import org.jpmml.evaluator.EvaluatorUtil;
    import org.jpmml.evaluator.FieldValue;
    import org.jpmml.evaluator.InputField;
    
    import java.io.Serializable;
    import java.util.LinkedHashMap;
    import java.util.Map;
    
    public class DecisionTreeLrisPredictor implements Serializable {
    
        /* 结果标签名对应的key, 不同模型可能label标签名称不一致,故这里开放给调用处设置 */
        public String RESULT_LABEL_NAME = "labelName";
    
        private Evaluator evaluator;
    
        public DecisionTreeLrisPredictor() throws Exception {
            evaluator = PMMLUtils.loadEvaluator("/DecisionTreeIris.pmml");
        }
    
        public Object predict(Map<String, ?> inputRecord) {
            if (inputRecord == null) {
                throw new NullPointerException("特征为空!");
            }
            // 封装参数:特征转成模型可识别的参数
            Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
            for (InputField inputField : evaluator.getInputFields()) {
                FieldName inputName = inputField.getName();
                Object rawValue = inputRecord.get(inputName.getValue());
                FieldValue inputValue = inputField.prepare(rawValue);
                arguments.put(inputName, inputValue);
            }
            // 评估-预测
            Map<FieldName, ?> results = evaluator.evaluate(arguments);
            // 结果解析
            Map<String, ?> resultRecord = EvaluatorUtil.decodeAll(results);
            // 获取标签:获取分类结果
            return resultRecord.get(inputRecord.get(RESULT_LABEL_NAME));
        }
    }
    
    
  • [可选]定义一个Flink处理的特征数据对象

    package com.mym.jpmml.bean;
    
    
    import java.io.Serializable;
    
    public class IrisModel implements Serializable {
        //petal length、petal width、sepal length、sepal width
        private double petalLength;
        private double petalWidth;
        private double sepalLength;
        private double sepalWidth;
    
        public IrisModel(double petalLength, double petalWidth, double sepalLength, double sepalWidth) {
            this.petalLength = petalLength;
            this.petalWidth = petalWidth;
            this.sepalLength = sepalLength;
            this.sepalWidth = sepalWidth;
        }
    
        public IrisModel() {
        }
    
        public double getPetalLength() {
            return petalLength;
        }
    
        public void setPetalLength(double petalLength) {
            this.petalLength = petalLength;
        }
    
        public double getPetalWidth() {
            return petalWidth;
        }
    
        public void setPetalWidth(double petalWidth) {
            this.petalWidth = petalWidth;
        }
    
        public double getSepalLength() {
            return sepalLength;
        }
    
        public void setSepalLength(double sepalLength) {
            this.sepalLength = sepalLength;
        }
    
        public double getSepalWidth() {
            return sepalWidth;
        }
    
        public void setSepalWidth(double sepalWidth) {
            this.sepalWidth = sepalWidth;
        }
    
        @Override
        public String toString() {
            return "IrisModel{" +
                    "petalLength=" + petalLength +
                    ", petalWidth=" + petalWidth +
                    ", sepalLength=" + sepalLength +
                    ", sepalWidth=" + sepalWidth +
                    '}';
        }
    }
    
    
  • 调用模型:Flink算子进行调用模型

    package com.mym.jpmml.flinkpredictor;import com.mym.jpmml.bean.IrisModel;import com.mym.jpmml.predictor.DecisionTreeLrisPredictor;import org.apache.flink.api.common.functions.FlatMapFunction;import org.apache.flink.streaming.api.datastream.DataStream;import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;import org.apache.flink.util.Collector;import java.util.HashMap;import java.util.Map;public class IrirsJPMMLFlinkInvokeTest {    public static void main(String[] args) throws Exception {        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();        env.setParallelism(1);        DataStream<String> inputStream = env.socketTextStream("localhost", 7777);        DataStream<IrisModel> dataStream = inputStream.map(line -> {            String[] fields = line.split(",");            return new IrisModel(new Double(fields[0]), new Double(fields[1]), new Double(fields[2]), new Double(fields[3]));        });        SingleOutputStreamOperator<Object> result = dataStream.flatMap(new PredictorFlatMapFunction());        result.print("result");        env.execute();    }    static class PredictorFlatMapFunction implements FlatMapFunction<IrisModel, Object>{        private DecisionTreeLrisPredictor predictor;        public PredictorFlatMapFunction() throws Exception {            this.predictor = new DecisionTreeLrisPredictor();        }        @Override        public void flatMap(IrisModel irisModel, Collector<Object> collector) throws Exception {            Map<String, Object> inputRecord = new HashMap<>();            // petal length、petal width、sepal length、sepal width            // 1.4,0.2,5.1,3.5            inputRecord.put("petal length", irisModel.getPetalLength());            inputRecord.put("petal width", irisModel.getPetalWidth());            inputRecord.put("sepal length", irisModel.getSepalLength());            inputRecord.put("sepal width", irisModel.getSepalWidth());            inputRecord.put(predictor.RESULT_LABEL_NAME, "label");            Object predict = predictor.predict(inputRecord);            collector.collect(predict);        }    }}
    

    我这里使用netcat网络socket方式测试,可以改成使用其他任何方式测试,测试数据如下(测试数据可以自行查看sklearn的iris数据集)

  • 测试

    测试数据和对应标签结果:label表示分类结果,实际0,1代表鸢尾花的啥,自行去研究下iris数据集的分类吧

    sepal length | sepal width | petal length | petal width | label1.4,0.2,5.1,3.5 01.4,0.2,4.9,3.0 01.3,0.2,4.7,3.2 01.5,0.2,4.6,3.1 01.4,0.2,5.0,3.6 04.2,1.2,5.7,3.0 14.2,1.3,5.7,2.9 14.3,1.3,6.2,2.9 13.0,1.1,5.1,2.5 14.1,1.3,5.7,2.8 1
    

    测试输入

    C:\Users\mym>nc -l -p 77771.4,0.2,5.1,3.55.7,2.8,4.1,1.34.1,1.3,5.7,2.85.7,2.8,4.1,1.34.1,1.3,5.7,2.81.4,0.2,5.1,3.5
    

    flink预测输出

    result> 0result> 2result> 1result> 2result> 1result> 0
    

    总结

    • 预测模型是一次预测一条数据,很适合flink的流计算
    • 生成的.pmml模型很大,如果在调用处打jar后还是很大时要考虑压缩。
    • jpmml支持大部分机器学习框架比如sklearn、sparkml、python、R、lightgbm、xgboost、tensorflow等等

你可能感兴趣的:(Flink,机器学习,flink,python,JPMML,ml)