Java算法一:线性回归方程式拟合数据曲线及预测数据

摘要

        在Java中,使用线性回归算法,基于已有的数据拟合出回归方式式趋势图,及预测数据。

        该算法,可通过传入项数的最高次N,来拟合出对应的二元N次方程式。得到方程式以后,可通过传入X数据,来计算出对应的Y轴数据。

package com.unkown.orchestrator.controller;

import com.alibaba.fastjson.JSONObject;
import org.apache.commons.math3.fitting.PolynomialCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoints;

import java.lang.reflect.Array;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

/**
 * @Description: 
 * @author:lvxiaobu
 * @date:2023/8/1 13:06
 */
public class PolynomialRegression {

    private PolynomialCurveFitter fitter; //最高项次数
    private double[] coefficients; // 各项数的常数值

    public PolynomialRegression(int degree) {
        fitter = PolynomialCurveFitter.create(degree);
    }

    public void fit(List xData, List yData) {
        WeightedObservedPoints points = new WeightedObservedPoints();
        for (int i = 0; i < xData.size(); i++) {
            points.add(xData.get(i), yData.get(i));
        }

        // 计算各项的常数项,如 y=ax^2 + bx +c 中的a、b、c
        coefficients = fitter.fit(points.toList());

        // 输出拟合后的公式
        String fun = "f(x) = ";
        for (int i = coefficients.length - 1; i >= 0; i--) {
            String add = coefficients[i] > 0 ? "+" : "";
            String x = i > 0 ? "x^" + i : "";
            if (i == coefficients.length - 1) {
                fun += (coefficients[i] + x);
            } else {
                fun += (add + coefficients[i] + x);
            }
        }
        System.out.println("拟合公式为:"+fun);
    }

    /**
     * @Description: 基于方程式 及 传入的X数据,计算对应的Y轴数据
     * @author:lvxiaobu
     * @date:2023/8/1 13:18
     */
    public List predict(List preX) {
        DecimalFormat df = new DecimalFormat("#.00");
        List preY = new ArrayList<>();
        for (int index = 0;index < preX.size(); index++){
            double y = (double) 0;
            for (int i = 0; i < coefficients.length; i++) {
                y += coefficients[i] * Math.pow(preX.get(index), i);
            }
            y = Double.parseDouble(df.format(y));
            preY.add(y);
        }

        return preY;
    }

    public static void main(String[] args) {
        // 提供已有数据
        double[] xData = {2, 4, 6, 8, 10,12,14};
        List xDatas = Arrays.stream(xData).boxed().collect(Collectors.toList());
        double[] yData = {11.20, 13.40, 17.60, 24.80, 30, 38,49,52};
        List yDatas = Arrays.stream(yData).boxed().collect(Collectors.toList());

        // 声明生成的线性回归方程式的最高项次数
        PolynomialRegression regression = new PolynomialRegression(2); // 生成
        regression.fit(xDatas, yDatas); // 计算方程式中的各项的常数值.如 y=ax^2 + bx +c 中的a、b、c

        // 提供需要基于方程式计算的x数据
        double[] preXData = {2, 4, 6, 8, 10,12,14,16,18,20};
        List preXDatas = Arrays.stream(preXData).boxed().collect(Collectors.toList());
        // 预测Y轴对应数据
        List preY = regression.predict(preXDatas);
        System.out.println(JSONObject.toJSONString(preY));
    }
}

你可能感兴趣的:(JAVA,算法,java,线性回归)