线性回归适用于数据成线性分布的回归问题,如果样本是非线性分布,线性回归就不再使用,转而可以采用非线性模型进行回归,比如多项式回归
与线性模型,多项式模型引入了高次项:
y = w 0 + w 1 x + w 2 x 2 + w 3 x 3 + . . . + w n x n y = w_0 + w_1x + w_2x^2 + w_3x^3 + ... + w_nx^n y=w0+w1x+w2x2+w3x3+...+wnxn
即:
y = ∑ i = 1 n w i x i y = \sum_{i=1}^{n}w_ix^i y=i=1∑nwixi
多项式回归模型可以理解为线性回归的扩展,即在线性回归模型中添加了新的特征从而增加了模型的表达能力。比如:有 x 1 , x 2 , x 3 x_1,x_2,x_3 x1,x2,x3三个特征,分别表示房子的长,宽,高,则房屋价格可以使用线性回归模型表示为:
y = w 0 + w 1 x 1 + w 2 x 2 + w 3 x 3 y = w_0 + w_1x_1 + w_2x_2 + w_3x_3 y=w0+w1x1+w2x2+w3x3
对于房屋价格同样也可以使用房屋的体积表示,即使用多项式回归模型表示为:
y = w 0 + w 1 x + w 2 x 2 + w 3 x 3 y = w_0 + w_1x + w_2x^2 + w_3x^3 y=w0+w1x+w2x2+w3x3
因此,n元一次线性模型和一元n次多项式模型一定程度上可以相互转换。
对于多项式回归模型同样使用梯度下降对损失函数进行优化,寻找到最优的一组参数 w 0 , w 1 , w 2 , . . , w n w_0, w_1, w_2, .. , w_n w0,w1,w2,..,wn就可以将一元n次多项式转换为n元一次多项式进而求线性回归。
import numpy as np
import sklearn.linear_model as lm
import sklearn.metrics as sm
import matplotlib.pyplot as plt
import sklearn.pipeline as pl # 管线模块
import sklearn.preprocessing as sp
train_x, train_y = [], []
with open('D:\python\data\poly_sample.txt', 'r') as f:
for line in f.readlines():
data = [float(substr) for substr in line.split(',')]
train_x.append(data[:-1])
train_y.append(data[-1])
# 将数据转换为numpy数组格式
train_x = np.array(train_x)
train_y = np.array(train_y)
# 链接两个模型(可以看出,实现此多项式回归正是基于线性回归模型扩展得到)
model = pl.make_pipeline(sp.PolynomialFeatures(3),
lm.LinearRegression())
# 利用数据训练多项式回归器
model.fit(train_x, train_y)
# 根据训练模型预测输出
pred_y = model.predict(train_x)
# 评估模型(使用R2系数)
score_r2 = sm.r2_score(train_y, pred_y)
print('score_r2: %f'%score_r2) # score_r2: 0.903666
# 测试模型
test_x = np.linspace(train_x.min(), train_x.max(), 1000)
pred_test_y = model.predict(test_x.reshape(-1,1))
print('--------可视化--------')
plt.figure('Polynomial Regression', facecolor='gray')
plt.title('Polynomial Regression', fontsize=18)
plt.xlabel('x', fontsize=18)
plt.ylabel('y', fontsize=18)
plt.tick_params(labelsize=10)
plt.grid(linestyle=':')
plt.scatter(train_x, train_y, c='red', alpha=0.8, s=60, label='Sample')
plt.plot(test_x, pred_test_y, c='blue', label='Regression')
plt.legend()
plt.show()