线性回归-非线性模型

原文章链接
看了博主的文章,了解到可以根据线性回归的特性,实现对非线性回归模型的预测;
数据是自定义函数:y= x^2 - 2x + 1 计算得到:
原始数据:
线性回归-非线性模型_第1张图片

如下图,红色是原始数据,黑色是采用线性回归直接预测结果,可见直接线性回归的数据是有问题,均方差也很大。
线性回归-非线性模型_第2张图片

线性回归-非线性模型_第3张图片
于是根据python自带函数,将参数x进行阶数提升:得到常数项1,x,x^2

poly = PolynomialFeatures(degree=2)  # 设置自变量阶数
x_poly = poly.fit_transform(X)  # 使用多项式特征器对x进行转换

结果:
线性回归-非线性模型_第4张图片
再将此X值代入线性模型即可得到非线性的预测结果:

plt.scatter(X, prediction1, color='black', label="linear data")
plt.scatter(X, prediction2, color='b', label="no linear data", alpha=0.5)  # alpha代表透明度(0-1)
plt.legend(loc='upper left')  # 图例
plt.show()

结果:
线性回归-非线性模型_第5张图片
结合原始数据,发现预测结果基本重合:
线性回归-非线性模型_第6张图片
(红色+蓝色变为紫色!)

完整代码实现:

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

data = pd.read_excel(r'非线性.xlsx')
# print(data.head(10))

# plt.show()

X = np.array(data.x).reshape(-1, 1)

estimate1 = LinearRegression()
estimate1.fit(X, data.y)
print("预测1参数K:" + str(estimate1.coef_))
print("预测1截距:" + str(estimate1.intercept_))

# 用训练的线性模型预测y值
prediction1 = estimate1.predict(X)
print("预测结果1:")
print(prediction1)

RMSE1 = np.sqrt(mean_squared_error(prediction1, data.y))
print("均方差1:" + str(RMSE1))
# plt.scatter(X, prediction1, color='black')
# plt.show()

from sklearn.preprocessing import PolynomialFeatures

# 设置阶数
poly = PolynomialFeatures(degree=2)  # 设置自变量阶数
x_poly = poly.fit_transform(X)  # 使用多项式特征器对x进行转换
print(x_poly)

estimate2 = LinearRegression()
estimate2.fit(x_poly, data.y)

print("预测2参数K:" + str(estimate2.coef_))
print("预测2截距:" + str(estimate2.intercept_))

prediction2 = estimate2.predict(x_poly)
print("预测结果2:")
print(prediction2)
RMSE2 = np.sqrt(mean_squared_error(prediction2, data.y))
print("均方差2:" + str(RMSE2))

plt.scatter(data.x, data.y, color="r", label="orignal data")  # label是图例名称
plt.scatter(X, prediction1, color='black', label="linear data")
plt.scatter(X, prediction2, color='b', label="no linear data", alpha=0.5)  # alpha代表透明度(0-1)
plt.legend(loc='upper left')  # 图例
plt.show()

你可能感兴趣的:(线性回归,算法,回归)