python训练模型上线问题总结

java调用python模型

  1. PMML格式
  2. 使用java自带的Runtime.getRuntime().exec(args);方法直接调用python脚本

    PMML格式

    1、首先将python代码训练的模型保存为pmml格式,代码如下

model = xgb.XGBClassifier()
from sklearn2pmml import PMMLPipeline
pipeline = PMMLPipeline([("classifier", model)])
pipeline.fit(X_train,y_train)    
from sklearn2pmml import sklearn2pmml
sklearn2pmml(pipeline, "xgb.pmml", with_repr = True) 

然后使用java读取pmml文件对数据进行预测,
后来选择使用java调用虚拟机的方式运行python脚本。

import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import javax.xml.bind.JAXBException;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.model.PMMLUtil;
import org.xml.sax.SAXException;

/**
 * 读取pmml 获取模型
 * 
 * @author liaotuo
 *
 */
public class ModelInvoker {
    private ModelEvaluator modelEvaluator;

    // 通过文件读取模型
    public ModelInvoker(String pmmlFileName) {
        PMML pmml = null;
        InputStream is = null;
        try {
            if (pmmlFileName != null) {
                is = ModelInvoker.class.getClassLoader().getResourceAsStream(pmmlFileName);
                pmml = PMMLUtil.unmarshal(is);
            }
            try {
                is.close();
            } catch (IOException localIOException) {
            }
            this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
        } catch (SAXException e) {
            pmml = null;
        } catch (JAXBException e) {
            pmml = null;
        } finally {
            try {
                is.close();
            } catch (IOException localIOException3) {
            }
        }
        this.modelEvaluator.verify();
        System.out.println("模型读取成功");
    }
    // 通过输入流读取模型
    public ModelInvoker(InputStream is) {
        PMML pmml = null;
        try {
            pmml = PMMLUtil.unmarshal(is);
            try {
                is.close();
            } catch (IOException localIOException) {
            }
            this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelEvaluator(pmml);
        } catch (SAXException e) {
            pmml = null;
        } catch (JAXBException e) {
            pmml = null;
        } finally {
            try {
                is.close();
            } catch (IOException localIOException3) {
            }
        }
        this.modelEvaluator.verify();
    }
    public Map invoke(Map paramsMap) {
        return this.modelEvaluator.evaluate(paramsMap);
    }
}
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.dmg.pmml.FieldName;

/**
 * 使用模型
 * 
 * @author gs
 *
 */
public class ModelCalc {


    static String pmmlPath = "E:\\workspace\\python\\tydic\\model\\xgb.pmml";
    public static void main(String[] args) throws IOException {
        String modelArgsFilePath = "E:\\workspace\\python\\tydic\\model\\test\\X_val";

        predictFromFile(modelArgsFilePath);
    }
    /**
     * 以文件名的方式读取输入数据进行预测
     * @param modelArgsFilePath
     * @throws FileNotFoundException
     * @throws IOException
     */
    public static List predictFromFile(String modelArgsFilePath) throws FileNotFoundException, IOException {

        BufferedInputStream bis = new BufferedInputStream(new FileInputStream(pmmlPath));

        ModelInvoker invoker = new ModelInvoker(bis);
        List> paramList = getDataFromFile(modelArgsFilePath);
        List predictResult = new ArrayList();
        int lineNum = 0; // 当前处理行数
        for (Map param : paramList) {

            lineNum++;
            System.out.println("======当前行: " + lineNum + "=======");
            Map result = invoker.invoke(param);
            Set keySet = result.keySet(); // 获取结果的keySet
            int i = 0;
            for (FieldName fn : keySet) {
                String probility1 = result.get(fn).toString(); //预测为1的概率
                System.out.println(probility1);
//              i++;
//              if(i%3==0){
//                  predictResult.add(probility1);
//              }

            }
        }
        return predictResult;
    }

    /**
     * 读取参数文件
     * 
     * @param filePath
     * @return
     * @throws IOException
     */
    private static List> getDataFromFile(String filePath) throws IOException {
        BufferedReader br = new BufferedReader(new FileReader(filePath));
        String[] nameArr = br.readLine().split(" "); // 读取表头的名字
        List> list = new ArrayList();
        String paramLine = null; // 一行参数
        // 循环读取 每次读取一行数据
        while ((paramLine = br.readLine()) != null) {

            Map map = new HashMap();
            String[] paramLineArr = paramLine.split(" ");
            // 一次循环处理一行数据
            for (int i = 0; i < paramLineArr.length; i++) {
                map.put(new FieldName(nameArr[i]), paramLineArr[i]); // 将表头和值组成map
            }
            list.add(map);                                              // 加入list中

        }
        return list;
    }
}

使用Runtime.getRuntime().exec(args)

这个主要是java代码的书写

public class PythonDemo {
    public static void main(String[] args) {
        try {
            // 需传入的参数
            String host = "localhost";
            String port = "3306";
            String user = "root";
            String passwd = "123456";
            String path = "C:/";
            String database = "dic_coll_consume";
            String start_date = "2017-08-01";
            String end_date = "2017-09-01";

            args = new String[] { "python", "C:\\model_train.py", host, port,user, passwd, path, database, start_date, end_date };
            Process pr = Runtime.getRuntime().exec(args);
            print(pr.getInputStream());
            print(pr.getErrorStream());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private static String decodeUnicode(String line) {
        String l = null;
        try {
            l = new String(line.getBytes(), "utf8");
        } catch (UnsupportedEncodingException e) {
            System.out.println("wrong");
            e.printStackTrace();
        }
        return l;
    }

    private static void print(InputStream stream){
        new Thread(new Runnable() {
            public void run() {
                try{
                    BufferedReader in = new BufferedReader(new InputStreamReader(stream));
                    String line;
                    while ((line = in.readLine()) != null) {
                        line = decodeUnicode(line);
                        System.out.println(line);
                    }
                    in.close();
                    System.out.println("end");
                } catch (Exception e) {
                    e.printStackTrace();
                }
           }
        }).start();    
    }

}

这个地方主要遇到的坑就是由于程序运行时会弹出很多的信息,而使用pr.waitfor()时缓存很小,很容易使程序阻塞 后来采用多线程的方式将信息打印出来,立马就解决了问题。

你可能感兴趣的:(python)