利用Python实现多项式的曲线拟合。
假设训练集由x的N次观测 x 1 , x 2 , . . . , x n x_1,x_2,...,x_n x1,x2,...,xn得到,x均匀分布于区间[0,1]。对应的观测集为 t 1 , t 2 , . . . , t n t_1,t_2,...,t_n t1,t2,...,tn,目标函数为 s i n ( 2 π x ) sin(2πx) sin(2πx)。
所以,为了通过训练集和观测集拟合出预测函数,使其尽可能接近目标函数,我们通过训练集加上随机高斯噪声输入到目标函数得到。
首先,图一中分别绘制了 t = s i n ( 2 π x ) t = sin(2πx) t=sin(2πx) 标准曲线(如绿线所示)和添加了噪声的观测集(样本包含10个点,如蓝点所示)。
import numpy as np
import matplotlib.pyplot as plt
#标准曲线
x = np.linspace(0, 1, 100)
t = np.sin(2 * np.pi * x)
#采样函数
def get_data(N):
x_n = np.linspace(0,1,N)
t_n = np.sin(2 * np.pi * x_n) + np.random.normal(scale=0.15, size=N) #add Gaussian Noise
return x_n, t_n
#绘制部分组件函数
def draw_ticks():
plt.tick_params(labelsize=15)
plt.xticks(np.linspace(0, 1, 2))
plt.yticks(np.linspace(-1, 1, 3))
plt.ylim(-1.5, 1.5)
font = {'family':'Times New Roman','size':20}
plt.xlabel('x', font)
plt.ylabel('t',font, rotation='horizontal')
#采样
x_10, t_10 = get_data(10)
#图像绘制部分
plt.figure(1, figsize=(8,5))
plt.plot(x, t, 'g',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3, label="training data")
draw_ticks()
plt.title('Figure 1 : sample curve', font)
plt.savefig('1.png', dpi=400)
绿色的曲线为要拟合的目标函数。然后,使用多项式函数来拟合生成的数据。多项式定义如下:
y ( x , w ) = w 0 + w 1 x + w 2 x 2 + . . . + w M x M = ∑ j = 1 M w j x j y(x,w)=w_0 +w_1x+w_2x^2+...+w_Mx^M=\sum\limits_{j=1}^{M}w_jx^j y(x,w)=w0+w1x+w2x2+...+wMxM=j=1∑Mwjxj
M是多项式的阶数,ω0,…,ωM 是多项式的系数,记为W。然后使用均方误差作为误差函数对拟合出的多项式进行评估,公式如下:
E ( W ) = 1 2 ∑ n = 1 N ( y ( x n , W ) − t n ) 2 = 1 2 ( X W − T ) T ( X W − T ) E(W)=\frac{1}{2}\sum\limits_{n=1}^{N}(y(x_n,W)-t_n)^2=\frac{1}{2}(XW-T)^T(XW-T) E(W)=21n=1∑N(y(xn,W)−tn)2=21(XW−T)T(XW−T)
表示为矩阵形式:
W = [ w 0 w 1 ⋮ w m ] , X = [ 1 x 1 ⋯ x 1 m 1 x 2 ⋯ x 2 m ⋮ ⋮ ⋱ ⋮ 1 x n ⋯ x n m ] W= \left[ \begin{matrix} w_0 \\ w_1 \\ \vdots\\ w_m \end{matrix} \right],X= \left[ \begin{matrix} 1 & x_1 & \cdots & x_1^m \\ 1 & x_2 & \cdots & x_2^m \\ \vdots & \vdots & \ddots & \vdots \\ 1 & x_n & \cdots & x_n^m \\ \end{matrix} \right] W=⎣⎢⎢⎢⎡w0w1⋮wm⎦⎥⎥⎥⎤,X=⎣⎢⎢⎢⎡11⋮1x1x2⋮xn⋯⋯⋱⋯x1mx2m⋮xnm⎦⎥⎥⎥⎤
拟合数据的目的即为最小化误差函数,因为误差函数是多项式系数W的二次函数,所以存在唯一最小值,且在导数为零处取得。对W求导并令导数为零得到:
∂ E ( W ) ∂ W = X T X W − X T T \frac{\partial E(W)}{\partial W}=X^TXW-X^TT ∂W∂E(W)=XTXW−XTT
W = ( X T X ) − 1 X T T W = (X^TX)^{-1}X^TT W=(XTX)−1XTT
故可以通过矩阵运算得到W。
#拟合函数(lamda默认为0,即无正则项)
def regress(M, N, x, x_n, t_n, lamda=0):
print("-----------------------M=%d, N=%d-------------------------" %(M,N))
order = np.arange(M+1)
order = order[:, np.newaxis]
e = np.tile(order, [1,N])
XT = np.power(x_n, e)
X = np.transpose(XT)
a = np.matmul(XT, X) + lamda*np.identity(M+1) #X.T * X
b = np.matmul(XT, t_n) #X.T * T
w = np.linalg.solve(a,b) #aW = b => (X.T * X) * W = X.T * T
print("W:")
print(w)
e2 = np.tile(order, [1,x.shape[0]])
XT2 = np.power(x, e2)
p = np.matmul(w, XT2)
return p
分别选择 M = 0, 1, 3, 9 不同多项式阶数对数据进行拟合。图中红线为拟合结果。
#M=0, N=10
p = regress(0, 10, x, x_10, t_10)
#图像绘制部分
plt.figure(2, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.title('Figure 2 : M = 0, N = 10', font)
plt.text(0.8, 0.9,'M = 0', font, style = 'italic')
plt.savefig('2.png', dpi=400)
#M=1, N=10
p = regress(1, 10, x, x_10, t_10)
#图像绘制部分
plt.figure(3, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.title('Figure 3 : M = 1, N = 10', font)
plt.text(0.8, 0.9,'M = 1', font, style = 'italic')
plt.savefig('3.png', dpi=400)
#M=3, N=10
p = regress(3, 10, x, x_10, t_10)
#图像绘制部分
plt.figure(4, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.title('Figure 4 : M = 3, N = 10', font)
plt.text(0.8, 0.9,'M = 3', font, style = 'italic')
plt.savefig('4.png', dpi=400)
#M=9, N=10
p = regress(9, 10, x, x_10, t_10)
#图像绘制部分
plt.figure(5, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.text(0.8, 0.9,'M = 9', font, style = 'italic')
plt.title('Figure 5 : M = 9, N = 10', font)
plt.savefig('5.png', dpi=400)
当模型复杂度确定时,考虑利用更多的观测点(15个和100个)对9阶多项式进行拟合。
M=9
N=15
x_15, t_15 = get_data(N)
p = regress(M, N, x, x_15, t_15)
#图像绘制部分
plt.figure(6, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_15, t_15, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.text(0.8, 0.65,'N = 15', font, style = 'italic')
plt.title('Figure 6 : M = 9, N = 15', font)
plt.savefig('6.png', dpi=400)
M=9
N=100
x_100, t_100 = get_data(N)
p = regress(M, N, x, x_100, t_100)
#图像绘制部分
plt.figure(7, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_100, t_100, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.text(0.8, 0.65,'N = 100', font, style = 'italic')
plt.title('Figure 7 : M = 9, N = 100', font)
plt.savefig('7.png', dpi=400)
可以看到,数据规模的增加能够有效的减轻模型的过拟合问题。但是实际应用中可能无法获得足够数据量。
除了增加数据量来减轻过拟合的影响,还可以通过正则化方法。在定义误差函数时增加惩罚项,使多项式系数被有效控制,不会过大。
误差函数变为如下形式:
E ~ ( w ) = 1 2 ∑ n = 1 N { y ( x n , w ) − t n } 2 + λ 2 ∣ ∣ w ∣ ∣ 2 \widetilde{E}(w)=\frac{1}{2}\sum\limits_{n=1}^{N}\{y(x_n,w)-t_n\}^2+\frac{\lambda}{2}||w||^2 E (w)=21n=1∑N{y(xn,w)−tn}2+2λ∣∣w∣∣2
求导置零得到:
W = ( X T X + λ E m + 1 ) − 1 X T T W = (X^T X + λE_{m+1})^{-1}X^TT W=(XTX+λEm+1)−1XTT
然后,我们进行当多项式阶数 M = 9 M = 9 M=9 时,有 N = 10 N = 10 N=10 个采样点的情况下,λ较小和较大时(如 l n λ = − 18 lnλ = -18 lnλ=−18 和 l n λ = 0 lnλ = 0 lnλ=0 ) 时对过拟合现象的实验。
M=9
N=10
x_10, t_10 = get_data(N)
#lnλ = 0
p = regress(M, N, x, x_10, t_10, np.exp(0))
#图像绘制部分
plt.figure(8, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.text(0.8, 0.9,' lnλ = 0', font, style = 'italic')
plt.title('Figure 8 : M = 9, N = 10, lnλ = 0', font)
plt.savefig('8.png', dpi=400)
#lnλ = -18
p = regress(M, N, x, x_10, t_10, np.exp(-18))
#图像绘制部分
plt.figure(9, figsize=(8,5))
plt.plot(x, t, 'g', x, p, 'r',linewidth=3)
plt.scatter(x_10, t_10, color='', marker='o', edgecolors='b', s=100, linewidth=3)
draw_ticks()
plt.text(0.8, 0.9,' lnλ = -18', font, style = 'italic')
plt.title('Figure 9 : M = 9, N = 10, lnλ = -18', font)
plt.savefig('9.png', dpi=400)
结果显示,加上了正则项后,λ 较小时有效地改善了高阶多项式的过拟合现象,但是当 λ 过大时会过度抑制模型系数。所以,根据模型的复杂度来进行合适的正则化对于拟合结果非常重要。
ps:第一次用Markdown写博,还挺酷的哈哈哈~