在python中实现拟合很方便,使用curve_fit,填好公式,样本数据和结果集,初始猜想和边界,很快就能实现,如下示例:
curve_fit(fit_function, a_tuple, b, p0=init_guess, bounds=(lb_tuple, ub_tuple), maxfev=1000)
fit_function如:theta1* x1 + theta2 * x2 + …
根据需要调整公式,可以实现多元线性拟合或非线性拟合。
在java中没有找到像python那样便捷地实现拟合的方法。针对多元线性拟合,经过多方面查找和实验,有一些小小的经验,记录备查。
着重感谢下面文章的作者,基于他们的文章学到了很多知识:
https://zhuanlan.zhihu.com/p/25765735
https://www.cnblogs.com/donaldlee2008/p/5861796.html
https://my.oschina.net/u/1778239/blog/1858397
后面2篇文章的代码都能正常运行,得到的结果类似,测试过程中对代码做了的一些调整,记录备查。
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
public class LinearRegression {
/*
* 二元训练数据示例:
* x0 x1 x2 y
0.0 1.0 2.0 7.2
0.0 2.0 1.0 4.9
0.0 3.0 0.0 2.6
0.0 4.0 1.0 6.3
0.0 5.0 -1.0 1.0
0.0 6.0 0.0 4.7
0.0 7.0 -2.0 -0.6
三元训练数据示例:
0.0,0.8672012113286763,0.9497918028837882,0.9497918028837882,1
0.0,1,0.13043478260869565,0.22608695652173913,0.5339804347826087
0.0,1,0.8,0.9,0.7957633333333335
0.0,1,1,1,1
注意!!!!x1,x2,x3 ……,y 列是用户实际输入的数据,x0是为了推导出来的公式统一,特地补上的一列。
x0,x1,x2,x3 ……是“特征”,y是结果
h(x) = theta0 * x0 + theta1* x1 + theta2 * x2 + theta3 * x3 ……
theta0,theta1,theta2,theta2,theta3…… 是想要训练出来的参数
此程序采用“梯度下降法”
*
*/
private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 y
private int row;//训练数据 行数
private int column;//训练数据 列数
private double [] theta;//参数theta
private double alpha;//训练步长
private int iteration;//迭代次数
public LinearRegression(String fileName)
{
int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的 行数
int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的 列数
trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
this.row=rowoffile;
this.column=columnoffile+1;
this.alpha = 0.001;//步长默认为0.001
this.iteration=100000;//迭代次数默认为 100000
theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
initialize_theta(0.5);
loadTrainDataFromFile(fileName,rowoffile,columnoffile);
}
public LinearRegression(String fileName,double alpha,int iteration)
{
int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的 行数
int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的 列数
trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于0
this.row=rowoffile;
this.column=columnoffile+1;
this.alpha = alpha;
this.iteration=iteration;
theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
initialize_theta(1.0/3.0);
loadTrainDataFromFile(fileName,rowoffile,columnoffile);
}
private int getRowNumber(String fileName)
{
int count =0;
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
while ( reader.readLine() != null)
count++;
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return count;
}
private int getColumnNumber(String fileName)
{
int count =0;
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = reader.readLine();
if(tempString!=null)
count = tempString.split(",").length;
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return count;
}
private void initialize_theta(double init_guess)//将theta各个参数全部初始化为init_guess
{
for(int i=0;i0 )
{
//对每个theta i 求 偏导数
double [] partial_derivative = compute_partial_derivative();//偏导数
theta[0] = 0.0;//如果不需要第一个theta0,将这个值设置为0
//更新每个theta
for(int i =0; i< theta.length;i++) {
double tmpTheta = theta[i]-alpha * partial_derivative[i];
//加入了边界值,超过边界值则不再处理。如果不需要边界值,直接赋值即可。
if(tmpTheta <1 && tmpTheta >-1){
theta[i] = tmpTheta;
}
}
}
double[] thetaResult = new double[theta.length-1];
for(int i=1; i
}
去除theta0的影响:
由于不需要theta0,所以在初始化时,将样本数据的第一列设置为0,将theta0设置为0,每次求偏导后再次将theta0设置为0,去除该值的影响。
如果实际公式中需要theta0,则把样本数据的第一列设置为1.0,再把theta0赋0的代码去掉即可。
步长alpha和迭代次数iteration:
步长和迭代次数很重要,需要根据实际情况进行调整。
三元时,测试结果如下:
如果使用0.01的步长,大概100W次,才能得到与python类似的结果;
如果使用0.1的步长,大概10W次,可以得到与python类似的结果。
关于初始猜想值:
初始猜想值,原始代码设置为了1.0,根据实际需要,代码中调整为了 1/变量个数。
关于边界:
在python中,可以设置数据的上限和下限,在该代码中没有包含。
调整代码后,比较简单粗暴地进行了范围判断,超过范围后就不再赋值。如果哪位大侠有更好的判断数据范围的方法,请告知,不尽感谢!