Java实现一元线性回归

Java实现一元线性回归

 
发布时间:2006.04.28 22:26     来源:月光软件站    作者:

 

 

最近在写一个荧光图像分析软件,需要自己拟合方程。一元回归线公式的算法参考了《Java数值方法》,拟合度R^2(绝对系数)是自己写的,欢迎讨论。计算结果和Excel完全一致。

总共三个文件:

DataPoint.java

/**
 * A data point for interpolation and regression.
 */
public class DataPoint
{
    /** the x value */  public float x;
    /** the y value */  public float y;

    /**
     * Constructor.
     * @param x the x value
     * @param y the y value
     */
    public DataPoint(float x, float y)
    {
        this.x = x;
        this.y = y;

    }
}

/**
 * A least-squares regression line function.
 */

import java.util.*;
import java.math.BigDecimal;

public class RegressionLine 
 //implements Evaluatable
{
    /** sum of x */     private double sumX;
    /** sum of y */     private double sumY;
    /** sum of x*x */   private double sumXX;
    /** sum of x*y */   private double sumXY;
    /** sum of y*y */   private double sumYY;
    /** sum of yi-y */   private double sumDeltaY;
    /** sum of sumDeltaY^2 */   private double sumDeltaY2;
    /**误差 */
    private double sse;  
    private double sst;  
    private double E;
    private String[] xy ;
    
    private ArrayList listX ;
    private ArrayList listY ;
    
    private int XMin,XMax,YMin,YMax;
    
    /** line coefficient a0 */  private float a0;
    /** line coefficient a1 */  private float a1;

    /** number of data points */        private int     pn ;
    /** true if coefficients valid */   private boolean coefsValid;

    /**
     * Constructor.
     */
    public RegressionLine() {
     XMax = 0;
     YMax = 0;
     pn = 0;
     xy =new String[2];
     listX = new ArrayList();
     listY = new ArrayList();
    }

    /**
     * Constructor.
     * @param data the array of data points
     */
    public RegressionLine(DataPoint data[])
    { 
     pn = 0;
     xy =new String[2];
     listX = new ArrayList();
     listY = new ArrayList();
        for (int i = 0; i < data.length; ++i) {
            addDataPoint(data[i]);
        }
    }

    /**
     * Return the current number of data points.
     * @return the count
     */
    public int getDataPointCount() { return pn; }

    /**
     * Return the coefficient a0.
     * @return the value of a0
     */
    public float getA0()
    {
        validateCoefficients();
        return a0;
    }

    /**
     * Return the coefficient a1.
     * @return the value of a1
     */
    public float getA1()
    {
        validateCoefficients();
        return a1;
    }

    /**
     * Return the sum of the x values.
     * @return the sum
     */
    public double getSumX() { return sumX; }

    /**
     * Return the sum of the y values.
     * @return the sum
     */
    public double getSumY() { return sumY; }

    /**
     * Return the sum of the x*x values.
     * @return the sum
     */
    public double getSumXX() { return sumXX; }

    /**
     * Return the sum of the x*y values.
     * @return the sum
     */
    public double getSumXY() { return sumXY; }
    
    public double getSumYY() { return sumYY; }
    
    public int getXMin() {
  return XMin;
 }

 public int getXMax() {
  return XMax;
 }

 public int getYMin() {
  return YMin;
 }

 public int getYMax() {
  return YMax;
 }
    
    /**
     * Add a new data point: Update the sums.
     * @param dataPoint the new data point
     */
    public void addDataPoint(DataPoint dataPoint)
    {
        sumX  += dataPoint.x;
        sumY  += dataPoint.y;
        sumXX += dataPoint.x*dataPoint.x;
        sumXY += dataPoint.x*dataPoint.y;
        sumYY += dataPoint.y*dataPoint.y;
        
        if(dataPoint.x > XMax){
         XMax = (int)dataPoint.x;
        }
        if(dataPoint.y > YMax){
         YMax = (int)dataPoint.y;
        }
        
        //把每个点的具体坐标存入ArrayList中,备用
        
        xy[0] = (int)dataPoint.x+ "";
        xy[1] = (int)dataPoint.y+ "";
        if(dataPoint.x!=0 && dataPoint.y != 0){
        System.out.print(xy[0]+",");
        System.out.println(xy[1]);        
        
        try{
        //System.out.println("n:"+n);
        listX.add(pn,xy[0]);
        listY.add(pn,xy[1]);
        }
        catch(Exception e){
         e.printStackTrace();
        }                
        
        /*
        System.out.println("N:" + n);
        System.out.println("ArrayList listX:"+ listX.get(n));
        System.out.println("ArrayList listY:"+ listY.get(n));
        */
        }        
        ++pn;
        coefsValid = false;
     }

    /**
     * Return the value of the regression line function at x.
     * (Implementation of Evaluatable.)
     * @param x the value of x
     * @return the value of the function at x
     */
    public float at(int x)
    {
        if (pn < 2) return Float.NaN;

        validateCoefficients();
        return a0 + a1*x;
    }
    
    public float at(float x)
    {
        if (pn < 2) return Float.NaN;

        validateCoefficients();
        return a0 + a1*x;
    }

    /**
     * Reset.
     */
    public void reset()
    {
        pn = 0;
        sumX = sumY = sumXX = sumXY = 0;
        coefsValid = false;
    }

    /**
     * Validate the coefficients.
     * 计算方程系数 y=ax+b 中的a
     */
    private void validateCoefficients()
    {
        if (coefsValid) return;

        if (pn >= 2) {
            float xBar = (float) sumX/pn;
            float yBar = (float) sumY/pn;

            a1 = (float) ((pn*sumXY - sumX*sumY)
                            /(pn*sumXX - sumX*sumX));
            a0 = (float) (yBar - a1*xBar);
        }
        else {
            a0 = a1 = Float.NaN;
        }

        coefsValid = true;
    }
    
    /**
     * 返回误差
     */
    public double getR(){   
     //遍历这个list并计算分母
     for(int i = 0; i < pn -1; i++)    {         
      float Yi= (float)Integer.parseInt(listY.get(i).toString());
      float Y = at(Integer.parseInt(listX.get(i).toString())); 
      float deltaY = Yi - Y;    
      float deltaY2 = deltaY*deltaY;
      /*
      System.out.println("Yi:" + Yi);
      System.out.println("Y:" + Y);
      System.out.println("deltaY:" + deltaY);
      System.out.println("deltaY2:" + deltaY2);
      */
          
         sumDeltaY2 += deltaY2;
         //System.out.println("sumDeltaY2:" + sumDeltaY2);
         
     }     
      
     sst = sumYY - (sumY*sumY)/pn;     
        //System.out.println("sst:" + sst);
     E =1- sumDeltaY2/sst;
     
     
     return round(E,4) ;
    }
    
    //用于实现精确的四舍五入
    public double round(double v,int scale){

     if(scale<0){
     throw new IllegalArgumentException(
     "The scale must be a positive integer or zero");
     }
     
     BigDecimal b = new BigDecimal(Double.toString(v));
     BigDecimal one = new BigDecimal("1");
     return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).doubleValue();

    }    
    
    public  float round(float v,int scale){

     if(scale<0){
     throw new IllegalArgumentException(
     "The scale must be a positive integer or zero");
     }
     
     BigDecimal b = new BigDecimal(Double.toString(v));
     BigDecimal one = new BigDecimal("1");
     return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).floatValue();

    }    
}

演示程序:

LinearRegression.java

/**
 * <p><b>Linear Regression</b>
 * <br> 
 * Demonstrate linear regression by constructing the regression line for a set
 * of data points.
 * 
 * <p>require DataPoint.java,RegressionLine.java 
 * 
 * <p>为了计算对于给定数据点的最小方差回线,需要计算SumX,SumY,SumXX,SumXY; (注:SumXX = Sum (X^2))
 * <p><b>回归直线方程如下: f(x)=a1x+a0   </b>
 * <p><b>斜率和截距的计算公式如下:</b>
 * <br>n: 数据点个数
 * <p>a1=(n(SumXY)-SumX*SumY)/(n*SumXX-(SumX)^2)
 * <br>a0=(SumY - SumY * a1)/n 
 * <br>(也可表达为a0=averageY-a1*averageX)
 * 
 * <p><b>画线的原理:两点成一直线,只要能确定两个点即可</b><br>
 *  第一点:(0,a0) 再随意取一个x1值代入方程,取得y1,连结(0,a0)和(x1,y1)两点即可。
 * 为了让线穿过整个图,x1可以取横坐标的最大值Xmax,即两点为(0,a0),(Xmax,Y)。如果y=a1*Xmax+a0,y大于
 * 纵坐标最大值Ymax,则不用这个点。改用y取最大值Ymax,算得此时x的值,使用(X,Ymax), 即两点为(0,a0),(X,Ymax)
 * 
 * <p><b>拟合度计算:(即Excel中的R^2)</b>
 * <p> *R2 = 1 - E
 * <p>误差E的计算:E = SSE/SST
 * <p>SSE=sum((Yi-Y)^2) SST=sumYY - (sumY*sumY)/n;
 * <p> 
 */
public class LinearRegression
{
    private static final int MAX_POINTS = 10;
    private double E;

    /**
  * Main program.
  * 
  * @param args
  *            the array of runtime arguments
  */
    public static void main(String args[])
    {
        RegressionLine line = new RegressionLine();

        line.addDataPoint(new DataPoint(20, 136));
        line.addDataPoint(new DataPoint(40, 143));
        line.addDataPoint(new DataPoint(60, 152));
        line.addDataPoint(new DataPoint(80, 162));
        line.addDataPoint(new DataPoint(100, 167));
        
        printSums(line);
        printLine(line);
    }

    /**
  * Print the computed sums.
  * 
  * @param line
  *            the regression line
  */
    private static void printSums(RegressionLine line)
    {
        System.out.println("\n数据点个数 n = " + line.getDataPointCount());
        System.out.println("\nSum x  = " + line.getSumX());
        System.out.println("Sum y  = " + line.getSumY());
        System.out.println("Sum xx = " + line.getSumXX());
        System.out.println("Sum xy = " + line.getSumXY());
        System.out.println("Sum yy = " + line.getSumYY());       
        
    }

    /**
  * Print the regression line function.
  * 
  * @param line
  *            the regression line
  */
    private static void printLine(RegressionLine line)
    {
        System.out.println("\n回归线公式:  y = " +
                           line.getA1() +
                           "x + " + line.getA0());
        System.out.println("拟合度:     R^2 = " + line.getR());
    } 
    
}

 

你可能感兴趣的:(java实现)