线性回归2——正则化(含代码)

一、频率角度与贝叶斯角度线性回归2——正则化(含代码)_第1张图片

二、代码

'''
我们用多项式来拟合y=sin4x
'''
import numpy as np
from scipy.optimize import leastsq
import matplotlib.pyplot as plt

# 目标函数
def real_func(x):
    return np.sin(4*x)

# 多项式
def fake_func(w, x):
    f = np.poly1d(w)
    return f(x)

# 误差函数
def residuals_func(w, x, y):
    ret = fake_func(w, x) - y
    return ret

# 加入正则化的误差函数
regularization = 0.0001
def regularization_residuals_fun(w, x, y):
    ret = fake_func(w, x) - y
    ret = np.append(ret,np.sqrt(regularization * np.square(w)))
    return ret

# 为了便于观察,加上噪声的十个点
X = np.linspace(0, 1, 10)
Y = [np.random.normal(0, 0.1) + y1 for y1 in real_func(X)]

x_points = np.linspace(0, 1, 1000)

def fitting(P=0):   # P为多项式的次数
    # 随机初始化多项式参数
    p_init = np.random.rand(P + 1)  # 生成p+1个随机数的列表,这样poly1d函数返回的多项式次数就是p(例如y=ax+b,为1次,初始化a,b两项)
    # 最小二乘法
    p_lsq = leastsq(residuals_func, p_init, args=(X, Y))  # # 三个参数:误差函数、函数参数列表、数据点
    print('多项式的参数:', p_lsq[0])
    regularize_p_lsq = leastsq(regularization_residuals_fun, p_init, args=(X, Y))  # # 三个参数:误差函数、函数参数列表、数据点
    print('正则化多项式的参数:', regularize_p_lsq[0])
    # 可视化
    plt.plot(x_points, real_func(x_points), 'blue',label='real line')  # 真实曲线
    plt.plot(x_points, fake_func(p_lsq[0], x_points), 'orange', label='fake line')  # 拟合曲线
    plt.plot(x_points, fake_func(regularize_p_lsq[0], x_points), 'green', label='regularization line')  # 拟合正则化后的曲线
    plt.plot(X, Y, 'ro', label='noise')  # 十个噪点分布
    plt.legend()
    plt.show()
    return p_lsq

end = fitting(P=9)  # 九次多项式

线性回归2——正则化(含代码)_第2张图片
线性回归2——正则化(含代码)_第3张图片

你可能感兴趣的:(机器学习推导,正则化,机器学习)