决策树Python 实现——决策树回归

原文以及代码来自于,本人对每段代码进行了详细注释,希望对初学者有用。
https://scikit-learn.org/stable/auto_examples/tree/plot_tree_regression.html
决策树回归 Decision Tree Regression
带有决策树的 1D 回归。
决策树用于拟合正数曲线和加噪声观测。因此,它学习接近主数曲线的局部线性回归。
我们可以看到,如果树的最大深度(由最大深度参数控制)设置得过高,则决策树会学习训练数据的细节,并从噪声中学习,即它们过度拟合。
决策树Python 实现——决策树回归_第1张图片

print(__doc__)

# Import the necessary modules and libraries
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt

# Create a random dataset
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))
'''
numpy.random.RandomState()是一个伪随机数生成器。那么伪随机数是什么呢? ()括号内是seed,确保不同电脑上产生相同的伪随机数
伪随机数是用确定性的算法计算出来的似来自[0,1]均匀分布的随机数序列。并不真正的随机,但具有类似于随机数的统计特征,如均匀性、独立性等。
运行一下下面两个
rng = np.random.RandomState(1)
x = rng.rand(4)
y = rng.rand(4)
rng = np.random.RandomState(1)
x = rng.rand(4)
rng = np.random.RandomState(1)
y = rng.rand(4)

.sort axis =0 每列从小到大排列
import numpy as np
x=np.array([[0,12,48],[4,18,14],[7,1,99]])
np.sort(x)
Out[61]: 
array([[ 0, 12, 48],
       [ 4, 14, 18],
       [ 1,  7, 99]])
np.sort(x, axis=0)
Out[62]: 
array([[ 0,  1, 14],
       [ 4, 12, 48],
       [ 7, 18, 99]])
np.sort(x, axis=1)
Out[63]: 
array([[ 0, 12, 48],
       [ 4, 14, 18],
       [ 1,  7, 99]]) 
       
.ravel() 将多维数组降位一维

y[::5] 从开始到结束,每隔5个数,第0个开始取出算起,更加详细的取数组元素方式可以参看下面链接
https://blog.csdn.net/sinat_34474705/article/details/74458605
'''

# Fit regression model
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X, y)
regr_2.fit(X, y)

# Predict
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)
'''
[:, np.newaxis], 将一行数据转换为一列数据,每行是一个一维输入,每列是一个feature,这个例子只有一个feature
'''


# Plot the results
plt.figure()
plt.scatter(X, y, s=20, edgecolor="black",
            c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue",
         label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()

你可能感兴趣的:(决策树回归)