使用java自带的Runtime.getRuntime().exec(args);方法直接调用python脚本
1、首先将python代码训练的模型保存为pmml格式,代码如下
model = xgb.XGBClassifier()
from sklearn2pmml import PMMLPipeline
pipeline = PMMLPipeline([("classifier", model)])
pipeline.fit(X_train,y_train)
from sklearn2pmml import sklearn2pmml
sklearn2pmml(pipeline, "xgb.pmml", with_repr = True)
然后使用java读取pmml文件对数据进行预测,
后来选择使用java调用虚拟机的方式运行python脚本。
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import javax.xml.bind.JAXBException;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.model.PMMLUtil;
import org.xml.sax.SAXException;
/**
* 读取pmml 获取模型
*
* @author liaotuo
*
*/
public class ModelInvoker {
private ModelEvaluator modelEvaluator;
// 通过文件读取模型
public ModelInvoker(String pmmlFileName) {
PMML pmml = null;
InputStream is = null;
try {
if (pmmlFileName != null) {
is = ModelInvoker.class.getClassLoader().getResourceAsStream(pmmlFileName);
pmml = PMMLUtil.unmarshal(is);
}
try {
is.close();
} catch (IOException localIOException) {
}
this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
} catch (SAXException e) {
pmml = null;
} catch (JAXBException e) {
pmml = null;
} finally {
try {
is.close();
} catch (IOException localIOException3) {
}
}
this.modelEvaluator.verify();
System.out.println("模型读取成功");
}
// 通过输入流读取模型
public ModelInvoker(InputStream is) {
PMML pmml = null;
try {
pmml = PMMLUtil.unmarshal(is);
try {
is.close();
} catch (IOException localIOException) {
}
this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
} catch (SAXException e) {
pmml = null;
} catch (JAXBException e) {
pmml = null;
} finally {
try {
is.close();
} catch (IOException localIOException3) {
}
}
this.modelEvaluator.verify();
}
public Map invoke(Map paramsMap) {
return this.modelEvaluator.evaluate(paramsMap);
}
}
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.dmg.pmml.FieldName;
/**
* 使用模型
*
* @author gs
*
*/
public class ModelCalc {
static String pmmlPath = "E:\\workspace\\python\\tydic\\model\\xgb.pmml";
public static void main(String[] args) throws IOException {
String modelArgsFilePath = "E:\\workspace\\python\\tydic\\model\\test\\X_val";
predictFromFile(modelArgsFilePath);
}
/**
* 以文件名的方式读取输入数据进行预测
* @param modelArgsFilePath
* @throws FileNotFoundException
* @throws IOException
*/
public static List predictFromFile(String modelArgsFilePath) throws FileNotFoundException, IOException {
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(pmmlPath));
ModelInvoker invoker = new ModelInvoker(bis);
List
这个主要是java代码的书写
public class PythonDemo {
public static void main(String[] args) {
try {
// 需传入的参数
String host = "localhost";
String port = "3306";
String user = "root";
String passwd = "123456";
String path = "C:/";
String database = "dic_coll_consume";
String start_date = "2017-08-01";
String end_date = "2017-09-01";
args = new String[] { "python", "C:\\model_train.py", host, port,user, passwd, path, database, start_date, end_date };
Process pr = Runtime.getRuntime().exec(args);
print(pr.getInputStream());
print(pr.getErrorStream());
} catch (Exception e) {
e.printStackTrace();
}
}
private static String decodeUnicode(String line) {
String l = null;
try {
l = new String(line.getBytes(), "utf8");
} catch (UnsupportedEncodingException e) {
System.out.println("wrong");
e.printStackTrace();
}
return l;
}
private static void print(InputStream stream){
new Thread(new Runnable() {
public void run() {
try{
BufferedReader in = new BufferedReader(new InputStreamReader(stream));
String line;
while ((line = in.readLine()) != null) {
line = decodeUnicode(line);
System.out.println(line);
}
in.close();
System.out.println("end");
} catch (Exception e) {
e.printStackTrace();
}
}
}).start();
}
}
这个地方主要遇到的坑就是由于程序运行时会弹出很多的信息,而使用pr.waitfor()时缓存很小,很容易使程序阻塞 后来采用多线程的方式将信息打印出来,立马就解决了问题。