java调用python数据分析模型_pmml:java调用Python训练的模型之xgboost

使用sklearn2pmml 保存Python的模型

第一步:Python端安装sklearn2pmml,这里安装的是PMML最新版本,4.4 ,这里的4.4和java的1.5.x.jar对应

pip install sklearn2pmml

第二步:Python端修改代码

pipeline = PMMLPipeline([('classifier', clf)])

pipeline.fit(X_train, Y_train)

sklearn2pmml(pipeline, 'output/XGboost1.pmml', with_repr=True, debug=True)from xgboost.sklearn import XGBClassifier

from sklearn.metrics import accuracy_score

from sklearn.metrics import recall_score

from sklearn.metrics import precision_score

from sklearn.metrics import f1_score

from sklearn.metrics import confusion_matrix

from sklearn2pmml import PMMLPipeline, sklearn2pmml

def trian_xgboost(df, fwmodel):

'''

训练模型,并测试结果

:param arr: 数据

:param samplemore: 是否往多了采样

:param fwmodel: 模型的保存路径

:return:

'''

y = df['lable'].values

df.drop(['lable'], axis=1, inplace=True)

X = df.values

print(df.columns)

print('=============xgboost=============')

X_train, X_test, Y_train, Y_test = train_test_split(X, y, test_size=0.25, random_state=0)

clf = XGBClassifier(

silent=0, # 设置成1则没有运行信息输出,最好是设置为0.是否在运行升级时打印消息。

# nthread=4,# cpu 线程数 默认最大

learning_rate=0.07, # 如同学习率

min_child_weight=3,

# 这个参数默认是 1,是每个叶子里面 h 的和至少是多少,对正负样本不均衡时的 0-1 分类而言

# ,假设 h 在 0.01 附近,min_child_weight 为 1 意味着叶子节点中最少需要包含 100 个样本。

# 这个参数非常影响结果,控制叶子节点中二阶导的和的最小值,该参数值越小,越容易 overfitting。

max_depth=12, # 构建树的深度,越大越容易过拟合

gamma=0, # 树的叶子节点上作进一步分区所需的最小损失减少,越大越保守,一般0.1、0.2这样子。

subsample=1, # 随机采样训练样本 训练实例的子采样比

max_delta_step=0, # 最大增量步长,我们允许每个树的权重估计。

# colsample_bytree=1, # 生成树时进行的列采样

reg_lambda=1, # 控制模型复杂度的权重值的L2正则化项参数,参数越大,模型越不容易过拟合。

# objective='multi:softmax', # 多分类的问题 指定学习任务和相应的学习目标

n_estimators=100, # 树的个数

seed=1000

)

pipeline = PMMLPipeline([('classifier', clf)])

pipeline.fit(X_train, Y_train)

y_pred = pipeline.predict(X_test)

print("================================")

print(y_pred)

print('=================================')

print(Y_test)

print('=========================')

# print(X_test)

predictions = [round(value) for value in y_pred]

accuracy = accuracy_score(Y_test, predictions)

print("Accuracy: %.2f%%" % (accuracy * 100.0))

xx = precision_score(Y_test, predictions, average='macro')

yy = recall_score(Y_test, predictions, average='macro')

print("精确率: %.2f%%" % (xx * 100.0))

print("召回率: %.2f%%" % (yy * 100.0))

mm = f1_score(Y_test, predictions, average='weighted')

print("f1 %.2f%%" % (mm * 100.0))

# 混淆矩阵

nn = confusion_matrix(Y_test, predictions)

print("混淆矩阵", nn)

# 分类报告:precision / recall / fi - score / 均值 / 分类个数

target_names = ['class 0', 'class 1']

print(classification_report(Y_test, predictions, target_names=target_names))

sklearn2pmml(pipeline, 'output/XGboost1.pmml', with_repr=True, debug=True)

第三步:生成pmml模型

第四步:java端引入maven需要的jar包

注意:这里pmml相关的包必须是1.5.x以上,因为Python端生成的模型是4.4的,否则就会包版本不对应的错误

java.lang.IllegalArgumentException

org.jpmml

pmml-evaluator

1.5.11

org.jpmml

pmml-evaluator-extension

1.5.11

javax.xml.bind

jaxb-api

2.3.0

com.sun.xml.bind

jaxb-core

2.3.0

com.sun.xml.bind

jaxb-impl

2.3.0

第五步:java端代码,因为导入的是1.5.x的jar包所以代码跟1.4.x的写的不一样

import org.dmg.pmml.FieldName;

import org.dmg.pmml.PMML;

import org.jpmml.evaluator.Evaluator;

import org.jpmml.evaluator.FieldValue;

import org.jpmml.evaluator.InputField;

import org.jpmml.evaluator.ModelEvaluator;

import org.jpmml.evaluator.ModelEvaluatorBuilder;

import org.jpmml.evaluator.ModelEvaluatorFactory;

import org.jpmml.evaluator.TargetField;

import org.jpmml.evaluator.ReportFactory;

public class Rmd_prod {

public static void predictLrHeart(Map irismap,String pathxml)throws Exception {

PMML pmml;

// 模型导入

File file = new File(pathxml);

InputStream inputStream = new FileInputStream(file);

try (InputStream is = inputStream) {

pmml = org.jpmml.model.PMMLUtil.unmarshal(is);

//pmml = org.jpmml.model.PMMLUtil.unmarshal(is);

ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory

.newInstance();

//1.5.x版本

ModelEvaluatorBuilder modelEvaluatorBuilder = new ModelEvaluatorBuilder(pmml);

Evaluator evaluator = modelEvaluatorBuilder.build();

// Activate the generation of MathML prediction reports

modelEvaluatorBuilder.setModelEvaluatorFactory(modelEvaluatorFactory);

Evaluator reportingEvaluator = modelEvaluatorBuilder.build();

//1.4.x版本

//ModelEvaluator> modelEvaluator = modelEvaluatorFactory.newModelEvaluator(pmml);

//Evaluator evaluator = (Evaluator) modelEvaluator;

List inputFields = evaluator.getInputFields();

// 过模型的原始特征,从画像中获取数据,作为模型输入

Map arguments = new LinkedHashMap<>();

for (InputField inputField : inputFields) {

FieldName inputFieldName = inputField.getName();

Object rawValue = irismap

.get(inputFieldName.getValue());

FieldValue inputFieldValue = inputField.prepare(rawValue);

arguments.put(inputFieldName, inputFieldValue);

}

Map results = evaluator.evaluate(arguments);

List targetFields = evaluator.getTargetFields();

//对于分类问题等有多个输出。

for (TargetField targetField : targetFields) {

FieldName targetFieldName = targetField.getName();

Object targetFieldValue = results.get(targetFieldName);

System.err.println("target: " + targetFieldName.getValue()

+ " value: " + targetFieldValue);

}

}

}

public static void main(String[] args) throws Exception{

String pathxml="/Users/wl/Documents/Pycharm/output/XGboost1.pmml";

Map map=new HashMap();

map.put("x1",0.7);

map.put("x2",0.2);

map.put("x3",(double)1);

map.put("x4",(double)0);

map.put("x5",(double)1);

map.put("x6",(double)0);

map.put("x7",(double)0);

map.put("x8",(double)0);

map.put("x9",(double)0);

map.put("x3",(double)1);

map.put("x4",(double)0);

map.put("x5",(double)1);

map.put("x6",(double)0);

map.put("x7",(double)0);

map.put("x8",(double)0);

map.put("x9",(double)0);

map.put("x10",(double)0);

map.put("x11",(double)1);

map.put("x12",(double)1);

map.put("x13",(double)1);

map.put("x14",(double)0);

map.put("x15",(double)1);

map.put("x16",(double)0);

map.put("x17",(double)0);

map.put("x18",(double)0);

map.put("x19",(double)1);

map.put("x20",(double)1);

map.put("x21",(double)0);

map.put("x22",(double)0);

map.put("x23",(double)0);

map.put("x24",(double)0);

map.put("x25",(double)0);

map.put("x26",(double)0);

map.put("x27",(double)0);

map.put("x28",(double)0);

map.put("x29",(double)1);

map.put("x30",(double)1);

map.put("x31",(double)1);

map.put("x32",(double)0);

map.put("x33",(double)1);

map.put("x34",(double)0);

map.put("x35",(double)0);

map.put("x36",(double)1);

map.put("x37",(double)0);

map.put("x38",(double)0);

map.put("x39",(double)1);

map.put("x40",(double)0);

map.put("x41",(double)0);

map.put("x42",(double)0);

map.put("x43",(double)1);

map.put("x44",(double)1);

map.put("x45",(double)1);

map.put("x46",(double)0);

map.put("x47",(double)0);

map.put("x48",(double)0);

map.put("x49",(double)0);

map.put("x50",(double)0);

map.put("x51",(double)0);

map.put("x52",(double)0);

map.put("x53",(double)1);

map.put("x54",(double)1);

rmd.predictLrHeart(map, pathxml);

}

}

你可能感兴趣的:(java调用python数据分析模型_pmml:java调用Python训练的模型之xgboost)