sklearn2pmml安装使用

公司代码是Java,但是算法部分使用了Python的sklearn,考虑用sklearn2pmml生成pmml文件,再由java调用,实现跨平台使用。

  1. 安装sklearn2pmml
pip install sklearn2pmml

需要注意的是,

  • scikit-learn的版本号需<=0.20.4,使用0.20.4之后的版本会报错,
AttributeError: module 'sklearn.externals.joblib' has no attribute '__version__'

因为sklearn.externals.joblib在0.21中弃用,将在0.23中删除。

DeprecationWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+
  • java版本号需>=1.7

我的配置是,

python: 3.6.8
sklearn: 0.20.4
sklearn.externals.joblib: 0.13.2
pandas: 0.24.1
sklearn_pandas: 1.8.0
sklearn2pmml: 0.48.0
java: 1.8.0_144
  1. 测试Python代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn2pmml import PMMLPipeline, sklearn2pmml

iris = load_iris()

train, test, train_labels, test_labels = train_test_split(iris.data, iris.target, test_size=0.2, random_state=0)

pipeline = PMMLPipeline([
    ("classifier", tree.DecisionTreeClassifier(random_state=9))
])

pipeline.fit(train, train_labels)

sklearn2pmml(pipeline, 'result.pmml', with_repr=True, debug=True)

生成的pmml文件如下图所示,


image.png

运行自己的代码时可能会出现以下错误,

RuntimeError: The JPMML-SkLearn conversion application has failed. The Java executable should have printed more information about the failure into its standard output and/or standard error streams

出现此错误时需要查看train和train_labels的列名,要求没有重复并且格式正确

  1. 测试Java代码
    下载jpmml-sklearn-executable-1.5.7.jar和pmml-evaluator-1.4.3.jar,并引用jar包创建新工程。
    经验证,引用上述jar包不会报错,不同的版本可能会报错,
    error.png

    以下为Java代码,
package javaTopython;

import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
 
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.TargetField;

public class PmmlFile {
    public static void main(String[] args) throws Exception {
        String  pathxml="tree.pmml";
        Map  map=new HashMap();
        map.put("x1", 5.1);
        map.put("x2", 3.5);
        map.put("x3", 1.4);
        map.put("x4", 0.2);    
        predictLrHeart(map, pathxml);
    }
    
    public static void predictLrHeart(Map irismap,String  pathxml)throws Exception {
 
        PMML pmml;
        // 模型导入
        File file = new File(pathxml);
        InputStream inputStream = new FileInputStream(file);
        try (InputStream is = inputStream) {
            pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
 
            ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory
                    .newInstance();
            ModelEvaluator modelEvaluator = modelEvaluatorFactory
                    .newModelEvaluator(pmml);
            Evaluator evaluator = (Evaluator) modelEvaluator;
 
            List inputFields = evaluator.getInputFields();
            // 过模型的原始特征,从画像中获取数据,作为模型输入
            Map arguments = new LinkedHashMap<>();
            for (InputField inputField : inputFields) {
                FieldName inputFieldName = inputField.getName();
                Object rawValue = irismap
                        .get(inputFieldName.getValue());
                FieldValue inputFieldValue = inputField.prepare(rawValue);
                arguments.put(inputFieldName, inputFieldValue);
            }
 
            Map results = evaluator.evaluate(arguments);
            List targetFields = evaluator.getTargetFields();
            //对于分类问题等有多个输出。
            for (TargetField targetField : targetFields) {
                FieldName targetFieldName = targetField.getName();
                Object targetFieldValue = results.get(targetFieldName);
                System.err.println("target: " + targetFieldName.getValue()
                        + " value: " + targetFieldValue);
            }
        }
    }
}

运行结果如下,

target y value: ProbabilityDistribution{result=0, probability_entries=[0=0.8876504283659372, 1=0.11232695495162393, 2=2.2616682438804697E-5]}

需要注意模型简化处理的情况,此时pmml文件中的可能会省略掉系数为零的列,所以最好有一个检验。

参考:
sklearn2pmml安装使用

你可能感兴趣的:(sklearn2pmml安装使用)