使用java调用python训练出的pmml模型

作为一个2,3年没有用过java的数据挖掘工程师,突然要用java来调用pmml模型,真的好烦啊。

在网上找了一段代码,自己试了1个礼拜才运行成功,记录下自己的过程,以后可以随时用,如果能帮到大家就更好了。

从安装软件说起,嫌麻烦的就别看了。

一、下载工具(俗话说得好,预先善其事必先利其器!哈哈)

我刚开始安装的是eclipse,但有诸多麻烦不能解决,就用了IDEA,和Pycharm一个公司发行的。

首先进入官网: http://www.jetbrains.com/products.html#lang=java

选择IDEA下载:

使用java调用python训练出的pmml模型_第1张图片

 

由于社区版的功能太少,我下载的是企业版的,后边会告诉破解方法。

IDEA的安装教程网上都有,正常安装就好。

企业版的激活码大家可以关注一个公众号,我也是在网上找到的。

http://idea.medeming.com/

关注公众号后粘贴就行了。

二、Java环境安装

参考教程:https://blog.csdn.net/weixin_38381149/article/details/89668578

写博客时想找当时看的博客,但发现了这个很全的,jdk,maven,tomcat都有。

想当初我为了装一个maven花了好久。。。

三、新建Maven项目(我也不知道该怎么说。。。)

  File ==》New==》Project==》Maven

四、接下来在IDEA中配置Maven,这是当时参考的博客:https://www.cnblogs.com/jiangzhaowei/p/9534393.html

五、添加依赖

  由于我只是为了调用模型,没有太多依赖,只添加了这么几个

    

        
            org.jpmml
            pmml-evaluator
            1.4.1
        
        
            org.jpmml
            pmml-evaluator-extension
            1.4.1
        

        
            javax.xml.bind
            jaxb-api
            2.3.0
        
        
            com.sun.xml.bind
            jaxb-core
            2.3.0
        
        
            com.sun.xml.bind
            jaxb-impl
            2.3.0
        

    

六、java调用Python训练出的pmml模型的代码

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import org.jpmml.model.PMMLUtil;
import org.xml.sax.SAXException;

import javax.xml.bind.JAXBException;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ClassificationModel {
    private Evaluator modelEvaluator;

    /**
     * 通过传入 PMML 文件路径来生成机器学习模型
     *
     * @param pmmlFileName pmml 文件路径
     */
    public ClassificationModel(String pmmlFileName) {
        PMML pmml = null;

        try {
            if (pmmlFileName != null) {
                InputStream is = new FileInputStream(pmmlFileName);
                pmml = PMMLUtil.unmarshal(is);
                try {
                    is.close();
                } catch (IOException e) {
                    System.out.println("InputStream close error!");
                }

                ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();

                this.modelEvaluator = (Evaluator) modelEvaluatorFactory.newModelEvaluator(pmml);
                modelEvaluator.verify();
                System.out.println("加载模型成功!");
            }
        } catch (SAXException e) {
            e.printStackTrace();
        } catch (JAXBException e) {
            e.printStackTrace();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }

    }

    // 获取模型需要的特征名称
    public List getFeatureNames() {
        List featureNames = new ArrayList();

        List inputFields = modelEvaluator.getInputFields();

        for (InputField inputField : inputFields) {
            featureNames.add(inputField.getName().toString());
        }
        return featureNames;
    }

    // 获取目标字段名称
    public String getTargetName() {
        return modelEvaluator.getTargetFields().get(0).getName().toString();
    }

    // 使用模型生成概率分布
    private ProbabilityDistribution getProbabilityDistribution(Map arguments) {
        Map evaluateResult = modelEvaluator.evaluate(arguments);

        FieldName fieldName = new FieldName(getTargetName());

        return (ProbabilityDistribution) evaluateResult.get(fieldName);

    }

    // 预测不同分类的概率
    public ValueMap predictProba(Map arguments) {
        ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);
        return probabilityDistribution.getValues();
    }

    // 预测结果分类
    public Object predict(Map arguments) {
        ProbabilityDistribution probabilityDistribution = getProbabilityDistribution(arguments);

        return probabilityDistribution.getPrediction();
    }

    public static void main(String[] args) {
        ClassificationModel clf = new ClassificationModel("D:/JupyterSpace/RandomForestClassifier_Iris.pmml"); //这里模型地址

        List featureNames = clf.getFeatureNames();
        System.out.println("feature: " + featureNames);

        // 构建待预测数据
        Map waitPreSample = new HashMap<>();
        waitPreSample.put(new FieldName("sepal length (cm)"), 10);
        waitPreSample.put(new FieldName("sepal width (cm)"), 1);
        waitPreSample.put(new FieldName("petal length (cm)"), 3);
        waitPreSample.put(new FieldName("petal width (cm)"), 2);

        System.out.println("waitPreSample predict result: " + clf.predict(waitPreSample).toString());
        System.out.println("waitPreSample predictProba result: " + clf.predictProba(waitPreSample).toString());

    }

}

注意事项:

1、类名和文件名要一致

2、打开File  ==》Project Structure

看你的JDK版本和这里是否一致

使用java调用python训练出的pmml模型_第2张图片

使用java调用python训练出的pmml模型_第3张图片

运行程序,查看是否报错。

这是我报的一个错:

NoClassDefFoundError: javax/activation/DataSource

  解决方法是下载:activation.jar包。

  下载地址:

    链接:https://pan.baidu.com/s/14D8cQWIJp2d7h2iljAPZ2A
    提取码:6f37

应该没什么问题了。有问题请留言,一定回复。(有问题一定要告诉我,以后还要用呢。。。)

你可能感兴趣的:(使用java调用python训练出的pmml模型)