对应python中curve_fit的多元线性回归java实现

对应python中curve_fit的多元线性回归java实现

  • python中的拟合方法
  • java中实现多元线性拟合方法
    • 参考文章
    • 源代码及说明
    • 关于代码中一些参数的说明

python中的拟合方法

在python中实现拟合很方便,使用curve_fit,填好公式,样本数据和结果集,初始猜想和边界,很快就能实现,如下示例:
curve_fit(fit_function, a_tuple, b, p0=init_guess, bounds=(lb_tuple, ub_tuple), maxfev=1000)
fit_function如:theta1* x1 + theta2 * x2 + …
根据需要调整公式,可以实现多元线性拟合或非线性拟合。

java中实现多元线性拟合方法

在java中没有找到像python那样便捷地实现拟合的方法。针对多元线性拟合,经过多方面查找和实验,有一些小小的经验,记录备查。

参考文章

着重感谢下面文章的作者,基于他们的文章学到了很多知识:
https://zhuanlan.zhihu.com/p/25765735
https://www.cnblogs.com/donaldlee2008/p/5861796.html
https://my.oschina.net/u/1778239/blog/1858397

后面2篇文章的代码都能正常运行,得到的结果类似,测试过程中对代码做了的一些调整,记录备查。

源代码及说明

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;

public class LinearRegression {

/*
 * 二元训练数据示例:
 *   x0        x1        x2        y
    0.0       1.0       2.0       7.2
    0.0       2.0       1.0       4.9
    0.0       3.0       0.0       2.6
    0.0       4.0       1.0       6.3
    0.0       5.0      -1.0       1.0
    0.0       6.0       0.0       4.7
    0.0       7.0      -2.0      -0.6
	三元训练数据示例:
	0.0,0.8672012113286763,0.9497918028837882,0.9497918028837882,1
	0.0,1,0.13043478260869565,0.22608695652173913,0.5339804347826087
	0.0,1,0.8,0.9,0.7957633333333335
	0.0,1,1,1,1

    注意!!!!x1,x2,x3 ……,y 列是用户实际输入的数据,x0是为了推导出来的公式统一,特地补上的一列。
    x0,x1,x2,x3 ……是“特征”,y是结果

    h(x) = theta0 * x0 + theta1* x1 + theta2 * x2 + theta3 * x3 ……
    theta0,theta1,theta2,theta2,theta3…… 是想要训练出来的参数
     此程序采用“梯度下降法”
 *
 */

private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 y
private int row;//训练数据  行数
private int column;//训练数据 列数
private double [] theta;//参数theta
private double alpha;//训练步长
private int iteration;//迭代次数

public LinearRegression(String fileName)
{
    int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数
    int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数

    trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
    this.row=rowoffile;
    this.column=columnoffile+1;

    this.alpha = 0.001;//步长默认为0.001
    this.iteration=100000;//迭代次数默认为 100000

    theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
    initialize_theta(0.5);

    loadTrainDataFromFile(fileName,rowoffile,columnoffile);
}
public LinearRegression(String fileName,double alpha,int iteration)
{
    int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数
    int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数

    trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于0
    this.row=rowoffile;
    this.column=columnoffile+1;

    this.alpha = alpha;
    this.iteration=iteration;

    theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
    initialize_theta(1.0/3.0);

    loadTrainDataFromFile(fileName,rowoffile,columnoffile);
}

private int getRowNumber(String fileName)
{
    int count =0;
    File file = new File(fileName);
    BufferedReader reader = null;
    try {
        reader = new BufferedReader(new FileReader(file));
        while ( reader.readLine() != null)
            count++;
        reader.close();
    } catch (IOException e) {
        e.printStackTrace();
    } finally {
        if (reader != null) {
            try {
                reader.close();
            } catch (IOException e1) {
            }
        }
    }
    return count;

}

private int getColumnNumber(String fileName)
{
    int count =0;
    File file = new File(fileName);
    BufferedReader reader = null;
    try {
        reader = new BufferedReader(new FileReader(file));
        String tempString = reader.readLine();
        if(tempString!=null)
            count = tempString.split(",").length;
        reader.close();
    } catch (IOException e) {
        e.printStackTrace();
    } finally {
        if (reader != null) {
            try {
                reader.close();
            } catch (IOException e1) {
            }
        }
    }
    return count;
}

private void initialize_theta(double init_guess)//将theta各个参数全部初始化为init_guess
{
    for(int i=0;i0 )
    {
        //对每个theta i 求 偏导数
        double [] partial_derivative = compute_partial_derivative();//偏导数
        theta[0] = 0.0;//如果不需要第一个theta0,将这个值设置为0
        //更新每个theta
        for(int i =0; i< theta.length;i++) {
            double tmpTheta = theta[i]-alpha * partial_derivative[i];
            //加入了边界值,超过边界值则不再处理。如果不需要边界值,直接赋值即可。
            if(tmpTheta <1 && tmpTheta >-1){
                theta[i] = tmpTheta;
            }
        }
    }

    double[] thetaResult = new double[theta.length-1];
    for(int i=1; i

}

关于代码中一些参数的说明

去除theta0的影响:
由于不需要theta0,所以在初始化时,将样本数据的第一列设置为0,将theta0设置为0,每次求偏导后再次将theta0设置为0,去除该值的影响。
如果实际公式中需要theta0,则把样本数据的第一列设置为1.0,再把theta0赋0的代码去掉即可。

步长alpha和迭代次数iteration:
步长和迭代次数很重要,需要根据实际情况进行调整。
三元时,测试结果如下:
如果使用0.01的步长,大概100W次,才能得到与python类似的结果;
如果使用0.1的步长,大概10W次,可以得到与python类似的结果。

关于初始猜想值:
初始猜想值,原始代码设置为了1.0,根据实际需要,代码中调整为了 1/变量个数。

关于边界:
在python中,可以设置数据的上限和下限,在该代码中没有包含。
调整代码后,比较简单粗暴地进行了范围判断,超过范围后就不再赋值。如果哪位大侠有更好的判断数据范围的方法,请告知,不尽感谢!

你可能感兴趣的:(python转java,java学习记录,java,python)