Spark -- 模型的导入和导出

  通过SparkML训练的各种模型,通过Pipeline训练的为PipelineModel,我们可以将此模型写出为pmml文件(跨平台)或者写入hdfs(spark可以重新加载)。

写入HDFS

介绍

  我们项目需要将Spark训练的模型保存至HDFS,待需要时再重新加载回来做后续的模型预测和评估的流程。因为Spark2.0后我们都是用Pipeline去训练模型的,我们以PipelineModel为例,来保存模型到HDFS。

保存模型到HDFS

  保存模型到HDFS上非常容易,直接下面一行代码搞定,save的参数就是HDFS的路径,最好使用绝对路径。

pipelineModel.write.overwrite().save(HdfsOperUtils.getCacheModelPath(cmptId, loadOrder))
从HDFS读取模型

  读取模型也很容易,直接一行代码搞定,其中load的参数是HDFS的路径,最好使用绝对路径。读取到之前保存的模型,就可以继续走下面的流程了。

val model: PipelineModel = PipelineModel.read.load(HdfsOperUtils.getCacheModelPath(cmptId, loadOrder))

PMML

介绍

  PMML是数据挖掘的一种通用的规范,它用统一的XML格式来描述我们生成的机器学习模型。这样无论你的模型是sklearn,R还是Spark MLlib生成的,我们都可以将其转化为标准的XML格式来存储。当我们需要将这个PMML的模型用于部署的时候,可以使用目标环境的解析PMML模型的库来加载模型,并做预测。

保存PMML

  保存PMML也很简单,引入pom依赖,然后两行代码搞定。

<dependency>
    <groupId>org.jpmmlgroupId>
    <artifactId>jpmml-sparkmlartifactId>
    <version>1.5.0version>
    <exclusions>
        <exclusion>
            <groupId>com.google.guavagroupId>
            <artifactId>guavaartifactId>
        exclusion>
    exclusions>
dependency>
//通过PMMLBuilder对模型生成pmml文件
val file = new PMMLBuilder(userTrans.schema, pipelineModel).buildFile(new File(filePath))
//将模型保存到HDFS指定路径
HdfsOperUtils.copyFileToHDFS(file,"/model/"+file.getName)

//其中HdfsOperUtils代码如下:
/**
     * 将本地文件复制到hdfs
     *
     * @param file     上传的文件,如“D:/spool.sql”
     * @param hdfsPath 保存到hdfs地址,如“/user/spool.sql”
     */
    public static void copyFileToHDFS(File file, String hdfsPath) {
        FileSystem fs = getFileSystem();
        try {
            FileInputStream fis = new FileInputStream(file);//读取本地文件
            OutputStream os = fs.create(new Path(hdfsPath));
            //复制文件,true是否关闭数据流,如果是false,就在finally里关掉
            IOUtils.copyBytes(fis, os, 4096, true);
        } catch (Exception e) {
            LOGGER.error("写文件异常", e);
        } finally {
            try {
                fs.close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
Spark读取PMML预测模型

  因为我没有实际操作过这个步骤,所以就先引用一下别人的吧:https://blog.csdn.net/fansy1990/article/details/53293024

Java读取PMML预测模型

  java读取也很简单,也是引用一下pom,然后简单几行就搞定了。

<dependency>
    <groupId>org.jpmmlgroupId>
    <artifactId>pmml-evaluatorartifactId>
    <version>1.4.7version>
dependency>
<dependency>
    <groupId>org.jpmmlgroupId>
    <artifactId>pmml-evaluator-extensionartifactId>
    <version>1.4.7version>
dependency>
public Evaluator evaluator;

    public static void main(String[] args) {
        long startTime = System.currentTimeMillis();

        Evaluator evaluator = loadPmml("D:/Work/Project/06-Metis-AI-文件夹智能推荐/metis_round2.pmml");
        long endTime1 = System.currentTimeMillis();
        System.out.println("加载模型耗时:" + (endTime1 - startTime));
        Map<String, Map<String, Double>> map = getEntityMapFromFile("D:/Work/Project/06-Metis-AI-文件夹智能推荐/metis_round2_test.csv", 40);
        long endTime2 = System.currentTimeMillis();
        System.out.println("加载一条数据耗时:" + (endTime2 - endTime1));
        Map<String, Double> valueMap = predict(evaluator, map);

        long endTime3 = System.currentTimeMillis();
        System.out.println("预测模型耗时:" + (endTime3 - endTime2));
        valueMap.forEach((key, value) -> System.out.println(key + " -> " + value));
    }

    /**
     * @Author: TheBigBlue
     * @Description: 加载pmml,返回evaluator
     * @Date: 2019/3/7
     * @Return:
     **/
    public static Evaluator loadPmml(String pmmlPath) {
        PMML pmml = new PMML();
        InputStream inputStream = null;
        try {
            inputStream = new FileInputStream(pmmlPath);
            pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream);
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            //关闭输入流
            try {
                inputStream.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
        Evaluator evaluator = null;
        if (pmml != null) {
            ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
            evaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
        }
        return evaluator;
    }

    /**
     * @Author: TheBigBlue
     * @Description: 模型预测
     * @Date: 2019/3/7
     * @Return:
     **/
    public static Map<String, Double> predict(Evaluator evaluator, Map<String, Map<String, Double>> map) {
        List<InputField> inputFields = evaluator.getInputFields();
        Map<String, Double> valueMap = new LinkedHashMap<>();
        for (Map.Entry<String, Map<String, Double>> entry : map.entrySet()) {
            Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
            for (InputField inputField : inputFields) {
                FieldName inputFieldName = inputField.getFieldName();
                Double rowValue = entry.getValue().get(inputFieldName.getValue());
                FieldValue inputFieldValue = inputField.prepare(rowValue);
                arguments.put(inputFieldName, inputFieldValue);
            }
            Map<FieldName, ?> results = evaluator.evaluate(arguments);
            FieldName targetFieldName = evaluator.getTargetFields().get(0).getFieldName();
            Computable computable = (Computable) results.get(targetFieldName);
            valueMap.put(entry.getKey(), (double) computable.getResult());
        }
        return valueMap;
    }

你可能感兴趣的:(Spark)