网上看一个达人用java写的一元线性回归的实现,我觉得挺有用的,一些企业做数据挖掘不是用到了,预测运营收入的功能吗?采用一元线性回归算法,可以计算出类似的功能。直接上代码吧:
1、定义一个DataPoint类,对X和Y坐标点进行封装:
/**
* File : DataPoint.java
* Author : zhouyujie
* Date : 2012-01-11 16:00:00
* Description : Java实现一元线性回归的算法,座标点实体类,(可实现统计指标的预测)
*/
package com.zyujie.dm;
public class DataPoint {
/** the x value */
public float x;
/** the y value */
public float y;
/**
* Constructor.
*
* @param x
* the x value
* @param y
* the y value
*/
public DataPoint(float x, float y) {
this.x = x;
this.y = y;
}
}
2、下面是算法实现回归线:
/**
* File : DataPoint.java
* Author : zhouyujie
* Date : 2012-01-11 16:00:00
* Description : Java实现一元线性回归的算法,回归线实现类,(可实现统计指标的预测)
*/
package com.zyujie.dm;
import java.math.BigDecimal;
import java.util.ArrayList;
public class RegressionLine // implements Evaluatable
{
/** sum of x */
private double sumX;
/** sum of y */
private double sumY;
/** sum of x*x */
private double sumXX;
/** sum of x*y */
private double sumXY;
/** sum of y*y */
private double sumYY;
/** sum of yi-y */
private double sumDeltaY;
/** sum of sumDeltaY^2 */
private double sumDeltaY2;
/** 误差 */
private double sse;
private double sst;
private double E;
private String[] xy;
private ArrayList listX;
private ArrayList listY;
private int XMin, XMax, YMin, YMax;
/** line coefficient a0 */
private float a0;
/** line coefficient a1 */
private float a1;
/** number of data points */
private int pn;
/** true if coefficients valid */
private boolean coefsValid;
/**
* Constructor.
*/
public RegressionLine() {
XMax = 0;
YMax = 0;
pn = 0;
xy = new String[2];
listX = new ArrayList();
listY = new ArrayList();
}
/**
* Constructor.
*
* @param data
* the array of data points
*/
public RegressionLine(DataPoint data[]) {
pn = 0;
xy = new String[2];
listX = new ArrayList();
listY = new ArrayList();
for (int i = 0; i < data.length; ++i) {
addDataPoint(data[i]);
}
}
/**
* Return the current number of data points.
*
* @return the count
*/
public int getDataPointCount() {
return pn;
}
/**
* Return the coefficient a0.
*
* @return the value of a0
*/
public float getA0() {
validateCoefficients();
return a0;
}
/**
* Return the coefficient a1.
*
* @return the value of a1
*/
public float getA1() {
validateCoefficients();
return a1;
}
/**
* Return the sum of the x values.
*
* @return the sum
*/
public double getSumX() {
return sumX;
}
/**
* Return the sum of the y values.
*
* @return the sum
*/
public double getSumY() {
return sumY;
}
/**
* Return the sum of the x*x values.
*
* @return the sum
*/
public double getSumXX() {
return sumXX;
}
/**
* Return the sum of the x*y values.
*
* @return the sum
*/
public double getSumXY() {
return sumXY;
}
public double getSumYY() {
return sumYY;
}
public int getXMin() {
return XMin;
}
public int getXMax() {
return XMax;
}
public int getYMin() {
return YMin;
}
public int getYMax() {
return YMax;
}
/**
* Add a new data point: Update the sums.
*
* @param dataPoint
* the new data point
*/
public void addDataPoint(DataPoint dataPoint) {
sumX += dataPoint.x;
sumY += dataPoint.y;
sumXX += dataPoint.x * dataPoint.x;
sumXY += dataPoint.x * dataPoint.y;
sumYY += dataPoint.y * dataPoint.y;
if (dataPoint.x > XMax) {
XMax = (int) dataPoint.x;
}
if (dataPoint.y > YMax) {
YMax = (int) dataPoint.y;
}
// 把每个点的具体坐标存入ArrayList中,备用
xy[0] = (int) dataPoint.x + "";
xy[1] = (int) dataPoint.y + "";
if (dataPoint.x != 0 && dataPoint.y != 0) {
System.out.print(xy[0] + ",");
System.out.println(xy[1]);
try {
// System.out.println("n:"+n);
listX.add(pn, xy[0]);
listY.add(pn, xy[1]);
} catch (Exception e) {
e.printStackTrace();
}
/*
* System.out.println("N:" + n); System.out.println("ArrayList
* listX:"+ listX.get(n)); System.out.println("ArrayList listY:"+
* listY.get(n));
*/
}
++pn;
coefsValid = false;
}
/**
* Return the value of the regression line function at x. (Implementation of
* Evaluatable.)
*
* @param x
* the value of x
* @return the value of the function at x
*/
public float at(int x) {
if (pn < 2)
return Float.NaN;
validateCoefficients();
return a0 + a1 * x;
}
/**
* Reset.
*/
public void reset() {
pn = 0;
sumX = sumY = sumXX = sumXY = 0;
coefsValid = false;
}
/**
* Validate the coefficients. 计算方程系数 y=ax+b 中的a
*/
private void validateCoefficients() {
if (coefsValid)
return;
if (pn >= 2) {
float xBar = (float) sumX / pn;
float yBar = (float) sumY / pn;
a1 = (float) ((pn * sumXY - sumX * sumY) / (pn * sumXX - sumX
* sumX));
a0 = (float) (yBar - a1 * xBar);
} else {
a0 = a1 = Float.NaN;
}
coefsValid = true;
}
/**
* 返回误差
*/
public double getR() {
// 遍历这个list并计算分母
for (int i = 0; i < pn - 1; i++) {
float Yi = (float) Integer.parseInt(listY.get(i).toString());
float Y = at(Integer.parseInt(listX.get(i).toString()));
float deltaY = Yi - Y;
float deltaY2 = deltaY * deltaY;
/*
* System.out.println("Yi:" + Yi); System.out.println("Y:" + Y);
* System.out.println("deltaY:" + deltaY);
* System.out.println("deltaY2:" + deltaY2);
*/
sumDeltaY2 += deltaY2;
// System.out.println("sumDeltaY2:" + sumDeltaY2);
}
sst = sumYY - (sumY * sumY) / pn;
// System.out.println("sst:" + sst);
E = 1 - sumDeltaY2 / sst;
return round(E, 4);
}
// 用于实现精确的四舍五入
public double round(double v, int scale) {
if (scale < 0) {
throw new IllegalArgumentException(
"The scale must be a positive integer or zero");
}
BigDecimal b = new BigDecimal(Double.toString(v));
BigDecimal one = new BigDecimal("1");
return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).doubleValue();
}
public float round(float v, int scale) {
if (scale < 0) {
throw new IllegalArgumentException(
"The scale must be a positive integer or zero");
}
BigDecimal b = new BigDecimal(Double.toString(v));
BigDecimal one = new BigDecimal("1");
return b.divide(one, scale, BigDecimal.ROUND_HALF_UP).floatValue();
}
}
3、线性回归测试类:
/**
* File : DataPoint.java
* Author : zhouyujie
* Date : 2012-01-11 16:00:00
* Description : Java实现一元线性回归的算法,线性回归测试类,(可实现统计指标的预测)
*/
package com.zyujie.dm;
/**
*
* Linear Regression
* Demonstrate linear regression by constructing the regression line for a set
* of data points.
*
*
* require DataPoint.java,RegressionLine.java
*
*
* 为了计算对于给定数据点的最小方差回线,需要计算SumX,SumY,SumXX,SumXY; (注:SumXX = Sum (X^2))
*
* 回归直线方程如下: f(x)=a1x+a0
*
* 斜率和截距的计算公式如下:
* n: 数据点个数
*
* a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2)
* a0=(SumY - SumY * a1)/n
* (也可表达为a0=averageY-a1*averageX)
*
*
* 画线的原理:两点成一直线,只要能确定两个点即可
* 第一点:(0,a0) 再随意取一个x1值代入方程,取得y1,连结(0,a0)和(x1,y1)两点即可。
* 为了让线穿过整个图,x1可以取横坐标的最大值Xmax,即两点为(0,a0),(Xmax,Y)。如果y=a1*Xmax+a0,y大于
* 纵坐标最大值Ymax,则不用这个点。改用y取最大值Ymax,算得此时x的值,使用(X,Ymax), 即两点为(0,a0),(X,Ymax)
*
*
* 拟合度计算:(即Excel中的R^2)
*
* *R2 = 1 - E
*
* 误差E的计算:E = SSE/SST
*
* SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n;
*
*/
public class LinearRegression {
private static final int MAX_POINTS = 10;
private double E;
/**
* Main program.
*
* @param args
* the array of runtime arguments
*/
public static void main(String args[]) {
RegressionLine line = new RegressionLine();
line.addDataPoint(new DataPoint(1, 136));
line.addDataPoint(new DataPoint(2, 143));
line.addDataPoint(new DataPoint(3, 132));
line.addDataPoint(new DataPoint(4, 142));
line.addDataPoint(new DataPoint(5, 147));
printSums(line);
printLine(line);
}
/**
* Print the computed sums.
*
* @param line
* the regression line
*/
private static void printSums(RegressionLine line) {
System.out.println("\n数据点个数 n = " + line.getDataPointCount());
System.out.println("\nSum x = " + line.getSumX());
System.out.println("Sum y = " + line.getSumY());
System.out.println("Sum xx = " + line.getSumXX());
System.out.println("Sum xy = " + line.getSumXY());
System.out.println("Sum yy = " + line.getSumYY());
}
/**
* Print the regression line function.
*
* @param line
* the regression line
*/
private static void printLine(RegressionLine line) {
System.out.println("\n回归线公式: y = " + line.getA1() + "x + "
+ line.getA0());
System.out.println("误差: R^2 = " + line.getR());
}
//y = 2.1x + 133.7 2.1 * 6 + 133.7 = 12.6 + 133.7 = 146.3
//y = 2.1x + 133.7 2.1 * 7 + 133.7 = 14.7 + 133.7 = 148.4
}
我们运行测试类,得到运行结果:
1,136
2,143
3,132
4,142
5,147
数据点个数 n = 5
Sum x = 15.0
Sum y = 700.0
Sum xx = 55.0
Sum xy = 2121.0
Sum yy = 98142.0
回归线公式: y = 2.1x + 133.7
误差: R^2 = 0.3658
假如某公司:
1月收入,136万元
2月收入,143万元
3月收入,132万元
4月收入,142万元
5月收入,147万元
我们可以根据回归线公式:y = 2.1x + 133.7,预测出6月份收入:
y = 2.1 * 6 + 133.7 = 12.6 + 133.7 = 146.3