class sklearn.tree.DecisionTreeClassifier(*, criterion=‘gini’, splitter=‘best’, max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, class_weight=None, presort=‘deprecated’, ccp_alpha=0.0)
scikit-learn 使用 CART 算法的优化版本
注意:当使用entropy时,sklearn实际计算的是基于信息熵的信息增益(information gain),即父节点的信息熵和子节点的信息熵之差。比起基尼系数,信息熵对不纯度更加敏感,对不纯度的惩罚最强。但在实际使用中,信息熵和基尼系数的效果基本相同。信息熵的计算比基尼系数缓慢一些,因为基尼系数的计算不涉及对数。另外,由于信息熵对不纯度更加敏感,所以信息熵作为指标时,决策树的生长会更加精细,因此对于高维数据或者噪音数据很多的数据,信息熵容易过拟合,基尼系数在这种情况下往往效果比较好。当模型欠拟合的时候,可考虑使用信息熵。
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split #训练集和测试集的切分函数
#导入数据集并探索数据
wine = load_wine() #实例化数据集
wine
{'data': array([[1.423e+01, 1.710e+00, 2.430e+00, ..., 1.040e+00, 3.920e+00,
1.065e+03],
[1.320e+01, 1.780e+00, 2.140e+00, ..., 1.050e+00, 3.400e+00,
1.050e+03],
[1.316e+01, 2.360e+00, 2.670e+00, ..., 1.030e+00, 3.170e+00,
1.185e+03],
...,
[1.327e+01, 4.280e+00, 2.260e+00, ..., 5.900e-01, 1.560e+00,
8.350e+02],
[1.317e+01, 2.590e+00, 2.370e+00, ..., 6.000e-01, 1.620e+00,
8.400e+02],
[1.413e+01, 4.100e+00, 2.740e+00, ..., 6.100e-01, 1.600e+00,
5.600e+02]]),
'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2]),
'target_names': array(['class_0', 'class_1', 'class_2'], dtype='
wine.data.shape
(178, 13)
wine.target
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2])
pd.concat([pd.DataFrame(wine.data), pd.DataFrame(wine.target)],axis=1).head()
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 0 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 14.23 | 1.71 | 2.43 | 15.6 | 127.0 | 2.80 | 3.06 | 0.28 | 2.29 | 5.64 | 1.04 | 3.92 | 1065.0 | 0 |
1 | 13.20 | 1.78 | 2.14 | 11.2 | 100.0 | 2.65 | 2.76 | 0.26 | 1.28 | 4.38 | 1.05 | 3.40 | 1050.0 | 0 |
2 | 13.16 | 2.36 | 2.67 | 18.6 | 101.0 | 2.80 | 3.24 | 0.30 | 2.81 | 5.68 | 1.03 | 3.17 | 1185.0 | 0 |
3 | 14.37 | 1.95 | 2.50 | 16.8 | 113.0 | 3.85 | 3.49 | 0.24 | 2.18 | 7.80 | 0.86 | 3.45 | 1480.0 | 0 |
4 | 13.24 | 2.59 | 2.87 | 21.0 | 118.0 | 2.80 | 2.69 | 0.39 | 1.82 | 4.32 | 1.04 | 2.93 | 735.0 | 0 |
wine.feature_names
['alcohol',
'malic_acid',
'ash',
'alcalinity_of_ash',
'magnesium',
'total_phenols',
'flavanoids',
'nonflavanoid_phenols',
'proanthocyanins',
'color_intensity',
'hue',
'od280/od315_of_diluted_wines',
'proline']
wine.target_names
array(['class_0', 'class_1', 'class_2'], dtype='
#切分训练集和测试集
xtrain, xtest, ytrain, ytest = train_test_split(wine.data, wine.target, test_size=0.3)
xtrain.shape
(124, 13)
xtest.shape
(54, 13)
#模型训练
clf = tree.DecisionTreeClassifier(criterion='entropy')
clf = clf.fit(xtrain, ytrain)
score = clf.score(xtest, ytest) #返回预测的准确度
score
0.8703703703703703
# 绘制决策树
tree.plot_tree(clf.fit(xtrain, ytrain),feature_names=wine.feature_names, class_names=wine.target_names )
[Text(167.4, 195.696, 'proline <= 953.5\nentropy = 1.571\nsamples = 124\nvalue = [40, 49, 35]\nclass = class_1'),
Text(111.60000000000001, 152.208, 'color_intensity <= 3.825\nentropy = 1.202\nsamples = 88\nvalue = [4, 49, 35]\nclass = class_1'),
Text(55.800000000000004, 108.72, 'entropy = 0.0\nsamples = 44\nvalue = [0, 44, 0]\nclass = class_1'),
Text(167.4, 108.72, 'flavanoids <= 1.4\nentropy = 0.934\nsamples = 44\nvalue = [4, 5, 35]\nclass = class_2'),
Text(111.60000000000001, 65.232, 'entropy = 0.0\nsamples = 35\nvalue = [0, 0, 35]\nclass = class_2'),
Text(223.20000000000002, 65.232, 'proline <= 679.0\nentropy = 0.991\nsamples = 9\nvalue = [4, 5, 0]\nclass = class_1'),
Text(167.4, 21.744, 'entropy = 0.0\nsamples = 5\nvalue = [0, 5, 0]\nclass = class_1'),
Text(279.0, 21.744, 'entropy = 0.0\nsamples = 4\nvalue = [4, 0, 0]\nclass = class_0'),
Text(223.20000000000002, 152.208, 'entropy = 0.0\nsamples = 36\nvalue = [36, 0, 0]\nclass = class_0')]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-msI2sR3a-1589803106915)(output_16_1.png)]
#特征重要性
clf.feature_importances_
array([0. , 0. , 0. , 0. , 0. ,
0. , 0.1650984 , 0. , 0. , 0.33215949,
0. , 0. , 0.50274211])
clf = tree.DecisionTreeClassifier(criterion='gini', random_state=42)
clf = clf.fit(xtrain, ytrain)
score = clf.score(xtest, ytest)
score
0.8888888888888888
用于在每个节点上选择拆分的策略。 best在特征的所有划分点中找出最优的划分点。random是随机的在部分划分点中找局部最优的划分点。默认的"best"适合样本量不大的时候,而如果样本数据量非常大,此时决策树构建推荐"random"。
clf = tree.DecisionTreeClassifier(criterion='entropy', splitter='random', random_state=42)
clf = clf.fit(xtrain, ytrain)
score = clf.score(xtest, ytest)
score
0.8888888888888888
tree.plot_tree(clf.fit(xtrain, ytrain),feature_names=wine.feature_names, class_names=wine.target_names )
[Text(139.5, 199.32, 'od280/od315_of_diluted_wines <= 2.123\nentropy = 1.571\nsamples = 124\nvalue = [40, 49, 35]\nclass = class_1'),
Text(63.77142857142857, 163.07999999999998, 'hue <= 0.936\nentropy = 0.669\nsamples = 40\nvalue = [0, 7, 33]\nclass = class_2'),
Text(31.885714285714286, 126.83999999999999, 'od280/od315_of_diluted_wines <= 1.785\nentropy = 0.323\nsamples = 34\nvalue = [0, 2, 32]\nclass = class_2'),
Text(15.942857142857143, 90.6, 'entropy = 0.0\nsamples = 24\nvalue = [0, 0, 24]\nclass = class_2'),
Text(47.82857142857143, 90.6, 'flavanoids <= 0.872\nentropy = 0.722\nsamples = 10\nvalue = [0, 2, 8]\nclass = class_2'),
Text(31.885714285714286, 54.359999999999985, 'entropy = 0.0\nsamples = 7\nvalue = [0, 0, 7]\nclass = class_2'),
Text(63.77142857142857, 54.359999999999985, 'ash <= 2.617\nentropy = 0.918\nsamples = 3\nvalue = [0, 2, 1]\nclass = class_1'),
Text(47.82857142857143, 18.119999999999976, 'entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]\nclass = class_1'),
Text(79.71428571428572, 18.119999999999976, 'entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = class_2'),
Text(95.65714285714286, 126.83999999999999, 'alcohol <= 13.386\nentropy = 0.65\nsamples = 6\nvalue = [0, 5, 1]\nclass = class_1'),
Text(79.71428571428572, 90.6, 'entropy = 0.0\nsamples = 5\nvalue = [0, 5, 0]\nclass = class_1'),
Text(111.6, 90.6, 'entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = class_2'),
Text(215.22857142857143, 163.07999999999998, 'alcohol <= 13.013\nentropy = 1.138\nsamples = 84\nvalue = [40, 42, 2]\nclass = class_1'),
Text(159.42857142857144, 126.83999999999999, 'flavanoids <= 0.622\nentropy = 0.324\nsamples = 42\nvalue = [1, 40, 1]\nclass = class_1'),
Text(143.4857142857143, 90.6, 'entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = class_2'),
Text(175.37142857142857, 90.6, 'od280/od315_of_diluted_wines <= 3.584\nentropy = 0.165\nsamples = 41\nvalue = [1, 40, 0]\nclass = class_1'),
Text(159.42857142857144, 54.359999999999985, 'entropy = 0.0\nsamples = 39\nvalue = [0, 39, 0]\nclass = class_1'),
Text(191.31428571428572, 54.359999999999985, 'nonflavanoid_phenols <= 0.233\nentropy = 1.0\nsamples = 2\nvalue = [1, 1, 0]\nclass = class_0'),
Text(175.37142857142857, 18.119999999999976, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]\nclass = class_1'),
Text(207.25714285714287, 18.119999999999976, 'entropy = 0.0\nsamples = 1\nvalue = [1, 0, 0]\nclass = class_0'),
Text(271.0285714285714, 126.83999999999999, 'hue <= 0.772\nentropy = 0.437\nsamples = 42\nvalue = [39, 2, 1]\nclass = class_0'),
Text(239.14285714285714, 90.6, 'total_phenols <= 1.633\nentropy = 1.0\nsamples = 2\nvalue = [0, 1, 1]\nclass = class_1'),
Text(223.2, 54.359999999999985, 'entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]\nclass = class_2'),
Text(255.0857142857143, 54.359999999999985, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]\nclass = class_1'),
Text(302.9142857142857, 90.6, 'hue <= 1.308\nentropy = 0.169\nsamples = 40\nvalue = [39, 1, 0]\nclass = class_0'),
Text(286.9714285714286, 54.359999999999985, 'entropy = 0.0\nsamples = 39\nvalue = [39, 0, 0]\nclass = class_0'),
Text(318.8571428571429, 54.359999999999985, 'entropy = 0.0\nsamples = 1\nvalue = [0, 1, 0]\nclass = class_1')]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-J6EoZ066-1589803106922)(output_22_1.png)]
score = clf.score(xtrain, ytrain)
score
1.0
在不加限制的情况下,决策树很容易出现过拟合的现象。为解决这一问题,往往需要对决策树进行剪枝。正确的剪枝策略是优化决策树算法的核心。在sklearn中有以下几个参数可对决策树进行剪枝:
max_depth
限制树的最大深度,超过设定深度的树枝全部剪掉。默认为None,则结点将一直生长,直到所有叶子都是纯的,或者直到所有叶子都包含少于min_samples_split个样本。这是用得最广泛的剪枝参数,在高纬度低样本量时非常有效。一般可从max_depth=3开始尝试。
min_samples_leaf
该参数限定了一个结点在分枝后的子节点中都必须包含至少min_samples_leaf个训练样本,否则分枝就不会发生。与max_depth一起搭配使用,可以让模型变得更加平滑。这个参数的数量设置得太小容易引起过拟合,太大则会阻止模型学习数据。一般可从min_samples_leaf=5开始尝试。如果叶子结点中的样本量变化很大,可输入浮点数作为样本量的百分比来使用。同时,该参数可以保证每个叶子的最小尺寸,可以咋回归问题中避免低方差,过拟合的叶子结点出现。对于类别不多的分类问题, 1通常就是最佳选择。
min_samples_split
限定一个结点必须要包含至少min_samples_split个训练样本,这个结点才允许被分枝,否则分枝就不会发生。
clf = tree.DecisionTreeClassifier(criterion='entropy', random_state=42, splitter='random', max_depth=3, min_samples_leaf=10, min_samples_split=10)
clf = clf.fit(xtrain, ytrain)
score = clf.score(xtest, ytest)
score
0.8703703703703703
tree.plot_tree(clf.fit(xtrain, ytrain))
[Text(153.45000000000002, 190.26, 'X[11] <= 2.123\nentropy = 1.571\nsamples = 124\nvalue = [40, 49, 35]'),
Text(83.7, 135.9, 'X[1] <= 3.601\nentropy = 0.669\nsamples = 40\nvalue = [0, 7, 33]'),
Text(55.800000000000004, 81.53999999999999, 'X[9] <= 6.71\nentropy = 0.84\nsamples = 26\nvalue = [0, 7, 19]'),
Text(27.900000000000002, 27.180000000000007, 'entropy = 0.997\nsamples = 15\nvalue = [0, 7, 8]'),
Text(83.7, 27.180000000000007, 'entropy = 0.0\nsamples = 11\nvalue = [0, 0, 11]'),
Text(111.60000000000001, 81.53999999999999, 'entropy = 0.0\nsamples = 14\nvalue = [0, 0, 14]'),
Text(223.20000000000002, 135.9, 'X[0] <= 13.155\nentropy = 1.138\nsamples = 84\nvalue = [40, 42, 2]'),
Text(167.4, 81.53999999999999, 'X[5] <= 1.774\nentropy = 0.574\nsamples = 46\nvalue = [4, 41, 1]'),
Text(139.5, 27.180000000000007, 'entropy = 0.469\nsamples = 10\nvalue = [0, 9, 1]'),
Text(195.3, 27.180000000000007, 'entropy = 0.503\nsamples = 36\nvalue = [4, 32, 0]'),
Text(279.0, 81.53999999999999, 'X[7] <= 0.306\nentropy = 0.35\nsamples = 38\nvalue = [36, 1, 1]'),
Text(251.10000000000002, 27.180000000000007, 'entropy = 0.235\nsamples = 26\nvalue = [25, 1, 0]'),
Text(306.90000000000003, 27.180000000000007, 'entropy = 0.414\nsamples = 12\nvalue = [11, 0, 1]')]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qlwkFC2d-1589803106929)(output_29_1.png)]
clf.score(xtrain, ytrain)
0.8870967741935484
max_features限制分枝时考虑的特征个数,超过限制个数的特征都会被舍弃。和max_depth异曲同工。max_features是用来限制高纬度数据的过拟合的剪枝参数,但其方法比较暴力。在不知道决策树中各个特征的重要性的情况下,强行设定这个参数可能会导致模型学习不足。如果希望通过降维的方式防止过拟合,建议使用PCA, ICA或者特征选择模块中的降维算法。
min_impurity_decrease限制信息增益的大小,信息增益小于设定数值的分枝不会发生。
使用超参数的学习曲线进行判断。超参数的学习曲线,是一条以超参数的取值为横坐标,模型的度量指标为纵坐标的曲线。它是用来衡量不同超参数取值下模型的表现的线。
import matplotlib.pyplot as plt
test=[]
for i in range(10):
clf = tree.DecisionTreeClassifier(max_depth=i+1, criterion='entropy', random_state=42, splitter='random')
clf = clf.fit(xtrain, ytrain)
score = clf.score(xtest, ytest)
test.append(score)
plt.plot(range(1,11), test, color='red', label='max_depth')
plt.legend()
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-q98gx9Im-1589803106941)(output_35_0.png)]
属性是在模型训练之后,能够调用查看的模型的各种性质。
对决策树来说,最重要的 是feature_importances_,能够查看各个特征对模型的重要性。
fit
score
apply: 输入测试集,返回每个测试样本所在的叶子结点的索引
clf.apply(xtest)
array([14, 25, 16, 25, 25, 25, 16, 25, 16, 16, 25, 25, 16, 3, 25, 16, 25,
8, 5, 16, 25, 3, 3, 7, 3, 16, 3, 25, 25, 3, 25, 16, 16, 25,
25, 25, 25, 16, 16, 16, 16, 16, 25, 3, 3, 3, 25, 25, 25, 16, 25,
16, 3, 19], dtype=int64)
clf.predict(xtest)
array([2, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 2, 0, 1, 0, 2, 2, 1, 0, 2,
2, 1, 2, 1, 2, 0, 0, 2, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 2,
2, 2, 0, 0, 0, 1, 0, 1, 2, 0])
#计算预测正确的个数
(clf.predict(xtest)==ytest).sum()
48
#计算预测的准确率
(clf.predict(xtest)==ytest).mean()
0.8888888888888888
test
[0.6111111111111112,
0.7222222222222222,
0.6666666666666666,
0.8703703703703703,
0.8888888888888888,
0.8888888888888888,
0.8888888888888888,
0.8888888888888888,
0.8888888888888888,
0.8888888888888888]
#选择最大树深=4
clf = tree.DecisionTreeClassifier(max_depth=4, criterion='entropy', random_state=42, splitter="random")
clf = clf.fit(xtrain, ytrain)
score = clf.score(xtest, ytest)
score
0.8703703703703703
样本不均衡是指在一组数据集中,标签的一类天生占有很大的比例,但我们有捕捉出特定的分类的需求的状况。
分类模型天生会倾向于多数的类,让多数类更容易被判断正确,少数类被牺牲掉。因为对于模型而言,样本量越大的标签可以学习的信息越多,算法就会更加依赖于从多数类中学到的信息来进行判断。如果我们希望捕获少数类,模型就会失败
解决方法:
SMOTE方法:该方法通过将少数类的特征重新组合,创造出更多的少数类样本。但是该方法会增加样本的总数,会影响计算速度。因此,提出了更加针对少数类的指标来优化模型
有了权重之后,样本量就不在单纯地记录数目, 而是受输入的权重影响。因此,这时候的剪枝需要使用基于权重的剪枝参数min_weight_fraction_leaf。该参数的使用将比min_sample_leaf更少偏向主类。若样本是加权的,则使用基于权重的预修剪标准更容易优化树结构,这确保叶结点至少包含样本权重的总和的一小部分。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier as DTC
from sklearn.datasets import make_blobs #创建团状数据
class_1 = 500 #类别1只有500个样本
class_2 = 50 #类别2有50个样本
centers = [[0,0],[2.0,2.0]] #设置两个类别的中心
clusters_std = [1.5, 0.5] #设定两个类别的方差,通常来说,样本量比价大的类别会更加松散
x,y = make_blobs(n_samples=[class_1, class_2], centers=centers, cluster_std=clusters_std, random_state=430, shuffle=False)
#查看数据集
plt.scatter(x[:, 0], x[:, 1], c=y, cmap='rainbow', s=10)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Sj66mm7B-1589803106947)(output_59_1.png)]
#不设定class_weight
clf = DTC(max_depth=4)
clf.fit(x, y)
clf.predict(x)
array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1])
#设定class_weight
wclf = DTC(max_depth=4, class_weight={1:10}) #设定1类样本权重为10,默认0类样本权重为1
wclf.fit(x,y)
wclf.predict(x)
array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
#计算两个模型的准确度
clf.score(x, y)
0.9709090909090909
wclf.score(x,y)
0.9272727272727272
样本平衡后,少数类几乎都被分类正确了,但是多数类有许多被分错了。从准确率的角度来看,不做样本平衡的时候准确率反而更高。这是因为样本平衡后,为了更有效地捕捉出少数类,模型误伤了许多多数类样本。因此,如果目的在于模型整体的准确率,那么就要拒绝样本平衡。
预测正确的数除以总样本
(clf.predict(x)[-50:]).sum()
41
(clf.predict(x)[:500]).sum()
7
表示预测的少数类样本中,预测正确的少数类样本所占的比例。该比值可以 衡量“将多数类判错后所需付出的成本”
# 所有判断正确并确实为1的样本/所有被判断为1的样本
# 对于没有class_weight的决策树而言:
(y[y == clf.predict(x)]==1).sum()/(clf.predict(x)== 1).sum()
0.8541666666666666
表示在真实的少数类样本中,被正确预测的少数类样本所占的比例。
# 对于有class_weight的决策树来说:
(y[y == wclf.predict(x)] == 1).sum()/(y == 1).sum()
1.0
可以看出,做样本平衡之后的模型,捕捉出了100%的少数类点。
如果希望不计一切代价找出少数类(如潜在的犯罪者),就会追求高召回率,相反则不需要。召回率和精确度是此消彼长的,两者之间的平衡代表了捕捉少数类的需求和尽量不要误伤多数类的需求的平衡。
是精确度和召回率的调和平均数。该参数倾向于靠近两个数中比较小的一个数,因此追求高F1 measure能够保证精确度和召回率都比较高。该参数的值分布在[0,1]之间。
F − m e a s u r e = 2 1 P r e c i s i o n + 1 R e c a l l = 2 ∗ P r e c i s i o n ∗ R e c a l l P r e c i s i o n + R e c a l l F\ -\ measure=\frac{2}{\frac{1}{Precision}+\frac{1}{Recall}}\ = \ \frac{2*Precision*Recall}{Precision+Recall} F − measure=Precision1+Recall12 = Precision+Recall2∗Precision∗Recall
表示真实的多数类样本中,被正确预测的多数类样本所占的比例
# 所有被正确预测为0的样本/所有的0样本
# 对于没有class_weight的决策树而言:
(y[y == clf.predict(x)] == 0).sum()/(y == 0).sum()
0.986
#对于有class_weight的决策树来说:
(y[y == wclf.predict(x)]==0).sum()/(y == 0).sum()
0.92
#sklearn中的混淆矩阵
from sklearn.metrics import confusion_matrix as cm
cm(y, clf.predict(x))
array([[493, 7],
[ 9, 41]], dtype=int64)
cm(y, wclf.predict(x))
array([[460, 40],
[ 0, 50]], dtype=int64)
#sklearn中的准确率sklearn.metrics.accuracy_score
from sklearn .metrics import accuracy_score
accuracy_score(y, clf.predict(x))
0.9709090909090909
#sklearn中的精确度sklearn.metrics.precision_score
from sklearn.metrics import precision_score
precision_score(y, clf.predict(x))
0.8541666666666666
#sklearn中的召回率sklearn.metrics.recall_score
from sklearn.metrics import recall_score
recall_score(y, clf.predict(x))
0.82
#sklearn中的精确度-召回率平衡曲线
from sklearn.metrics import plot_precision_recall_curve
from matplotlib import pyplot as plt
disp = plot_precision_recall_curve( clf,x,y)
disp.ax_.set_title('2-class Precision-Recall curve')
Text(0.5, 1.0, '2-class Precision-Recall curve')
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yF8Zujcy-1589803106959)(output_84_1.png)]
#sklearn中的精确度-召回率平衡曲线
from sklearn.metrics import plot_precision_recall_curve
from matplotlib import pyplot as plt
disp = plot_precision_recall_curve( wclf,x,y)
disp.ax_.set_title('2-class Precision-Recall curve')
Text(0.5, 1.0, '2-class Precision-Recall curve')
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-r9GogUrS-1589803106966)(output_85_1.png)]
#sklearn中的F1 measure
from sklearn.metrics import f1_score
f1_score(y, clf.predict(x))
0.836734693877551
class sklearn.tree.DecisionTreeRegressor(*, criterion=‘mse’, splitter=‘best’, max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, presort=‘deprecated’, ccp_alpha=0.0)[source]¶
回归树和分类树的参数几乎一样。值得注意的是,在回归树种,没有标签分布是否均衡的问题,因此没有class_weight这样的参数。
回归树的接口score返回的是P平方,并不是MSE
from sklearn.datasets import load_boston
from sklearn.model_selection import cross_val_score #导入交叉验证的函数
from sklearn.tree import DecisionTreeRegressor
boston = load_boston()
regressor = DecisionTreeRegressor(random_state=0)
cross_val_score(regressor, boston.data, boston.target, cv=10, scoring = 'neg_mean_squared_error')
array([-16.41568627, -10.61843137, -18.30176471, -55.36803922,
-16.01470588, -44.70117647, -12.2148 , -91.3888 ,
-57.764 , -36.8134 ])
boston = load_boston()
boston
{'data': array([[6.3200e-03, 1.8000e+01, 2.3100e+00, ..., 1.5300e+01, 3.9690e+02,
4.9800e+00],
[2.7310e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9690e+02,
9.1400e+00],
[2.7290e-02, 0.0000e+00, 7.0700e+00, ..., 1.7800e+01, 3.9283e+02,
4.0300e+00],
...,
[6.0760e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
5.6400e+00],
[1.0959e-01, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9345e+02,
6.4800e+00],
[4.7410e-02, 0.0000e+00, 1.1930e+01, ..., 2.1000e+01, 3.9690e+02,
7.8800e+00]]),
'target': array([24. , 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15. ,
18.9, 21.7, 20.4, 18.2, 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6,
15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21. , 12.7, 14.5, 13.2,
13.1, 13.5, 18.9, 20. , 21. , 24.7, 30.8, 34.9, 26.6, 25.3, 24.7,
21.2, 19.3, 20. , 16.6, 14.4, 19.4, 19.7, 20.5, 25. , 23.4, 18.9,
35.4, 24.7, 31.6, 23.3, 19.6, 18.7, 16. , 22.2, 25. , 33. , 23.5,
19.4, 22. , 17.4, 20.9, 24.2, 21.7, 22.8, 23.4, 24.1, 21.4, 20. ,
20.8, 21.2, 20.3, 28. , 23.9, 24.8, 22.9, 23.9, 26.6, 22.5, 22.2,
23.6, 28.7, 22.6, 22. , 22.9, 25. , 20.6, 28.4, 21.4, 38.7, 43.8,
33.2, 27.5, 26.5, 18.6, 19.3, 20.1, 19.5, 19.5, 20.4, 19.8, 19.4,
21.7, 22.8, 18.8, 18.7, 18.5, 18.3, 21.2, 19.2, 20.4, 19.3, 22. ,
20.3, 20.5, 17.3, 18.8, 21.4, 15.7, 16.2, 18. , 14.3, 19.2, 19.6,
23. , 18.4, 15.6, 18.1, 17.4, 17.1, 13.3, 17.8, 14. , 14.4, 13.4,
15.6, 11.8, 13.8, 15.6, 14.6, 17.8, 15.4, 21.5, 19.6, 15.3, 19.4,
17. , 15.6, 13.1, 41.3, 24.3, 23.3, 27. , 50. , 50. , 50. , 22.7,
25. , 50. , 23.8, 23.8, 22.3, 17.4, 19.1, 23.1, 23.6, 22.6, 29.4,
23.2, 24.6, 29.9, 37.2, 39.8, 36.2, 37.9, 32.5, 26.4, 29.6, 50. ,
32. , 29.8, 34.9, 37. , 30.5, 36.4, 31.1, 29.1, 50. , 33.3, 30.3,
34.6, 34.9, 32.9, 24.1, 42.3, 48.5, 50. , 22.6, 24.4, 22.5, 24.4,
20. , 21.7, 19.3, 22.4, 28.1, 23.7, 25. , 23.3, 28.7, 21.5, 23. ,
26.7, 21.7, 27.5, 30.1, 44.8, 50. , 37.6, 31.6, 46.7, 31.5, 24.3,
31.7, 41.7, 48.3, 29. , 24. , 25.1, 31.5, 23.7, 23.3, 22. , 20.1,
22.2, 23.7, 17.6, 18.5, 24.3, 20.5, 24.5, 26.2, 24.4, 24.8, 29.6,
42.8, 21.9, 20.9, 44. , 50. , 36. , 30.1, 33.8, 43.1, 48.8, 31. ,
36.5, 22.8, 30.7, 50. , 43.5, 20.7, 21.1, 25.2, 24.4, 35.2, 32.4,
32. , 33.2, 33.1, 29.1, 35.1, 45.4, 35.4, 46. , 50. , 32.2, 22. ,
20.1, 23.2, 22.3, 24.8, 28.5, 37.3, 27.9, 23.9, 21.7, 28.6, 27.1,
20.3, 22.5, 29. , 24.8, 22. , 26.4, 33.1, 36.1, 28.4, 33.4, 28.2,
22.8, 20.3, 16.1, 22.1, 19.4, 21.6, 23.8, 16.2, 17.8, 19.8, 23.1,
21. , 23.8, 23.1, 20.4, 18.5, 25. , 24.6, 23. , 22.2, 19.3, 22.6,
19.8, 17.1, 19.4, 22.2, 20.7, 21.1, 19.5, 18.5, 20.6, 19. , 18.7,
32.7, 16.5, 23.9, 31.2, 17.5, 17.2, 23.1, 24.5, 26.6, 22.9, 24.1,
18.6, 30.1, 18.2, 20.6, 17.8, 21.7, 22.7, 22.6, 25. , 19.9, 20.8,
16.8, 21.9, 27.5, 21.9, 23.1, 50. , 50. , 50. , 50. , 50. , 13.8,
13.8, 15. , 13.9, 13.3, 13.1, 10.2, 10.4, 10.9, 11.3, 12.3, 8.8,
7.2, 10.5, 7.4, 10.2, 11.5, 15.1, 23.2, 9.7, 13.8, 12.7, 13.1,
12.5, 8.5, 5. , 6.3, 5.6, 7.2, 12.1, 8.3, 8.5, 5. , 11.9,
27.9, 17.2, 27.5, 15. , 17.2, 17.9, 16.3, 7. , 7.2, 7.5, 10.4,
8.8, 8.4, 16.7, 14.2, 20.8, 13.4, 11.7, 8.3, 10.2, 10.9, 11. ,
9.5, 14.5, 14.1, 16.1, 14.3, 11.7, 13.4, 9.6, 8.7, 8.4, 12.8,
10.5, 17.1, 18.4, 15.4, 10.8, 11.8, 14.9, 12.6, 14.1, 13. , 13.4,
15.2, 16.1, 17.8, 14.9, 14.1, 12.7, 13.5, 14.9, 20. , 16.4, 17.7,
19.5, 20.2, 21.4, 19.9, 19. , 19.1, 19.1, 20.1, 19.9, 19.6, 23.2,
29.8, 13.8, 13.3, 16.7, 12. , 14.6, 21.4, 23. , 23.7, 25. , 21.8,
20.6, 21.2, 19.1, 20.6, 15.2, 7. , 8.1, 13.6, 20.1, 21.8, 24.5,
23.1, 19.7, 18.3, 21.2, 17.5, 16.8, 22.4, 20.6, 23.9, 22. , 11.9]),
'feature_names': array(['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD',
'TAX', 'PTRATIO', 'B', 'LSTAT'], dtype='
boston.data.shape
(506, 13)
regressor = DecisionTreeRegressor(random_state=0)
cross_val_score(regressor, boston.data, boston.target, cv=10,scoring='neg_mean_squared_error' )
array([-16.41568627, -10.61843137, -18.30176471, -55.36803922,
-16.01470588, -44.70117647, -12.2148 , -91.3888 ,
-57.764 , -36.8134 ])
一维回归的图像绘制
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
rng = np.random.RandomState(1) #使用numpys设置随机数种子
x = np.sort(5*rng.rand(80, 1), axis=0) #sort对numpy进行排序
y = np.sin(x).ravel() #生成正弦曲线之后将y降维
y[::5] += 3*(0.5 - rng.rand(16)) #对标准的正弦结果y加上噪音
plt.figure()
plt.scatter(x, y, s=20, edgecolor='black', c='darkorange', label='data')
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WYcnlEqg-1589803106969)(output_101_1.png)]
regr_1 = DecisionTreeRegressor(max_depth=1)
regr_2 = DecisionTreeRegressor(max_depth=10)
regr_1.fit(x,y)
DecisionTreeRegressor(ccp_alpha=0.0, criterion='mse', max_depth=1,
max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort='deprecated',
random_state=None, splitter='best')
regr_2.fit(x,y)
DecisionTreeRegressor(ccp_alpha=0.0, criterion='mse', max_depth=10,
max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort='deprecated',
random_state=None, splitter='best')
np.arange(0.0, 5.0, 0.01)
array([0. , 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1 ,
0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2 , 0.21,
0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3 , 0.31, 0.32,
0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4 , 0.41, 0.42, 0.43,
0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5 , 0.51, 0.52, 0.53, 0.54,
0.55, 0.56, 0.57, 0.58, 0.59, 0.6 , 0.61, 0.62, 0.63, 0.64, 0.65,
0.66, 0.67, 0.68, 0.69, 0.7 , 0.71, 0.72, 0.73, 0.74, 0.75, 0.76,
0.77, 0.78, 0.79, 0.8 , 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87,
0.88, 0.89, 0.9 , 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98,
0.99, 1. , 1.01, 1.02, 1.03, 1.04, 1.05, 1.06, 1.07, 1.08, 1.09,
1.1 , 1.11, 1.12, 1.13, 1.14, 1.15, 1.16, 1.17, 1.18, 1.19, 1.2 ,
1.21, 1.22, 1.23, 1.24, 1.25, 1.26, 1.27, 1.28, 1.29, 1.3 , 1.31,
1.32, 1.33, 1.34, 1.35, 1.36, 1.37, 1.38, 1.39, 1.4 , 1.41, 1.42,
1.43, 1.44, 1.45, 1.46, 1.47, 1.48, 1.49, 1.5 , 1.51, 1.52, 1.53,
1.54, 1.55, 1.56, 1.57, 1.58, 1.59, 1.6 , 1.61, 1.62, 1.63, 1.64,
1.65, 1.66, 1.67, 1.68, 1.69, 1.7 , 1.71, 1.72, 1.73, 1.74, 1.75,
1.76, 1.77, 1.78, 1.79, 1.8 , 1.81, 1.82, 1.83, 1.84, 1.85, 1.86,
1.87, 1.88, 1.89, 1.9 , 1.91, 1.92, 1.93, 1.94, 1.95, 1.96, 1.97,
1.98, 1.99, 2. , 2.01, 2.02, 2.03, 2.04, 2.05, 2.06, 2.07, 2.08,
2.09, 2.1 , 2.11, 2.12, 2.13, 2.14, 2.15, 2.16, 2.17, 2.18, 2.19,
2.2 , 2.21, 2.22, 2.23, 2.24, 2.25, 2.26, 2.27, 2.28, 2.29, 2.3 ,
2.31, 2.32, 2.33, 2.34, 2.35, 2.36, 2.37, 2.38, 2.39, 2.4 , 2.41,
2.42, 2.43, 2.44, 2.45, 2.46, 2.47, 2.48, 2.49, 2.5 , 2.51, 2.52,
2.53, 2.54, 2.55, 2.56, 2.57, 2.58, 2.59, 2.6 , 2.61, 2.62, 2.63,
2.64, 2.65, 2.66, 2.67, 2.68, 2.69, 2.7 , 2.71, 2.72, 2.73, 2.74,
2.75, 2.76, 2.77, 2.78, 2.79, 2.8 , 2.81, 2.82, 2.83, 2.84, 2.85,
2.86, 2.87, 2.88, 2.89, 2.9 , 2.91, 2.92, 2.93, 2.94, 2.95, 2.96,
2.97, 2.98, 2.99, 3. , 3.01, 3.02, 3.03, 3.04, 3.05, 3.06, 3.07,
3.08, 3.09, 3.1 , 3.11, 3.12, 3.13, 3.14, 3.15, 3.16, 3.17, 3.18,
3.19, 3.2 , 3.21, 3.22, 3.23, 3.24, 3.25, 3.26, 3.27, 3.28, 3.29,
3.3 , 3.31, 3.32, 3.33, 3.34, 3.35, 3.36, 3.37, 3.38, 3.39, 3.4 ,
3.41, 3.42, 3.43, 3.44, 3.45, 3.46, 3.47, 3.48, 3.49, 3.5 , 3.51,
3.52, 3.53, 3.54, 3.55, 3.56, 3.57, 3.58, 3.59, 3.6 , 3.61, 3.62,
3.63, 3.64, 3.65, 3.66, 3.67, 3.68, 3.69, 3.7 , 3.71, 3.72, 3.73,
3.74, 3.75, 3.76, 3.77, 3.78, 3.79, 3.8 , 3.81, 3.82, 3.83, 3.84,
3.85, 3.86, 3.87, 3.88, 3.89, 3.9 , 3.91, 3.92, 3.93, 3.94, 3.95,
3.96, 3.97, 3.98, 3.99, 4. , 4.01, 4.02, 4.03, 4.04, 4.05, 4.06,
4.07, 4.08, 4.09, 4.1 , 4.11, 4.12, 4.13, 4.14, 4.15, 4.16, 4.17,
4.18, 4.19, 4.2 , 4.21, 4.22, 4.23, 4.24, 4.25, 4.26, 4.27, 4.28,
4.29, 4.3 , 4.31, 4.32, 4.33, 4.34, 4.35, 4.36, 4.37, 4.38, 4.39,
4.4 , 4.41, 4.42, 4.43, 4.44, 4.45, 4.46, 4.47, 4.48, 4.49, 4.5 ,
4.51, 4.52, 4.53, 4.54, 4.55, 4.56, 4.57, 4.58, 4.59, 4.6 , 4.61,
4.62, 4.63, 4.64, 4.65, 4.66, 4.67, 4.68, 4.69, 4.7 , 4.71, 4.72,
4.73, 4.74, 4.75, 4.76, 4.77, 4.78, 4.79, 4.8 , 4.81, 4.82, 4.83,
4.84, 4.85, 4.86, 4.87, 4.88, 4.89, 4.9 , 4.91, 4.92, 4.93, 4.94,
4.95, 4.96, 4.97, 4.98, 4.99])
x_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
#np.newaxis与reshape(-1,1)作用相同
y_1 = regr_1.predict(x_test)
y_2 = regr_2.predict(x_test)
plt.figure()
plt.scatter(x, y, s=20, edgecolor='black', c='darkorange', label='data')
plt.plot(x_test, y_1,color='cornflowerblue', label='max_depth=2', linewidth=2)
plt.plot(x_test, y_2,color='yellowgreen', label='max_depth=10', linewidth=2)
plt.xlabel('data')
plt.ylabel('target')
plt.title('Decision Tree Regression')
plt.legend()
plt.show()
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rCX6QMbi-1589803106975)(output_107_0.png)]
参考: