机器学习入门04-线性回归原理与java实现多元线性回归

线性回归原理:

线性回归公式:y = b + w*x,w表示权重b表示偏置。

在实际实现中可以将公式写作:y = w[0] * x[0] + w[1] * x[1],x[0]=1,这样就可以很方便的进行参数求解,同样稍作修改将公式写成:y = w[0] * x[0] + w[1] * x[1] + ... + w[n]*x[n],就变成了多元回归。

采用梯度下降和多次迭代不断优化参数,梯度下降计算参数的梯度,计算流程分为以下几步:

1、根据当前参数和训练计算数据预测值

        preY = sum(w[n] + x[n])

2、计算梯度

        wright_gradient[n] = sum(2 * (preY - y) * x[n] / N),N为训练数据总行数

3、更新参数:

        wright[n] = wright[n] - a * wright_gradient[n],a为学习率,学习率取值范围[0,1],根据训练数据和训练情况来定。

4、迭代

        每迭代一次就多整个训练数据计算一次梯度和更新一次参数,通过迭代使函数不断逼近最小误差。

线性回归的实现(java实现,一元回归和多元回归通用):

        1、读取数据,以csv格式存储,前面几列为x,最后一列为y。

public List readTrainFile(String filepath) {
		File trainFile = new File(filepath);
		List resultList = new ArrayList();
		if (trainFile.exists()) {
			try {
				BufferedReader reader = new BufferedReader(new FileReader(trainFile));
				String line;
				while ((line = reader.readLine()) != null) {
					String[] strs = line.split(",");
					double[] lines = new double[strs.length];
					for (int i = 0; i < strs.length; i++) {
						lines[i] = Double.parseDouble(strs[i]);
					}
					resultList.add(lines);
				}
				reader.close();
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
		return resultList;
	}

        2、训练,需要设置学习率和迭代次数,返回参数数组。

public double[] train(String filepath, double learningRate, int iterationNum) {
		List trainData = readTrainFile(filepath);
		double[] weights = new double[trainData.get(0).length];
		for(int i = 0; i < weights.length; i++) {
			weights[i] = 0;
		}
		weights = updateWeights(trainData, weights, learningRate, iterationNum);
		return weights;
	}

        3、计算权重参数,对数据集每迭代一次,使用梯度下降计算梯度,通过学习率*梯度更新权重。

public double[] updateWeights(List trainData, double[] weights, double learningRate, int iterationNum) {

		for (int i = 0; i < iterationNum; i++) {
			double[] weights_gradient = new double[weights.length];
			for (int j = 0; j < trainData.size(); j++) {
				double[] line = trainData.get(j);
				double[] x = new double[line.length];
				x[0] = 1;
				double y = line[line.length - 1];
				for(int n = 1; n < x.length; n++) {
					x[n] = line[n - 1];
				}
				//根据当前参数和数据预测preY
				double preY = 0.0;
				for(int n = 0; n < weights.length; n++) {
					preY += x[n] * weights[n];
				}
				for(int n = 0; n < weights.length; n++) {
					weights_gradient[n]+=2 * (preY - y) * x[n] / (double)trainData.size();
				}
			}
			//更新参数
			for(int j = 0; j < weights.length; j++) {
				weights[j] = weights[j] - learningRate * weights_gradient[j];
			}
			//每迭代1次,输出loss
			if (i % 100 == 0) {
				double loss = computeError(trainData, weights);
				System.out.println(loss);
			}
		}
		return weights;
}

        4、计算error

public double computeError(List trainData, double[] weights) {
		double error= 0.0;
		for(int i = 0; i < trainData.size(); i++) {
			double[] line = trainData.get(i);
			double preY = 0.0;
			double[] x = new double[line.length];
			x[0] = 1;
			double y = line[line.length - 1];
			for(int n = 1; n < x.length; n++) {
				x[n] = line[n - 1];
			}
			for(int j = 0; j < line.length; j++) {
				preY +=  weights[j] * x[j];
			}
			error += (y - preY) * (y - preY);
		}
		return error / (double)trainData.size();
	}

        5、测试程序,输出计算后参数

public static void main(String[] args) {
		MultiLineRegression lineRegression = new MultiLineRegression();
		double[] weights = lineRegression.train("E:/index/traindata.csv", 0.001, 1000);
		for(double w : weights) {
			System.out.print(w + ",");
		}
	}

你可能感兴趣的:(机器学习,机器学习)