使用pmml跨平台部署机器学习模型Demo——房价预测

  基于房价数据,在python中训练得到一个线性回归的模型,在JavaWeb中加载模型完成房价预测的功能。


一、 训练、保存模型

工具:PyCharm-2017、Python-39、sklearn2pmml-0.76.1。

1.训练数据house_price.csv

No square_feet price
1 150 6450
2 200 7450
3 250 8450
4 300 9450
5 350 11450
6 400 15450
7 600 18450

2.训练、保存模型

import sklearn2pmml as pmml
from sklearn2pmml import PMMLPipeline
from sklearn import linear_model as lm
import os
import pandas as pd

def save_model(data, model_path):
    pipeline = PMMLPipeline([("regression", lm.LinearRegression())])
    pipeline.fit(data[["square_feet"]], data["price"])
    pmml.sklearn2pmml(pipeline, model_path, with_repr=True)

if __name__ == "__main__":
    data = pd.read_csv("house_price.csv")
    model_path = model_path = os.path.dirname(os.path.abspath(__file__)) + "/my_liner_model.pmml"
    save_model(data, model_path)
    print("模型保存完成。")

二、JavaWeb应用开发

工具:IntelliJ IDEA-2018、jdk-14.0.2、Tomcat-9.0.37。


创建maven项目,加入依赖项

    
        
            org.jpmml
            pmml-evaluator
            1.4.15
        
        
            com.sun.xml.bind
            jaxb-core
            2.2.11
        
        
            javax.xml
            jaxb-api
            2.1
        
        
            com.sun.xml.bind
            jaxb-impl
            2.2.11
        
        
            javax.servlet
            javax.servlet-api
            3.0.1
        
    

项目结构为
使用pmml跨平台部署机器学习模型Demo——房价预测_第1张图片


界面——index.jsp

<%@ page contentType="text/html;charset=UTF-8" language="java" %>


    使用pmml跨平台部署机器学习模型Demo


使用pmml跨平台部署机器学习模型Demo——房价预测

${price}

Servlet类——PredictServlet.java

package servlet;

import service.PredictService;
import service.imp.PredictServiceImp;

import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

@WebServlet("/PredictServlet")
public class PredictServlet extends HttpServlet {
    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        PredictService predictService = new PredictServiceImp();

        String feet_str = request.getParameter("feet"); //获取前端传来的值
        int feet = Integer.parseInt(feet_str);

        double price = predictService.getPredictedPrice(feet); //预测

        //请求转发,返回结果
        request.setAttribute("price", price);
        request.getRequestDispatcher("/index.jsp").forward(request, response);
    }

    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        this.doPost(request, response);
    }
}


Service接口——PredictService.java

package service;

public interface PredictService {
    public double getPredictedPrice(int feet);
}


Service实现类——PredictServiceImp.java

package service.imp;

import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import service.PredictService;

import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class PredictServiceImp implements PredictService {
    public double getPredictedPrice(int feet) {
        String model_path = "D:\\my_liner_model.pmml"; //pmml模型文件存放路径
        Evaluator model = loadModel(model_path); //加载模型
        Object r = predict(model, feet); //预测
        double result = Double.parseDouble(String.format("%.2f", r)); //格式化
        return result;
    }

    private static Evaluator loadModel(String model_path){
        PMML pmml = new PMML(); //定义PMML对象
        InputStream inputStream; //定义输入流
        try {
            inputStream = new FileInputStream(model_path); //输入流接到磁盘上的模型文件
            pmml = org.jpmml.model.PMMLUtil.unmarshal(inputStream); //将输入流解析为PMML对象
        }catch (Exception e){
            e.printStackTrace();
        }
        ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); //实例化一个模型构造工厂
        Evaluator evaluator = modelEvaluatorFactory.newModelEvaluator(pmml); //将PMML对象构造为Evaluator模型对象

        return evaluator;
    }

    private static Object predict(Evaluator evaluator, int feet){
        Map data = new HashMap(); //定义测试数据Map,存入各元自变量
        data.put("square_feet", feet); //键"square_feet"为自变量的名称,应与训练数据中的自变量名称一致

        List inputFieldList = evaluator.getInputFields(); //得到模型各元自变量的属性列表
        Map arguments = new LinkedHashMap();
        for (InputField inputField : inputFieldList) { //遍历各元自变量的属性列表
            FieldName inputFieldName = inputField.getName();
            Object rawValue = data.get(inputFieldName.getValue()); //取出该元变量的值
            FieldValue inputFieldValue = inputField.prepare(rawValue); //将值加入该元自变量属性中
            arguments.put(inputFieldName, inputFieldValue); //变量名和变量值的对加入LinkedHashMap
        }
        Map result = evaluator.evaluate(arguments); //进行预测
        List targetFieldList = evaluator.getTargetFields(); //得到模型各元因变量的属性列表
        FieldName targetFieldName = targetFieldList.get(0).getName(); //第一元因变量名称
        Object targetFieldValue = result.get(targetFieldName); //由因变量名称得到值

        return targetFieldValue;
    }
}


三、运行测试

  将python中训练得到的pmml模型文件置于D盘根目录下,将文件中的xmlns=".../PMML-4_4"修改为xmlns=".../PMML-4_3"。


启动运行,浏览器访问http://localhost/,进入页面
使用pmml跨平台部署机器学习模型Demo——房价预测_第2张图片


输入房子英尺数,点击‘预测房价’按钮,展示出预测价格
使用pmml跨平台部署机器学习模型Demo——房价预测_第3张图片

使用pmml跨平台部署机器学习模型Demo——房价预测_第4张图片


打包下载:
https://download.csdn.net/download/Albert201605/45648664


End.

你可能感兴趣的:(使用pmml跨平台部署机器学习模型Demo——房价预测)