pom.xml
org.ujmp ujmp-core 0.3.0 org.jfree jfreechart 1.5.0
LogisticRegression主类
package logisticregression; import org.ujmp.core.DenseMatrix; import org.ujmp.core.Matrix; public class LogisticRegression { public static double[] train(double[][] data, double[] classValues) { if (data != null && classValues != null && data.length == classValues.length) { Matrix matrWeights = DenseMatrix.Factory.zeros(data[0].length + 1, 1); Matrix matrData = DenseMatrix.Factory.zeros(data.length, data[0].length + 1); Matrix matrLable = DenseMatrix.Factory.zeros(data.length, 1); for (int i = 0; i < data.length; i++) { matrData.setAsDouble(1.0, i, 0); matrLable.setAsDouble(classValues[i], i, 0); for (int j = 0; j < data[0].length; j++) { matrData.setAsDouble(data[i][j], i, j + 1); if (i == 0) { matrWeights.setAsDouble(1.0, j, 0); } } } matrWeights.setAsDouble(-0.5, data[0].length, 0); double step = 0.01; int maxCycle = 5000000; for (int i = 0; i < maxCycle; i++) { Matrix h = sigmoid(matrData.mtimes(matrWeights)); Matrix difference = matrLable.minus(h); matrWeights = matrWeights.plus(matrData.transpose().mtimes(difference).times(step)); } double[] rtn = new double[(int) matrWeights.getRowCount()]; for (long i = 0; i < matrWeights.getRowCount(); i++) { rtn[(int) i] = matrWeights.getAsDouble(i, 0); } return rtn; } return null; } public static Matrix sigmoid(Matrix sourceMatrix) { Matrix rtn = DenseMatrix.Factory.zeros(sourceMatrix.getRowCount(), sourceMatrix.getColumnCount()); for (int i = 0; i < sourceMatrix.getRowCount(); i++) { for (int j = 0; j < sourceMatrix.getColumnCount(); j++) { rtn.setAsDouble(sigmoid(sourceMatrix.getAsDouble(i, j)), i, j); } } return rtn; } public static double sigmoid(double source) { return 1.0 / (1 + Math.exp(-1 * source)); } public static double getValue(double[] sourceData, double[] model) { double logisticRegressionValue = model[0]; for (int i = 0; i < sourceData.length; i++) { logisticRegressionValue = logisticRegressionValue + sourceData[i] * model[i + 1]; } logisticRegressionValue = sigmoid(logisticRegressionValue); return logisticRegressionValue; } }
逻辑回归测试类
package logisticregression; import common.ScatterPlot; public class LogisicRegressionTest { public static void main(String[] args) { double[][] sourceData = new double[][] { { -1, 1 }, { 0, 1 }, { 1, -1 }, { 1, 0 }, { 0, 0.1 }, { 0, -0.1 }, { -1, -1.1 }, { 1, 0.9 } }; double[] classValue = new double[] { 1, 1, 0, 0, 1, 0, 0, 0 }; double[] modle = LogisticRegression.train(sourceData, classValue); double logicValue = LogisticRegression.getValue(new double[] { 0, 0 }, modle); System.out.println("---model---"); for (int i = 0; i < modle.length; i++) { System.out.println(modle[i]); } System.out.println("-----------"); System.out.println(logicValue); double[][][] chartData = new double[3][][]; double[][] c0 = new double[2][5]; double[][] c1 = new double[2][3]; c1[0][0] = sourceData[0][0]; c1[1][0] = sourceData[0][1]; c1[0][1] = sourceData[1][0]; c1[1][1] = sourceData[1][1]; c0[0][0] = sourceData[2][0]; c0[1][0] = sourceData[2][1]; c0[0][1] = sourceData[3][0]; c0[1][1] = sourceData[3][1]; c1[0][2] = sourceData[4][0]; c1[1][2] = sourceData[4][1]; c0[0][2] = sourceData[5][0]; c0[1][2] = sourceData[5][1]; c0[0][3] = sourceData[6][0]; c0[1][3] = sourceData[6][1]; c0[0][4] = sourceData[7][0]; c0[1][4] = sourceData[7][1]; String[] c = new String[] { "1", "0", "L" }; double[][] c2 = new double[2][21]; int ind = 0; for (double x = -1; x <= 1; x = x + 0.1) { c2[0][ind] = x; c2[1][ind] = (-modle[0] - modle[1] * x) / modle[2]; ind++; } chartData[0] = c0; chartData[1] = c1; chartData[2] = c2; ScatterPlot.showScatterPlotChart("LogisticRegression", c, chartData); } }