决策树是一种常见的机器学习算法,它的原理就是以事物的属性为节点,属性的不同取值为分支构造一棵多叉树。初始的多叉树可以学习到训练集上的事物的所有属性,我们通过相关的方法对这棵树进行剪枝和层次调整后,使之对于训练集中未出现过的新数据具有一定的分类能力,即泛化。
由1中对决策树的描述,我们可以知道生成决策树的核心问题就是分支的选择,即每一步我们应该选择待分类事物的哪一个属性作为分支属性,我们应该把对事物影响大的属性尽可能放在树的高层,这是因为越往下走剩下的样本就会越少对事物影响大的属性应该在树中对大多数样本起到分类作用。基于不同的分支选择方法,决策树主要有三类,即ID3算法,C4.5算法和CART算法。
ID3中使用的信息增益偏向取值较多的属性,如果有“编号”这类属性,那ID3算法会把编号作为最优属性,这很荒诞,C4.5算法就使用增益率代替增益,解决这种问题。增益率公式如下:
此外,为了避免偏向于取值数目少的属性,C4.5算法并不是直接选取增益率最大的属性进行分支,而是启发式的,先找出信息增益高于平均值的属性,再从中选出增益率最高的作为分支属性。
from numpy.lib.arraypad import pad
from sklearn.datasets import load_iris
import pandas as pd
from pandas import plotting
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn import tree
#加载数据集
data = load_iris()
#转换成DataFrame形式
df = pd.DataFrame(data.data, columns=data.feature_names)
#添加品种列
df['Species'] = data.target
#查看数据集信息
df.info()
#查看前5条数据
df.head()
#查看各特征列的摘要信息
df.describe()
#设置颜色主题
antV = ['#1890FF', '#2FC25B', '#FACC14', '#223273', '#8543E0', '#13C2C2', '#3436c7', '#F04864']
# 绘制violinplot
f, axes = plt.subplots(2, 2, figsize=(8, 8), sharex=True)
sns.despine(left=True) # 删除上方和右方坐标轴上不需要的边框,这在matplotlib中是无法通过参数实现的
sns.violinplot(x='Species', y=df.columns[0], data=df, palette=antV, ax=axes[0, 0])
sns.violinplot(x='Species', y=df.columns[1], data=df, palette=antV, ax=axes[0, 1])
sns.violinplot(x='Species', y=df.columns[2], data=df, palette=antV, ax=axes[1, 0])
sns.violinplot(x='Species', y=df.columns[3], data=df, palette=antV, ax=axes[1, 1])
plt.show()
# 绘制pointplot
f, axes = plt.subplots(2, 2, figsize=(8, 6), sharex=True)
sns.despine(left=True)
sns.pointplot(x='Species', y=df.columns[0], data=df, color=antV[1], ax=axes[0, 0])
sns.pointplot(x='Species', y=df.columns[1], data=df, color=antV[1], ax=axes[0, 1])
sns.pointplot(x='Species', y=df.columns[2], data=df, color=antV[1], ax=axes[1, 0])
sns.pointplot(x='Species', y=df.columns[3], data=df, color=antV[1], ax=axes[1, 1])
plt.show()
# g = sns.pairplot(data=df, palette=antV, hue= 'Species')
# 安德鲁曲线
plt.subplots(figsize = (8,6))
plotting.andrews_curves(df, 'Species', colormap='cool')
plt.show()
# 加载数据集
data = load_iris()
# 转换成.DataFrame形式
df = pd.DataFrame(data.data, columns = data.feature_names)
# 添加品种列
df['Species'] = data.target
# 用数值替代品种名作为标签
target = np.unique(data.target)
target_names = np.unique(data.target_names)
targets = dict(zip(target, target_names))
df['Species'] = df['Species'].replace(targets)
# 提取数据和标签
X = df.drop(columns="Species")
y = df["Species"]
feature_names = X.columns
labels = y.unique()
X_train, test_x, y_train, test_lab = train_test_split(X,y,
test_size = 0.4,
random_state = 42)
model = DecisionTreeClassifier(max_depth =3, random_state = 42)
model.fit(X_train, y_train)
# 以文字形式输出树
text_representation = tree.export_text(model)
print(text_representation)
# 用图片画出
plt.figure(figsize=(30,10), facecolor ='g') #
a = tree.plot_tree(model,
feature_names = feature_names,
class_names = labels,
rounded = True,
filled = True,
fontsize=14)
plt.show()
优缺点
优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题
适用数据类型:数值型和标称型
https://github.com/datawhalechina/machine-learning-toy-code/blob/main/ml-with-sklearn/DecisionTree/DecisionTree.ipynb
https://mp.weixin.qq.com/s/kxAuVAhnimskmT667JjrFA
https://blog.csdn.net/baidu_38406307/article/details/102879578?spm=1001.2014.3001.5501