通过SparkML训练的各种模型,通过Pipeline训练的为PipelineModel,我们可以将此模型写出为pmml文件(跨平台)或者写入hdfs(spark可以重新加载)。
我们项目需要将Spark训练的模型保存至HDFS,待需要时再重新加载回来做后续的模型预测和评估的流程。因为Spark2.0后我们都是用Pipeline去训练模型的,我们以PipelineModel为例,来保存模型到HDFS。
保存模型到HDFS上非常容易,直接下面一行代码搞定,save的参数就是HDFS的路径,最好使用绝对路径。
pipelineModel.write.overwrite().save(HdfsOperUtils.getCacheModelPath(cmptId, loadOrder))
读取模型也很容易,直接一行代码搞定,其中load的参数是HDFS的路径,最好使用绝对路径。读取到之前保存的模型,就可以继续走下面的流程了。
val model: PipelineModel = PipelineModel.read.load(HdfsOperUtils.getCacheModelPath(cmptId, loadOrder))
PMML是数据挖掘的一种通用的规范,它用统一的XML格式来描述我们生成的机器学习模型。这样无论你的模型是sklearn,R还是Spark MLlib生成的,我们都可以将其转化为标准的XML格式来存储。当我们需要将这个PMML的模型用于部署的时候,可以使用目标环境的解析PMML模型的库来加载模型,并做预测。
保存PMML也很简单,引入pom依赖,然后两行代码搞定。
<dependency>
<groupId>org.jpmmlgroupId>
<artifactId>jpmml-sparkmlartifactId>
<version>1.5.0version>
<exclusions>
<exclusion>
<groupId>com.google.guavagroupId>
<artifactId>guavaartifactId>
exclusion>
exclusions>
dependency>
//通过PMMLBuilder对模型生成pmml文件
val file = new PMMLBuilder(userTrans.schema, pipelineModel).buildFile(new File(filePath))
//将模型保存到HDFS指定路径
HdfsOperUtils.copyFileToHDFS(file,"/model/"+file.getName)
//其中HdfsOperUtils代码如下:
/**
* 将本地文件复制到hdfs
*
* @param file 上传的文件,如“D:/spool.sql”
* @param hdfsPath 保存到hdfs地址,如“/user/spool.sql”
*/
public static void copyFileToHDFS(File file, String hdfsPath) {
FileSystem fs = getFileSystem();
try {
FileInputStream fis = new FileInputStream(file);//读取本地文件
OutputStream os = fs.create(new Path(hdfsPath));
//复制文件,true是否关闭数据流,如果是false,就在finally里关掉
IOUtils.copyBytes(fis, os, 4096, true);
} catch (Exception e) {
LOGGER.error("写文件异常", e);
} finally {
try {
fs.close();
} catch (Exception e) {
e.printStackTrace();
}
}
}
因为我没有实际操作过这个步骤,所以就先引用一下别人的吧:https://blog.csdn.net/fansy1990/article/details/53293024
java读取也很简单,也是引用一下pom,然后简单几行就搞定了。
<dependency>
<groupId>org.jpmmlgroupId>
<artifactId>pmml-evaluatorartifactId>
<version>1.4.7version>
dependency>
<dependency>
<groupId>org.jpmmlgroupId>
<artifactId>pmml-evaluator-extensionartifactId>
<version>1.4.7version>
dependency>
public Evaluator evaluator;
public static void main(String[] args) {
long startTime = System.currentTimeMillis();
Evaluator evaluator = loadPmml("D:/Work/Project/06-Metis-AI-文件夹智能推荐/metis_round2.pmml");
long endTime1 = System.currentTimeMillis();
System.out.println("加载模型耗时:" + (endTime1 - startTime));
Map<String, Map<String, Double>> map = getEntityMapFromFile("D:/Work/Project/06-Metis-AI-文件夹智能推荐/metis_round2_test.csv", 40);
long endTime2 = System.currentTimeMillis();
System.out.println("加载一条数据耗时:" + (endTime2 - endTime1));
Map<String, Double> valueMap = predict(evaluator, map);
long endTime3 = System.currentTimeMillis();
System.out.println("预测模型耗时:" + (endTime3 - endTime2));
valueMap.forEach((key, value) -> System.out.println(key + " -> " + value));
}
/**
* @Author: TheBigBlue
* @Description: 加载pmml,返回evaluator
* @Date: 2019/3/7
* @Return:
**/
public static Evaluator loadPmml(String pmmlPath) {
PMML pmml = new PMML();
InputStream inputStream = null;
try {
inputStream = new FileInputStream(pmmlPath);
pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
} catch (Exception e) {
e.printStackTrace();
} finally {
//关闭输入流
try {
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
Evaluator evaluator = null;
if (pmml != null) {
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
}
return evaluator;
}
/**
* @Author: TheBigBlue
* @Description: 模型预测
* @Date: 2019/3/7
* @Return:
**/
public static Map<String, Double> predict(Evaluator evaluator, Map<String, Map<String, Double>> map) {
List<InputField> inputFields = evaluator.getInputFields();
Map<String, Double> valueMap = new LinkedHashMap<>();
for (Map.Entry<String, Map<String, Double>> entry : map.entrySet()) {
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getFieldName();
Double rowValue = entry.getValue().get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rowValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
FieldName targetFieldName = evaluator.getTargetFields().get(0).getFieldName();
Computable computable = (Computable) results.get(targetFieldName);
valueMap.put(entry.getKey(), (double) computable.getResult());
}
return valueMap;
}