回归树几乎所有参数、属性、接口都和分类树一样。需要注意的是在回归树中没有标签分布是否均衡的问题,因此回归树没有class_weight参数
交叉验证是用来观察模型的稳定性的一种方法,我们将数据划分为n份,依次使用其中一份作为测试集,其他n-1份作为训练集,多次计算模型的精确性来评估模型的平均准确程度。训练集和测试集的划分会干扰模型的结果,因此用交叉验证n次的结果求出的平均值,是对模型效果的一个更好的度量。
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_boston
# 利用boston房价数据集来测试交叉检验
# 对于分类模型,交叉验证返回的是模型的精确度
# 对于回归模型,交叉验证默认返回的是R^2
# 可以改变scoring参数改变返回数据
# 一般使用MSE(均方误差)
boston = load_boston()
# 定义模型
dtr = DecisionTreeRegressor()
# 交叉检验会自动分割数据据,所以不需要人为对数据据进行划分
# cv 表示划分数据集的份数,推荐5或10
r2_cv = cross_val_score(dtr, boston['data'], boston['target'], cv=10)
# r2 越接近1越好
neg_MSE_cv = cross_val_score(dtr, boston['data'], boston['target'], cv=10, scoring='neg_mean_squared_error')
# 返回负的均方误差,均方误差越接近0越好
print(r2_cv)
print(neg_MSE_cv)
通过numpy随机生成80个数据,并计算正弦值,将正弦值作为target,x值为features。
给生成的target添加一些噪声。
训练模型,分别限制最大深度为3 和5,再用训练数据来测试模型
对numpy不了解的可以参考numpy笔记
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_boston
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 确保每次随机出来的数据是一样的,方便调试
np.random.seed(1)
# 随机生成0-10之间的50个数字
x = np.random.uniform(0, 10, size=(80, 1))
# 将数据排序,方便后面画图操作
x.sort(axis=0)
# 计算正弦值
y = np.sin(x).reshape(80)
# 画出散点图看数据分布
plt.scatter(x, y)
# 给数据添加噪声
# 每5个点将其数值变为1-2*n,n为0-1之间的随机数
y[::5] = 1 - 2 * (np.random.rand(y[::5].shape[0]))
plt.scatter(x, y)
# 训练模型,最大深度为三
dcr_3 = DecisionTreeRegressor(max_depth=3)
dcr_3.fit(x, y)
# 最大深度为5
dcr_5 = DecisionTreeRegressor(max_depth=5)
dcr_5.fit(x, y)
# 用训练集来检测模型
plt.figure(figsize=(10, 8))
plt.scatter(x, y, label='row data')
y3_predict = dcr_3.predict(x)
plt.plot(x, y3_predict, color='red', label='3_depth')
y5_predict = dcr_5.predict(x)
plt.plot(x, y5_predict, color='green', label='5_depth')
plt.legend(loc='best', fontsize=20)
# 很明显的看出,最大深度为3 的决策树,的拟合程度是基本符合原始数据分布的
# 最大深度为5的决策树,对原始数据拟合很好,过多的考虑了数据中的噪声点
# 这种情况就是过拟合了