统计学习方法 p12 多项式拟合 python实现

看李航老师《统计学习方法》这本书,第12页举了一个多项式拟合的问题,自己怎么都推导不出来,上网查发现书上有误。
拟合问题描述如下:

统计学习方法 p12 多项式拟合 python实现_第1张图片
统计学习方法 p12 多项式拟合 python实现_第2张图片

书上的推导就不贴了。
正确推导在知乎上有具体解答:

统计学习方法 p12 多项式拟合 python实现_第3张图片

W即通过求解方程得到

上面方程矩阵展开如下:
统计学习方法 p12 多项式拟合 python实现_第4张图片

按照这个思路的python代码实现如下:

import numpy as np
import matplotlib.pyplot as plt

# 原始曲线
x_plot = np.linspace(-0.05, 1.05, 100)
y_plot = 6 * x_plot ** 3 - 5 * x_plot ** 2 + 2
# 样本点
x_train = np.linspace(0, 1, 10)
y_train = 6 * x_train ** 3 - 5 * x_train ** 2 + 2
# 加入噪声
x_train_noise = x_train * (1 + (np.random.random(10) - 0.5) * 0.2)
y_train_noise = y_train * (1 + (np.random.random(10) - 0.5) * 0.2)

# 拟合函数
def curve_fitting(order, x_train_noise, y_train_noise):
    # 求各阶xi的值 
    x_element = np.ones((2*order+1, len(x_train_noise)))
    for i in range(2*order):
        x_element[i+1,:] = x_element[i, :] * x_train_noise
    # 求各阶xi值的和
    x_element_sum = x_element.sum(axis=1)
    # 构建XTX矩阵
    x_matrix = np.ones((order+1, order+1))
    for i in range(x_matrix.shape[0]):
        x_matrix[i, :] = np.asarray(x_element_sum[i:i+order+1])
    
    # 构建y矩阵
    y_matrix = x_element[:order+1, :] @ np.asarray(y_train_noise)
    
    # 求W矩阵
    w_matrix = np.linalg.solve(x_matrix, y_matrix)
    
    # 求拟合曲线上的坐标点
    x_split = np.linspace(-0.02, 1.0, 100)
    y_split = []
    for x in x_split:
        x_acum = 1
        y = 0
        for i in range(order+1):
            y += x_acum * w_matrix[i]
            x_acum *= x
        y_split.append(y)
        
    return x_split, y_split


x_order_3, y_order_3 = curve_fitting(3, x_train_noise, y_train_noise)
x_order_9, y_order_9 = curve_fitting(9, x_train_noise, y_train_noise)

# 绘图
plt.plot(x_plot, y_plot, label='original')
plt.scatter(x_train_noise, y_train_noise)
plt.plot(x_order_3, y_order_3, label='order=3')
plt.plot(x_order_9, y_order_9, label='order=9')
plt.legend()

绘图结果如下:
统计学习方法 p12 多项式拟合 python实现_第5张图片

吴恩达老师的机器学习的视频中也举了这么个例子,看视频时有一丝怀疑高阶多项式拟合真的像老师手画的那样么,通过python实现确认就是那样的!

参考:

  1. 李航. (2012). 统计学习方法. 清华大学出版社. 北京
  2. https://www.zhihu.com/question/23483726
  3. https://blog.csdn.net/xiaolewennofollow/article/details/46757657

你可能感兴趣的:(统计学习方法 p12 多项式拟合 python实现)