多项式回归(Polynomial Regression)(附代码)

如果一个方程,自变量的指数大于1,那么所有拟合这个方程的点就符合多项式回归。

figure 7

多项式回归有个很重要的因素就是指数(degree)。如果我们发现数据的分布大致是一条曲线,那么很可能符合多项式回归,但是我们不知道degree是多少。所以我们只能一个个去试,直到找到最拟合分布的degree。这个过程我们可以交给数据科学软件完成。需要注意的是,如果degree选择过大的话可能会导致函数过于拟合, 意味着对数据或者函数未来的发展很难预测,也许指向不同的方向。

这个回归的计算需要用到矩阵数据结构。有的编程语言可能需要导入外库。

figure 8

我们对所有拟合这个公式的点,用矩阵表示他们的关系

figure 9

如果用矩阵符号表示:

figure 10

多项式回归向量的系数(使用最小二乘法):

figure 11

Java 和 Python 代码如下:

//需要安装jama包,这里是下载地址: http://math.nist.gov/javanumerics/jama/
import Jama.Matrix;
import Jama.QRDecomposition;

public class PR {

    private final int N;
    private final int degree;
    private final Matrix beta;
    private double SSE;
    private double SST;

    public PR(double[] x, double[] y, int degree) {
        this.degree = degree;
        N = x.length;

        // build Vandermonde matrix
        double[][] vandermonde = new double[N][degree+1];
        for (int i = 0; i < N; i++) {
            for (int j = 0; j <= degree; j++) {
                vandermonde[i][j] = Math.pow(x[i], j);
            }
        }
        Matrix X = new Matrix(vandermonde);

        // 从向量中增加一个矩阵
        Matrix Y = new Matrix(y, N);

        // 找到最小的平方值
        QRDecomposition qr = new QRDecomposition(X);
        beta = qr.solve(Y);


        // 得到y的平均值
        double sum = 0.0;
        for (int i = 0; i < N; i++)
            sum += y[i];
        double mean = sum / N;

        // total variation to be accounted for
        for (int i = 0; i < N; i++) {
            double dev = y[i] - mean;
            SST += dev*dev;
        }

        // variation not accounted for
        Matrix residuals = X.times(beta).minus(Y);
        SSE = residuals.norm2() * residuals.norm2();
    }

    public double beta(int j) {
        return beta.get(j, 0);
    }

    public int degreee() {
        return degree;
    }

    public double R2() {
        return 1.0 - SSE/SST;
    }

    public double predict(double x) {

        double y = 0.0;
        for (int j = degree; j>=0; j--) {
            y = beta(j) + (x*y);
        }
        return y;
    }

    public String toString() {
        String s = "";
        int j = degree;

        // 忽略系数为0.
        while (Math.abs(beta(j)) < 1E-5)
            j--;

        // create remaining terms
        for (j = j; j >= 0; j--) {
            if      (j == 0) s += String.format("%.2f ", beta(j));
            else if (j == 1) s += String.format("%.2f N + ", beta(j));
            else             s += String.format("%.2f N^%d + ", beta(j), j);
        }
        return s + "  (R^2 = " + String.format("%.3f", R2()) + ")";
    }

}

ref:
Java代码使用了《算法》中的代码,可以在普林斯顿的算法课上下载:Polynomial Regression

你可能感兴趣的:(statistics)