java实现逻辑回归

阅读更多

pom.xml

		
                
			org.ujmp
			ujmp-core
			0.3.0
		
                
		
			org.jfree
			jfreechart
			1.5.0
		

 

    LogisticRegression主类

package logisticregression;

import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;

public class LogisticRegression {

	public static double[] train(double[][] data, double[] classValues) {

		if (data != null && classValues != null && data.length == classValues.length) {
			Matrix matrWeights = DenseMatrix.Factory.zeros(data[0].length + 1, 1);
			Matrix matrData = DenseMatrix.Factory.zeros(data.length, data[0].length + 1);
			Matrix matrLable = DenseMatrix.Factory.zeros(data.length, 1);
			for (int i = 0; i < data.length; i++) {
				matrData.setAsDouble(1.0, i, 0);
				matrLable.setAsDouble(classValues[i], i, 0);
				for (int j = 0; j < data[0].length; j++) {
					matrData.setAsDouble(data[i][j], i, j + 1);
					if (i == 0) {
						matrWeights.setAsDouble(1.0, j, 0);

					}
				}
			}
			matrWeights.setAsDouble(-0.5, data[0].length, 0);

			double step = 0.01;
			int maxCycle = 5000000;

			for (int i = 0; i < maxCycle; i++) {
				Matrix h = sigmoid(matrData.mtimes(matrWeights));
				Matrix difference = matrLable.minus(h);
				matrWeights = matrWeights.plus(matrData.transpose().mtimes(difference).times(step));
			}

			double[] rtn = new double[(int) matrWeights.getRowCount()];
			for (long i = 0; i < matrWeights.getRowCount(); i++) {
				rtn[(int) i] = matrWeights.getAsDouble(i, 0);
			}

			return rtn;

		}

		return null;
	}

	public static Matrix sigmoid(Matrix sourceMatrix) {
		Matrix rtn = DenseMatrix.Factory.zeros(sourceMatrix.getRowCount(), sourceMatrix.getColumnCount());
		for (int i = 0; i < sourceMatrix.getRowCount(); i++) {
			for (int j = 0; j < sourceMatrix.getColumnCount(); j++) {
				rtn.setAsDouble(sigmoid(sourceMatrix.getAsDouble(i, j)), i, j);
			}

		}

		return rtn;
	}

	public static double sigmoid(double source) {
		return 1.0 / (1 + Math.exp(-1 * source));
	}

	public static double getValue(double[] sourceData, double[] model) {
		double logisticRegressionValue = model[0];
		for (int i = 0; i < sourceData.length; i++) {
			logisticRegressionValue = logisticRegressionValue + sourceData[i] * model[i + 1];
		}
		logisticRegressionValue = sigmoid(logisticRegressionValue);

		return logisticRegressionValue;
	}

}

 

    逻辑回归测试类

package logisticregression;

import common.ScatterPlot;

public class LogisicRegressionTest {

	public static void main(String[] args) {
		double[][] sourceData = new double[][] { { -1, 1 }, { 0, 1 }, { 1, -1 }, { 1, 0 }, { 0, 0.1 }, { 0, -0.1 }, { -1, -1.1 }, { 1, 0.9 } };
		double[] classValue = new double[] { 1, 1, 0, 0, 1, 0, 0, 0 };
		double[] modle = LogisticRegression.train(sourceData, classValue);
		double logicValue = LogisticRegression.getValue(new double[] { 0, 0 }, modle);
		System.out.println("---model---");
		for (int i = 0; i < modle.length; i++) {
			System.out.println(modle[i]);
		}
		System.out.println("-----------");
		System.out.println(logicValue);

		double[][][] chartData = new double[3][][];
		double[][] c0 = new double[2][5];
		double[][] c1 = new double[2][3];
		c1[0][0] = sourceData[0][0];
		c1[1][0] = sourceData[0][1];

		c1[0][1] = sourceData[1][0];
		c1[1][1] = sourceData[1][1];

		c0[0][0] = sourceData[2][0];
		c0[1][0] = sourceData[2][1];

		c0[0][1] = sourceData[3][0];
		c0[1][1] = sourceData[3][1];

		c1[0][2] = sourceData[4][0];
		c1[1][2] = sourceData[4][1];

		c0[0][2] = sourceData[5][0];
		c0[1][2] = sourceData[5][1];

		c0[0][3] = sourceData[6][0];
		c0[1][3] = sourceData[6][1];

		c0[0][4] = sourceData[7][0];
		c0[1][4] = sourceData[7][1];

		String[] c = new String[] { "1", "0", "L" };
		double[][] c2 = new double[2][21];
		int ind = 0;
		for (double x = -1; x <= 1; x = x + 0.1) {
			c2[0][ind] = x;
			c2[1][ind] = (-modle[0] - modle[1] * x) / modle[2];
			ind++;
		}

		chartData[0] = c0;
		chartData[1] = c1;
		chartData[2] = c2;

		ScatterPlot.showScatterPlotChart("LogisticRegression", c, chartData);

	}

}

 

 

 

 

 

 

你可能感兴趣的:(机器学习)