LightGBM Java实现在线预测

LightGBM是三大知名GBDT的实现之一,支持二分类,多分类。与XGBoost相比,LGBM不需要通过所有样本计算信息增益,而且内置特征降维技术,支持高效率的并行训练,并且具有更快的训练速度、更低的内存消耗、更好的准确率、支持分布式可以快速处理海量数据等优点。 但在Java方面的支持不如XGBoost,没有封装好的Java在线预测包。

至于XGB和LGB原理和优缺点自行百度,不在本文范围内。

近期因为公司上线了很多XGBoost模型,在XGBoost训练消耗大量内存,为了节约资源选用LGBM替代XGBoost。在线预测服务就需要用Java封装训练好的LGBM模型,供线上实时预测使用。 在网上百度大多实现方式都是将模型封装为PMML格式,再在预测服务里预测结果。但是PMML版本模型单次预测需要100ms以上,显然不能满足性能需求。

于是展开Google大法,发现微软开源的mmlspark库(https://github.com/Azure/mmlspark.git),其中有一个包可以将LightGBM部署在spark环境中分布式训练。使用swig封装LightGBM的接口,然后使用jni的方式在spark中调用。赶紧找到打包好的maven lib。


      com.microsoft.ml.lightgbm
      lightgbmlib
      2.3.180

实现代码:

package com.tuhu.algo.etl.features.model;

import com.microsoft.ml.lightgbm.*;
import org.apache.commons.lang3.StringUtils;

import java.io.IOException;

/**
 * 

* * @Author: fc.w * @Date: 2020/11/14 16:29 */ public class LightGBMModelLoad { private SWIGTYPE_p_void boosterPtr; private String modelString; public LightGBMModelLoad(String modelString) { this.modelString = modelString; initModel(); } public void initModel() { try { init(modelString); } catch (Exception e) { throw new RuntimeException("模型加载失败", e); } } public void init(String modelString) throws Exception { initEnv(); if (StringUtils.isEmpty(modelString)) { throw new Exception("the inpute model string must not null"); } this.boosterPtr = getBoosterPtrFromModelString(modelString); } private void initEnv() throws IOException { String osPrefix = NativeLoader.getOSPrefix(); new NativeLoader("/com/microsoft/ml/lightgbm").loadLibraryByName(osPrefix + "_lightgbm"); new NativeLoader("/com/microsoft/ml/lightgbm").loadLibraryByName(osPrefix + "_lightgbm_swig"); } private void validate(int result) throws Exception { if (result == -1) { throw new Exception("Booster LoadFromString" + "call failed in LightGBM with error: " + lightgbmlib.LGBM_GetLastError()); } } private SWIGTYPE_p_void getBoosterPtrFromModelString(String lgbModelString) throws Exception { SWIGTYPE_p_p_void boosterOutPtr = lightgbmlib.voidpp_handle(); SWIGTYPE_p_int numItersOut = lightgbmlib.new_intp(); validate( lightgbmlib.LGBM_BoosterLoadModelFromString(lgbModelString, numItersOut, boosterOutPtr) ); return lightgbmlib.voidpp_value(boosterOutPtr); } /** * 预测 * @param data 批量向量 * @param numRows 预测行数 * @param numFeatures 向量大小 * @return 批量预测结果 */ public double[] predictForMat(double[] data, int numRows, int numFeatures) { int data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64; int isRowMajor = 1; String datasetParams = ""; SWIGTYPE_p_double scoredDataOutPtr = lightgbmlib.new_doubleArray(numRows * numFeatures); SWIGTYPE_p_long_long scoredDataLengthLongPtr = lightgbmlib.new_int64_tp(); lightgbmlib.int64_tp_assign(scoredDataLengthLongPtr, numRows * numFeatures); SWIGTYPE_p_double doubleArray = lightgbmlib.new_doubleArray(data.length); for (int i = 0; i < data.length; i++) { lightgbmlib.doubleArray_setitem(doubleArray, i, data[i]); } SWIGTYPE_p_void pdata = lightgbmlib.double_to_voidp_ptr(doubleArray); try { lightgbmlib.LGBM_BoosterPredictForMat( boosterPtr, pdata, data64bitType, numRows, numFeatures, isRowMajor, 0, -1, datasetParams, scoredDataLengthLongPtr, scoredDataOutPtr); return predToArray(scoredDataOutPtr, numRows); } catch (Exception e) { e.printStackTrace(); System.out.println(lightgbmlib.LastErrorMsg()); } finally { lightgbmlib.delete_doublep(doubleArray); lightgbmlib.delete_doublep(scoredDataOutPtr); lightgbmlib.delete_int64_tp(scoredDataLengthLongPtr); } return new double[numRows]; } private double[] predToArray(SWIGTYPE_p_double scoredDataOutPtr, int numRows) { double[] res = new double[numRows]; for (int i = 0; i < numRows; i++) { res[i] = lightgbmlib.doubleArray_getitem(scoredDataOutPtr, i); } return res; } }

资料
LightGBM官网
https://github.com/Azure/mmlspark
无痛看懂LightGBM原文

你可能感兴趣的:(LightGBM Java实现在线预测)