flink1.12以后感觉真的香的一批,实时计算个人的使用感觉是比spark 的structured streaming要舒服的多。但是奈何还没有spark完善,尤其是在1.8以后,flinkML基本就处于不可用的状态(不知道是不是我没有找对文档,怎么最新的版本只有框架,没有模型啊喂)。
不过最后终于找到了jpmml的这个解决方案。简单来讲,pmml就是一个机器学习模型的中间格式,文件里面是用xml描述的。然后我们可以使用官方提供的sklearn2pmml库,在python中使用skearn训练好模型保存成.pmml文件,然后再java中使用jpmml读取文件来进行预测。
预言模型标记语言(Predictive Model Markup Language,PMML)是一种利用XML描述和存储数据挖掘模型的标准语言,它依托XML本身特有的数据分层思想和应用模式,实现了数据挖掘中模型的可移植性。
jpmml作者的github
里面也有官方的使用案例,内容还挺丰富。
jpmml-sklearn
sklearn2pmml
作者提供的对依赖包的要求,pip嗯安装就行了。
自己使用的时候sklearn是用conda安装过了,另外两个好像只有pip安装,这里设置了清华镜像源,速度提升很客观。
好像安装sklearn2pmml时会自动安装sklearn-pandas?应该是有依赖关系吧。
scikit-learn
0.16.0 or newer.sklearn-pandas
0.0.10 or newer.sklearn2pmml
0.14.0 or newer.pip install sklearn-pandas -i https://pypi.tuna.tsinghua.edu.cn/simple/
pip install sklearn2pmml -i https://pypi.tuna.tsinghua.edu.cn/simple/
创建模型不是乱创建
创建模型要使用sklearn2pmml为我们提供的工作流(pipeline)
工作流内需要存入二元组,(名称,模型对象)
二元组的名称也不是乱指定的,每个名称都是对应特定功能的transformer的,像"selector"对应特征选择,“mapper”对应特征预处理,”pca“对应pca,”classifier“对应分类器,”regressor“对应回归器。
乱创建虽然在python可能可以运行,但是生成pmml文件时可能会出错。
吐槽:能设置的名称其实不少,但是关于怎么设置这些二元组,作者都是在github上使用示例代码给出的,挺多使用方法分散在项目的不同角落(主要是README),找起来还挺费劲(估计都是用到了才会仔细一点一点搜,要不就在issue直接问作者了),而且也没统一的文档什么的。(可能作者觉得自己写的那些使用说明很详细,大家都能在各种链接之间跳来跳去找到问题的答案)
使用方法大多在项目的README里面可以找到,下面只演示我之前做过的内容。
from sklearn2pmml.pipeline import PMMLPipeline
from sklearn2pmml.decoration import ContinuousDomain
from sklearn2pmml import sklearn2pmml,SelectorProxy
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import VarianceThreshold
mapper = mapper = DataFrameMapper([
(X.columns.to_list(), [ContinuousDomain(with_data = False),StandardScaler()]),
])
params = {"n_estimators":88,"random_state":420}
classifier = RandomForestClassifier(**params)
pipeline = PMMLPipeline([
("mapper",mapper),
("selector", SelectorProxy(VarianceThreshold())),
("classifier", classifier),
])
要对指定特征就行预处理需要用到mapper
DataFrameMapper中传入二元组列表,前面是指定的列名,可以是多个,后面是处理方式。上面演示的是标准缩放,也可以进行行独热编码。
ContinuousDomain是这个库特色的特征装饰器,这个是对连续型特征进行装饰
with_data = False
可以避免这个问题。
使用selector需要使用SelectorProxy对feature_selection下的对象进行包裹。
训练模型主要需要进行两个操作,一个是fit训练,另一个是verify验证
active_fields
这个字段无法被激活,进行verify会不通过。然后就是之前提到的,这个库对每个特征进行自动的分析的功能无法进行。
pipeline.fit(X,y)
pipeline.verify(X.sample(15))
from sklearn2pmml import sklearn2pmml
sklearn2pmml(pipeline, "StayAlertRFC.pmml", with_repr = True)
github上的回归任务演示代码
这里推荐使用maven项目,里面pmml的相关依赖有这些。
<dependency>
<groupId>org.jpmmlgroupId>
<artifactId>pmml-evaluatorartifactId>
<version>${pmml.version}version>
dependency>
<dependency>
<groupId>org.jpmmlgroupId>
<artifactId>pmml-evaluator-extensionartifactId>
<version>${pmml.version}version>
dependency>
生成的模型可以放在项目的资源路径下,方便之后打包。
package xyz.hyhy.stayalert.flink.utils;
import com.alibaba.fastjson.JSONObject;
import lombok.extern.slf4j.Slf4j;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.LoadingModelEvaluatorBuilder;
import org.xml.sax.SAXException;
import javax.xml.bind.JAXBException;
import java.io.IOException;
import java.util.List;
import java.util.Map;
@Slf4j
public class PMMLUtils {
public static void main(String[] args) throws IOException, JAXBException, SAXException {
Evaluator evaluator = loadEvaluator("/StayAlertRFC.pmml");
// Printing input (x1, x2, .., xn) fields
List<? extends InputField> inputFields = evaluator.getInputFields();
Map<String, Object> obj2 = JSONObject.parseObject("{\"V11\":33.7824}");
Double d = (Double) obj2.get("V11");
inputFields.get(3).prepare(d);
}
/**
* 载入PMML模型的方法
*
* @param pmmlFileName
* @return
* @throws JAXBException
* @throws SAXException
* @throws IOException
*/
public static Evaluator loadEvaluator(String pmmlFileName) throws JAXBException, SAXException, IOException {
Evaluator evaluator = new LoadingModelEvaluatorBuilder()
.load(PMMLUtils.class.getResourceAsStream(pmmlFileName))
.build();
evaluator.verify(); //自校验——预热模型
log.info("StayAlert分类评估器自校验&预热完成");
return evaluator;
}
}
这里使用LoadingModelEvaluatorBuilder载入模型,注意load()方法可以传入File类型也可以传入InputStream类型,这里一定要使用PMMLUtils.class.getResourceAsStream(pmmlFileName)来获取文件流传入参数,使用getResource在idea上可能可以用,但是打包部署到linux上就不行啦。
读取文件Evaluator evaluator = loadEvaluator("/StayAlertRFC.pmml");
/
package xyz.hyhy.stayalert.flink.prediction;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.xml.sax.SAXException;
import xyz.hyhy.stayalert.flink.utils.PMMLUtils;
import javax.xml.bind.JAXBException;
import java.io.IOException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
public class StayAlertPredictor {
private Evaluator evaluator;
private List<InputField> inputFields;
public StayAlertPredictor() throws IOException, JAXBException, SAXException {
evaluator = PMMLUtils.loadEvaluator("/LightStayAlertRFC.pmml");
inputFields = evaluator.getInputFields();
}
public Boolean predict(Map<String, ?> inputRecord) {
if (inputRecord == null) {
throw new NullPointerException("预测程序不能输入空的记录");
}
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
// 从数据源模式到PMML模式逐字段映射记录
for (InputField inputField : inputFields) {
FieldName inputName = inputField.getName();
Object rawValue = inputRecord.get(inputName.getValue());
Double doubleValue = Double.parseDouble(rawValue.toString());
// 将任意用户提供的值转换为已知的PMML值
FieldValue inputValue = inputField.prepare(doubleValue);
arguments.put(inputName, inputValue);
}
// 用已知的特征来评估模型
Map<FieldName, ?> results = evaluator.evaluate(arguments);
// 解耦结果来自jpmml-evaluator运行时环境
Map<String, ?> resultRecord = EvaluatorUtil.decodeAll(results);
//获取并返回预测结果
Integer isAlert = (Integer) resultRecord.get("IsAlert");
return isAlert == 1;
}
}
package xyz.hyhy.stayalert.flink.task;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.util.Collector;
import org.xml.sax.SAXException;
import xyz.hyhy.stayalert.flink.pojo.UserDataPOJO;
import xyz.hyhy.stayalert.flink.prediction.StayAlertPredictor;
import javax.xml.bind.JAXBException;
import java.io.IOException;
public class StayAlertPredictTask {
private static StayAlertPredictor predictor;
static {
try {
predictor = new StayAlertPredictor();
} catch (IOException e) {
e.printStackTrace();
} catch (JAXBException e) {
e.printStackTrace();
} catch (SAXException e) {
e.printStackTrace();
}
}
private StayAlertPredictTask() {
}
public static SingleOutputStreamOperator<UserDataPOJO> predict(DataStream<UserDataPOJO> ds) {
return ds.flatMap(new FlatMapFunction<UserDataPOJO, UserDataPOJO>() {
@Override
public void flatMap(UserDataPOJO userDataPOJO,
Collector<UserDataPOJO> collector) throws Exception {
try {
//判断是否分心
boolean isAlert = predictor.predict(userDataPOJO.getDeviceFeature());
userDataPOJO.setIsAlert(isAlert);
collector.collect(userDataPOJO);
userDataPOJO.setIsAlert(null);
} catch (Exception e) {
e.printStackTrace();
}
}
});
}
}
这里封装了一个创建计算流的方法,把预测模型弄成静态类,直接在算子里面使用就是了。
这次用下来,个人就有一些小小的感受(无关紧要哈,主要是使用前有疑惑所以想到的一些事项):
这个项目是之前做的一个flink大作业,博客里主要是和pmml使用相关的代码,完整代码有兴趣的可以到这里下载。