如果您的数据点显然不适合线性回归(穿过数据点之间的直线),那么多项式回归可能是理想的选择。
像线性回归一样,多项式回归使用变量 x 和 y 之间的关系来找到绘制数据点线的最佳方法。
Python 有一些方法可以找到数据点之间的关系并画出多项式回归线。我将向您展示如何使用这些方法而不是通过数学公式。
在下面的例子中,我们注册了 18 辆经过特定收费站的汽车。
我们已经记录了汽车的速度和通过时间(小时)。
x 轴表示一天中的小时,y 轴表示速度:
import matplotlib.pyplot as plt
x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]
plt.scatter(x, y)
plt.show()
显然使用线性回归是不可能的
导入 numpy 和 matplotlib,然后画出多项式回归线:
import numpy
import matplotlib.pyplot as plt
x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]
mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))
myline = numpy.linspace(1, 22, 100)
plt.scatter(x, y)
plt.plot(myline, mymodel(myline))
plt.show()
polyfit 多项式曲线拟合 资料:numpy-poly1d、polyfit、polyval多项式使用
关于这两个函数的用法由于笔者功底暂时有限无法详细解释,各位读者可自行百度.
结果:
可以看出拟合效果非常好
导入所需模块:
import numpy
import matplotlib.pyplot as plt
创建表示 x 和 y 轴值的数组:
x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]
NumPy 有一种方法可以让我们建立多项式模型:
mymodel = numpy.poly1d(numpy.polyfit(x, y, 3))
然后指定行的显示方式,我们从位置 1 开始,到位置 22 结束:
myline = numpy.linspace(1, 22, 100)
绘制原始散点图:
plt.scatter(x, y)
画出多项式回归线:
plt.plot(myline, mymodel(myline))
显示图表:
plt.show()
定义:衡量模型拟合度的一个量,是一个比例式,比例区间为[0,1],越接近1,表示模型拟合度越高
R 2 = 1 − ( 观 测 值 − 预 测 值 ) 2 ( 观 测 值 − 观 测 值 全 体 的 平 均 ) 2 R^2=1-\frac{(观测值-预测值)^2}{(观测值-观测值全体的平均)^2} R2=1−(观测值−观测值全体的平均)2(观测值−预测值)2
可以使用sklearn实现
R2=sklearn.linear_model.score(x,y)
print(R2)
值得注意的是R2越接近于1说明拟合效果越好
注意每种方法使用起来都有一些区别(有些方法使用范围有限,例如numpy无法用于多元回归),这需要通过实践来熟悉
使用numpy库中的poly1d和polyfit函数进行回归预测
资料:numpy-poly1d、polyfit、polyval多项式使用
import matplotlib.pyplot as plt
from scipy import stats
import matplotlib
import numpy as np
matplotlib.rcParams['font.sans-serif'] = ['SimHei'] # 用黑体显示中文
x = [1,2,3,5,6,7,8,9,10,12,13,14,15,16,18,19,21,22]
y = [100,90,80,60,60,55,60,65,70,70,75,76,78,79,90,99,99,100]
slope, intercept, r, p, std_err = stats.linregress(x, y)
myline = np.linspace(1, 30, 100)
for i in [3]:
mymodel = np.poly1d(np.polyfit(x, y, i))
plt.xlim(0,25)
plt.ylim(50,110)
plt.plot(x, mymodel(x),label="{0}".format(i))
plt.scatter(x, y,color="r")
plt.legend()
plt.grid()
plt.show()
同样是上面的数据我们使用sklearn来看看
代码: