Machine Learning with Scikit-Learn and Tensorflow 6.7 正则化超参数

书籍信息
Hands-On Machine Learning with Scikit-Learn and Tensorflow
出版社: O’Reilly Media, Inc, USA
平装: 566页
语种: 英语
ISBN: 1491962291
条形码: 9781491962299
商品尺寸: 18 x 2.9 x 23.3 cm
ASIN: 1491962291

系列博文为书籍中文翻译
代码以及数据下载:https://github.com/ageron/handson-ml

决策树几乎没有关于训练数据的假设(相反,线性回归假设训练数据呈现线性规律)。如果不加以限制,那么决策树会尝试精确拟合训练数据,导致过拟合(线性回归模型拟合能力有限,过拟合的可能性较小,但是可能会欠拟合)。

为了避免过拟合,我们需要对决策树加以限制。例如,我们可以通过max_depth限制决策树的最大深度(默认情况最大深度没有限制)。

决策树存在其他参数对模型加以限制:
(1)min_samples_split:结点分裂需要的最小样本数量
(2)min_samples_leaf:叶结点需要的最小样本数量
(3)min_weight_fraction_leaf:min_samples_leaf的比例形式
(4)max_leaf_nodes:最大叶结点数量
(5)max_features:分裂结点时考虑的最大特征数量
增加以min_起头的超参数或减少以max_起头的超参数可以避免模型的过拟合

注释:
避免过拟合的其他的思路包括首先训练决策树,然后进行剪枝。如果只有叶结点的结点其子结点对模型的效果提升不显著,那么这些子结点是不必要的。剪枝过程在所有不必要的结点修剪后完成。

下面的实例说明决策树超参数避免过拟合的作用。左边的决策树没有加以限制,右边的决策树通过min_samples_leaf=4加以限制。可以发现,左边的决策树存在过拟合的问题,右边的决策树效果比较理想。

from sklearn.datasets import make_moons
Xm, ym = make_moons(n_samples=100, noise=0.25, random_state=53)

deep_tree_clf1 = DecisionTreeClassifier(random_state=42)
deep_tree_clf2 = DecisionTreeClassifier(min_samples_leaf=4, random_state=42)
deep_tree_clf1.fit(Xm, ym)
deep_tree_clf2.fit(Xm, ym)

def plot_decision_boundary(clf, X, y, axes):
    x1s = np.linspace(axes[0], axes[1], 100)
    x2s = np.linspace(axes[2], axes[3], 100)
    x1, x2 = np.meshgrid(x1s, x2s)
    X_new = np.c_[x1.ravel(), x2.ravel()]
    y_pred = clf.predict(X_new).reshape(x1.shape)
    custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])
    plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap, linewidth=10)
    plt.plot(X[:, 0][y==0], X[:, 1][y==0], "yo")
    plt.plot(X[:, 0][y==1], X[:, 1][y==1], "bs")
    plt.axis(axes)
    plt.xlabel(r"$x_1$", fontsize=18)
    plt.ylabel(r"$x_2$", fontsize=18, rotation=0)

plt.figure(figsize=(11, 4))
plt.subplot(121)
plot_decision_boundary(deep_tree_clf1, Xm, ym, axes=[-1.5, 2.5, -1, 1.5])
plt.title("No restrictions", fontsize=16)
plt.subplot(122)
plot_decision_boundary(deep_tree_clf2, Xm, ym, axes=[-1.5, 2.5, -1, 1.5])
plt.title("min_samples_leaf = {}".format(deep_tree_clf2.min_samples_leaf), fontsize=14)
plt.show()

Machine Learning with Scikit-Learn and Tensorflow 6.7 正则化超参数_第1张图片

你可能感兴趣的:(机器学习)