springboot 实现机器学习_机器学习模型部署:pmml+springboot

工作中,大部分都是使用Python和机器学习库进行建模,但是线上环境基本上都是Java开发的,所以如何将我们训练好的模型部署到线上始终是一个问题。PMML就是针对这一问题的解决办法。

1.PMML概述

PMML全称预测模型标记语言(Predictive Model Markup Language),利用XML描述和存储数据挖掘模型,是一个已经被W3C所接受的标准。MML是一种基于XML的语言,用来定义预测模型。通过使用标准的XML解析器对PMML进行解析,应用程序能够决定模型输入和输出的数据类型,模型详细的格式,并且按照标准的数据挖掘术语来解释模型的结果。PMML提供了一个灵活机制来定义预测模型的模式,同时支持涉及多个预测模型的模型选择和模型平衡(model averaging)。PMML既可以呈现用于从数据中了解模型的统计技术(如人工神经网络和决策树),也可以呈现原始输入数据的预处理以及模型输出的后处理。

PMML文件的结构遵从了用于构建预测解决方案的常用步骤,包括:

1.数据词典,可以识别和定义哪些输入数据字段对于解决眼前的问题是最有用的,包括数值、顺序和分类字段。

2.挖掘架构,定义了处理缺少值和离群值的策略。这非常有用,因为通常情况,当将模型应用于实践时,所需的输入数据字段可能为空或者被误呈现。

3.数据转换,定义了将原始输入数据预处理至派生字段所需的计算。派生字段(有时也称为特征检测器)对输入字段进行合并或修改,以获取更多相关信息。例如,为了预测停车所需的制动压力,一个预测模型可能将室外温度和水的存在(是否在下雨?)作为原始数据。派生字段可能会将这两个字段结合起来,以探测路上是否结冰。然后结冰字段被作为模型的直接输入来预测停车所需的制动压力。

4.模型定义,定义了用于构建模型的结构和参数。PMML涵盖了多种统计技术。例如,为了呈现一个神经网络,它定义了所有的神经层和神经元之间的连接权重。对于一个决策树来说,它定义了所有树节点及简单和复合谓语。

5.输出,定义了预期模型输出。对于一个分类任务来说,输出可以包括预测类及与所有可能类相关的概率。

6.目标,定义了应用于模型输出的后处理步骤。对于一个回归任务来说,此步骤支持将输出转变为人们很容易就可以理解的分数(预测结果)。

7.模型解释,定义了将测试数据传递至模型时获得的性能度量标准。这些度量标准包括字段相关性、混淆矩阵、增益图及接收者操作特征(ROC)曲线图。

8.模型验证,定义了一个包含输入数据记录和预期模型输出的示例集。这是非常重要的一个步骤,因为在应用程序之间移动模型时,该模型需要通过匹配测试。这样就可以确保,在呈现相同的输入时,新系统可以生成与旧系统同样的输出。如果实际情况是这样的话,一个模型将被认为经过了验证,且随时可用于实践。

2.PMML深度解析

如上所述,PMML的结构反映了常用于创建预测解决方案的八大步骤,从在“数据词典”步骤中定义原始输入数据字段到在“模型验证”步骤中验证模型是否得到正确部署。

表1展示了一个含有三个字段的解决方案中PMML元素DataDictionary的定义,这三个字段是:数值型输入字段Value、分类输入字段Element和数值型输出字段Risk。

表1 DataDictionary元素

对于字段Value,范围从负无穷大到60的值是有效值。高于60的值被定义为无效值。考虑到字段Element是分类的,有效值被明确地列出。如果该特定字段的数据提要包含元素Iron,将该元素作为无效值处理。

上图展示了神经网络模型的图形表示,其中输入层包含3个神经元,隐藏层包含2个神经元,输出层包含1个神经元。如您所期望的,PMML可以完全呈现这样一个结构。

表2展示了隐藏层及其神经元以及输入层(0、1和2)和隐藏层(3和4)中神经元的连接权重的定义。

表2 在PMML中定义神经层及其神经元

PMML不是一件艰难的事。其复杂程度反映了其呈现的建模技术的复杂程度。事实上,它揭开了许多人感到神秘的预测分析的秘密和黑匣子。利用PMML,任何预测解决方案都可以采用同样的顺序用同一种语言元素呈现。

3.SAX解析xml文件

SAX(simple API for XML)是一种XML解析的替代方法。相比于DOM,SAX是一种速度更快,更有效的方法。它逐行扫描文档,一边扫描一边解析。而且相比于DOM,SAX可以在解析文档的任意时刻停止解析,但任何事物都有其相反的一面,对于SAX来说就是操作复杂。

SAX是事件驱动型XML解析的一个标准接口,对文档进行顺序扫描,当扫描到文档(document)开始与结束、元素(element)开始与结束、文档(document)结束等地方时通知事件处理函数,由事件处理函数做相应动作,然后继续同样的扫描,直至文档结束。

大多数SAX都会产生以下类型的事件:

1.在文档的开始和结束时触发文档处理事件。

2.在文档内每一XML元素接受解析的前后触发元素事件。

3.任何元数据通常由单独的事件处理。

4.在处理文档的DTD或Schema时产生DTD或Schema事件。

5.产生错误事件用来通知主机应用程序解析错误。

JAXB可以把xml对象转化为java对象,也可以把java对象转化为xml对象。这时候我们就得知道它的两个转化方法。一个是unmarshal(),一个是marshal()。marshal()是把java对象序列化为xml对象的一个过程。unmarshal()是把xml对象反序列化为我们需要的java对象的方法。

4.sklearn训练模型

这里训练一个很简单的分类模型,然后将模型导出为pmml文件。

安装sklearn2pmml:pip install sklearn2pmml。

# 导包

import pandas as pd

from sklearn.datasets import load_iris

from sklearn2pmml.pipeline import PMMLPipeline

from sklearn2pmml import sklearn2pmml

from sklearn.linear_model import LogisticRegression

# 加载数据

data = load_iris()

x = pd.DataFrame(data.data, columns=['slength','swidth','plength','pwidth'])

y = pd.DataFrame(data.target, columns=['y'])

# 创建pipeline和训练模型

lr = LogisticRegression(random_state=100)

pipeline = PMMLPipeline([ ('lr', lr) ])

pipeline.fit(X=x, y=y.values.ravel())

pipeline.verify(x.sample(n = 15, random_state=100))

# 导出pmml文件

sklearn2pmml(pipeline, r'testpmml.pmml', with_repr=True)

导出的pmml文件具体内容如下:通过DataDictionary可以看出,该分类问题为3分类问题,4个特征分别为slength、swidth、plength和pwidth。

5. idea创建springboot项目

5.1 创建新项目

5.2 在pom.xml中添加dependency

jpmml是java解析pmml的类库,fastjson是阿里巴巴开源的解析json是类库,org.glassfish.jaxb是jpmml依赖的一个类库,以前是标准类库,后来被移除了,要求手动添加。

第一次添加dependency之后,需要下载这些依赖库,点击刷新按钮,等依赖库下载完成即可。

org.jpmml

pmml-evaluator

1.4.5

org.jpmml

pmml-evaluator-extension

1.4.5

com.alibaba

fastjson

1.2.54

org.glassfish.jaxb

jaxb-runtime

2.3.0

com.jayway.jsonpath

json-path

5.3 新建一个PmmlPredict.java类

PmmlPredict.java类要做的只有两件事,1是在springboot启动时加载pmml并初始化模型,2是定义一个预测函数,方便http调用,然后返回预测值。

package pmmldemo.pmml;

import com.alibaba.fastjson.JSONObject;

import org.dmg.pmml.FieldName;

import org.jpmml.evaluator.*;

import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;

import java.io.File;

import java.io.IOException;

import java.util.Arrays;

import java.util.LinkedHashMap;

import java.util.List;

import java.util.Map;

public class PmmlPredict {

//将模型定义为全局变量,springboot启动时加载pmml并初始化

public static Evaluator evaluator;

//模型初始化方法,springboot启动时执行该方法,然后初始化上面的Evaluator

public static void initModel() throws IOException, SAXException, JAXBException {

File file = new File("/Users/a5210/testpmml.pmml");

evaluator = new LoadingModelEvaluatorBuilder().load(file).build();

evaluator.verify();

}

//定义一个实用函数,就是python中的print函数,没别的意思

public static void print(Object... args){

Arrays.stream(args).forEach(System.out::print);

System.out.println("");

}

// 定义预测函数,htt请求该函数,然后返回预测值

// 传入的参数是一个json,字段要求和模型的字段保持一致

public static Integer predict(JSONObject feature){

// 获取模型定义的特征

List extends InputField> inputFields = evaluator.getInputFields();

print("模型的特征是:", inputFields);

// 获取模型定义的目标名称

List extends TargetField> targetFields = evaluator.getTargetFields();

print("目标字段是:",targetFields);

// 示例传进来的json数据

// String json = "{\"slength\": 1.0, \"swidth\": 1.0, \"plength\": 1.0, \"pwidth\": 1.0}";

// JSONObject feature = JSONObject.parseObject(json);

// 将json转成evaluator要求的map格式,其实就是对key和value再做一层包装而已

Map arguments = new LinkedHashMap<>();

for(InputField inputField: inputFields){

FieldName inputName = inputField.getName();

String name = inputName.getValue();

Object rawValue = feature.getDoubleValue(name);

FieldValue inputValue = inputField.prepare(rawValue);

arguments.put(inputName, inputValue);

}

// 得到特征数据后就是预测了

Map results = evaluator.evaluate(arguments);

Map resultRecord = EvaluatorUtil.decode(results);

Integer y = (Integer) resultRecord.get("y");

// 打印结果会更加了解其中的封装过程

print("预测结果:");

print(results);

print(resultRecord);

print(y);

return y;

}

}

5.4 新建一个InitializingModel.java类

InitializingModel.java.java类负责初始化模型,并对外提供接口路径。

package pmmldemo.pmml;

import com.alibaba.fastjson.JSONObject;

import org.springframework.web.bind.annotation.*;

@RestController

public class InitializingModel {

// 定义index页,也是为了测试网络是否通畅

@RequestMapping("/")

public String index() {

return "hello spring for test";

}

// 定义一个接口,从http中接受RequestBody中的字符串,这是一个json的字符串,用fastjson解析成json后直接调用预测函数PmmlPredict.predict进行预测

@RequestMapping(value="/predict", method=RequestMethod.POST, produces="application/json;charset=UTF-8")

public @ResponseBody

String getModel(@RequestBody String feature) {

// 将字符串解析成json

JSONObject json = JSONObject.parseObject(feature);

// 调用PmmlPredict.initModel()

try {

PmmlPredict.initModel();

} catch (Exception e) {

e.printStackTrace();

}

// 预测

double y = PmmlPredict.predict(json);

// 返回

return String.valueOf(y);

}

}

5.5 启动springboot

6.使用postman进行测试

查看预测函数的打印输出,如下所示,可以看到,对于分类模型,会输出各个类别的预测概率,然后返回概率最大的一个类作为预测的概率。

模型的特征是:[InputField{name=slength, displayName=null, dataType=double, opType=continuous}, InputField{name=swidth, displayName=null, dataType=double, opType=continuous}, InputField{name=plength, displayName=null, dataType=double, opType=continuous}, InputField{name=pwidth, displayName=null, dataType=double, opType=continuous}]

目标字段是:[TargetField{name=y, displayName=null, dataType=integer, opType=categorical}]

预测结果:

{y=ProbabilityDistribution{result=2, probability_entries=[0=0.20617333796640783, 1=0.24651265148432944, 2=0.5473140105492628]}, probability(0)=0.20617333796640783, probability(1)=0.24651265148432944, probability(2)=0.5473140105492628}

{y=2, probability(0)=0.20617333796640783, probability(1)=0.24651265148432944, probability(2)=0.5473140105492628}

2

也可以使用python代码发送数据,并返回结果。

import requests

url = 'http://127.0.0.1:8080/predict'

data = {"slength": 1.0, "swidth": 1.0,"plength": 1.0, "pwidth": 1.0}

r = requests.post(url, json=data)

print(r.text)

7.aiohttp并发测试

为解决RuntimeError: This event loop is already running问题,先安装nest_asyncio:pip install nest_asyncio。

import asyncio

from aiohttp import ClientSession

import time

import nest_asyncio

nest_asyncio.apply()

async def hello():

url = 'http://127.0.0.1:8080/predict'

data = {"slength": 1.0, "swidth": 1.0,"plength": 1.0, "pwidth": 1.0}

async with ClientSession() as session:

async with session.post(url=url, json=data) as response:

response = await response.read()

# print(response)

return response

# 设置并发数量

tasks = [asyncio.ensure_future(hello()) for _ in range(1000)]

start = time.clock()

loop = asyncio.get_event_loop()

loop.run_until_complete(asyncio.wait(tasks))

end = time.clock()

print('全部请求耗时:%.4f 秒'%(end-start))

8.PMML总结与思考

PMML的确是跨平台的利器,但也有如下缺点:

第一个就是PMML为了满足跨平台,牺牲了很多平台独有的优化,所以很多时候我们用算法库自己的保存模型的API得到的模型文件,要比生成的PMML模型文件小很多。同时PMML文件加载速度也比算法库自己独有格式的模型文件加载慢很多。

第二个就是PMML加载得到的模型和算法库自己独有的模型相比,预测会有一点点的偏差,当然这个偏差并不大。比如某一个样本,用sklearn的决策树模型预测为类别1,但是如果我们把这个决策树落盘为一个PMML文件,并用JAVA加载后,继续预测刚才这个样本,有较小的概率出现预测的结果不为类别1。

第三个就是对于超大模型,比如大规模的集成学习模型,比如xgboost,随机森林,或者tensorflow,生成的PMML文件很容易得到几个G,甚至上T,这时使用PMML文件加载预测速度会非常慢,此时推荐为模型建立一个专有的环境,就没有必要去考虑跨平台了。

此外,对于TensorFlow,不推荐使用PMML的方式来跨平台。可能的方法一是TensorFlow serving,自己搭建预测服务,但是会稍有些复杂。另一个方法就是将模型保存为TensorFlow的模型文件,并用TensorFlow独有的JAVA库加载来做预测。

9.补充参考

你可能感兴趣的:(springboot,实现机器学习)