决策树(Decision Tree,又称为判定树)算法是机器学习中常见的一类算法,是一种以树结构形式表达的预测分析模型。决策树属于监督学习(Supervised learning),根据处理数据类型的不同,决策树又为分类决策树与回归决策树。最早的的决策树算法是由Hunt等人于1966年提出,Hunt算法是许多决策树算法的基础,包括ID3、C4.5和CART等。
决策树的主要思想是根据已知数据构建一棵树,通过对待分类或回归的样本进行逐步的特征判断,最终将其分类或回归至叶子节点。
决策树学习的关键在于如何选择最优的划分属性,所谓的最优划分属性,对于二元分类而言,就是尽量使划分的样本属于同一类别,即“纯度”最高的属性。那么如何来度量特征(features)的纯度,这时候就要用到“信息熵(information entropy)”。先来看看信息熵的定义:假如当前样本集D中第k类样本所占的比例为Pk(k=1,2,3,......|y|),|y|为类别的总数(对于二元分类来说,|y|=2)。则样本集的信息熵为:
Ent(D)的值越小,则D的纯度越高。再来看一个概念信息增益(information gain),假定离散属性有V个可能的取值{a1,a2,......av},如果使用特征a来对数据集D进行划分,则会产生V个分支结点, 其中第v(小v)个结点包含了数据集D中所有在特征a上取值为av的样本总数,记为Dv。因此可以根据上面信息熵的公式计算出信息熵,再考虑到不同的分支结点所包含的样本数量不同,给分支节点赋予权重|Dv|/|D|,从这个公式能够发现包含的样本数越多的分支结点的影响越大(这样是信息增益的一个缺点,即信息增益对可取值数目多的特征有偏好(即该属性能取得值越多,信息增益,越偏向这个),待会后面会举例子说明,这个缺点可以用信息增益率来解决)。因此,能够计算出特征a对样本集D进行划分所获得的“信息增益”:
一般而言,信息增益越大,则表示使用特征a对数据集划分所获得的“纯度提升”越大。所以信息增益可以用于决策树划分属性的选择,其实就是选择信息增益最大的属性,ID3算法就是采用的信息增益来划分属性。
熵:物理意义是体系混乱程度的度量。
信息熵:表示事物不确定性的度量标准,可以根据数学中的概率计算,出现的概率就大,出现的机会就多,不确定性就小(信息熵小)。
(1)信息增益(ID3使用的划分方式)
假设训练数据集D和特征A,根据如下步骤计算信息增益:
第一步:计算数据集D的经验熵:
其中,|Ck|为第k类样本的数目,|D|为数据集D的数目。
第二步:计算特征A对数据集D的经验条件熵H(D|A):
第三步:计算信息增益:
一般而言,信息增益越大,则意味着使用属性A来进行划分所获得的“纯度提升” 越大。因此,我们可使用信息增益来进行决策树的划分属性选择。ID3决策树学习算法就是以信息增益为准则来选择划分属性的。
(2)信息增益率(C4.5所用划分准则)
特征A对于数据集D的信息增益比定义为:
其中,
称为数据集D关于A的取值熵。
增益率准则就可取值数目较少的属性有所偏好,因此,C4.5算法并不是直接选择增益率最大的候选划分属性,而是使用了一个启发式:先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的。
(3)基尼指数
分类问题中,假设有K个类,样本点属于k的概率Pk,则概率分布的基尼指数:
二分类问题:
对给定的样本集合D,基尼指数:
CART决策树使用“基尼指数”来选择划分属性。数据集D的纯度可用基尼值来度量,Gini(D)越小,则数据集的纯度越高。CART生成的是二叉树,计算量相对来说不是很大,可以处理连续和离散变量,能够对缺失值进行处理。
分类决策树可用于处理离散型数据,回归决策树可用于处理连续型数据。最常见的决策树算法包含ID3算法、C4.5算法以及CART算法。
剪枝:顾名思义就是给决策树 "去掉" 一些判断分支,同时在剩下的树结构下仍然能得到不错的结果。之所以进行剪枝,是为了防止或减少 "过拟合现象" 的发生,是决策树具有更好的泛化能力。
具体做法:去掉过于细分的叶节点,使其回退到父节点,甚至更高的节点,然后将父节点或更高的叶节点改为新的叶节点。
剪枝的两种方法:
注意:决策树的生成只考虑局部最优,相对地,决策树的剪枝则考虑全局最优。
简单易于理解和解释:决策树算法的模型可以直观地表示为树形结构,易于理解和解释,可以帮助人们做出可靠的决策。
适用于多类型的数据:决策树算法可以处理包含分类和数值型特征的数据,也可以用于处理多分类问题。
可以处理缺失值和异常值:决策树算法可以自动处理缺失值和异常值,而无需进行额外的数据预处理。
能够处理高维数据:决策树算法可以有效地处理具有大量特征的数据集,而不需要进行特征选择或降维。
计算复杂度较低:决策树算法的训练和预测的计算复杂度相对较低,尤其适用于大规模数据集。
容易过拟合:决策树算法容易在训练数据上过拟合,导致在新数据上的泛化能力较差。可以通过剪枝等技术来减少过拟合的风险。
对输入数据的微小变化敏感:决策树算法对输入数据的微小变化非常敏感,这可能导致不稳定的预测结果。
忽略属性之间的相关性:决策树算法假设每个特征在判断类别时都是独立的,忽略了特征之间的相关性。这在某些情况下可能导致模型的性能下降。
可能产生不一致的分类结果:决策树算法对于某些数据集,可能会生成不一致的分类结果,这是由于训练数据的微小变化可能导致完全不同的树结构。
难以处理连续性特征:决策树算法对于连续性特征的处理相对困难,需要进行离散化或采用其他处理方法。
决策树算法的优缺点在实际应用中可能会受到具体问题和数据集的影响,因此在选择算法时需要根据具体情况进行综合考虑。
决策树算法在机器学习和数据挖掘领域有广泛的应用场景,包括但不限于以下几个方面:
分类问题:决策树算法可以用于解决分类问题,例如根据一组特征将实例分为不同的类别。它在垃圾邮件过滤、疾病诊断、客户分类等领域都有应用。
回归问题:决策树算法也可以用于解决回归问题,例如根据一组特征预测一个连续的数值。它在房价预测、销量预测等领域可以发挥作用。
特征选择:决策树算法可以通过特征选择的方式确定哪些特征对于解决问题是最重要的。通过分析决策树的结构和特征重要性,可以帮助我们理解数据集并选择最相关的特征。
缺失值处理:决策树算法可以有效地处理带有缺失值的数据集。在决策树的构建过程中,可以根据可用的特征进行分割,从而处理缺失值问题。
异常检测:决策树算法可以用于检测异常值。通过构建决策树,可以将异常值划分为与正常值不同的分支,从而帮助我们发现异常情况。
推荐系统:决策树算法可以用于构建个性化的推荐系统。通过分析用户的特征和历史行为,决策树可以推断出用户的兴趣和偏好,进而提供个性化的推荐内容。
需要注意的是,决策树算法适用于特征具有离散值或连续值的数据集。同时,决策树算法也有其局限性,对于特征关联性较强的数据集可能表现不佳。因此,在选择算法时需要综合考虑数据集的特点和问题的需求。
(1)数据集介绍
Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)三种中的哪一品种。
数据内容:
sepal_len sepal_wid petal_len petal_wid label
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0
.. ... ... ... ... ...
145 6.7 3.0 5.2 2.3 2
146 6.3 2.5 5.0 1.9 2
147 6.5 3.0 5.2 2.0 2
148 6.2 3.4 5.4 2.3 2
149 5.9 3.0 5.1 1.8 2
(2)代码实现
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 构建决策树
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
# 预测
y_pred = clf.predict(X_test)
# 评估性能
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy: ", accuracy)
(3)分类结果
Accuracy: 1.0
Process finished with exit code 0
达到了100%的准确率
(1)代码实现
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
rng = np.random.RandomState(1)
X = np.sort(5*rng.rand(80,1),axis = 0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))
# Fit regression model
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X, y)
regr_2.fit(X, y)
# Predict
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)
# Plot the results
plt.figure()
plt.scatter(X, y, s=20, edgecolor="black",
c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue",
label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()
(2)结果展示