【scikit-learn】决策树

目录

  • 决策树
    • 决策树的分类
    • 相关概念
    • 决策树的构造
      • 特征选择:
      • 决策树的生成:
      • 决策树的裁剪
    • 决策树的优缺点
      • 优点
      • 缺点
  • 决策树生成算法
    • 以ID3算法为例
  • 实战

决策树

决策树学习采用的是自顶向下的递归方法,其基本思想是以信息熵为度量构造一颗熵值下降最快的树,到叶子节点处,熵值为0。是一种有监督学习。决策树呈树形结构,在分类问题中,表示基于特征对实例进行分类的过程。学习时,利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策模型进行分类。

决策树的分类

决策树可以分为两类,主要取决于它目标变量的类型。

  • 离散性决策树:离散性决策树,其目标变量是离散的,如性别:男或女等;
  • 连续性决策树:连续性决策树,其目标变量是连续的,如工资、价格、年龄等;

相关概念

  • 根结点(Root Node):它表示整个样本集合,并且该节点可以进一步划分成两个或多个子集。
  • 拆分(Splitting):表示将一个结点拆分成多个子集的过程。
  • 决策结点(Decision Node):当一个子结点进一步被拆分成多个子节点时,这个子节点就叫做决策结点。
  • 叶子结点(Leaf/Terminal Node):无法再拆分的结点被称为叶子结点。
  • 剪枝(Pruning):移除决策树中子结点的过程就叫做剪枝,跟拆分过程相反。
  • 分支/子树(Branch/Sub-Tree):一棵决策树的一部分就叫做分支或子树。
  • 父结点和子结点(Paren and Child Node):一个结点被拆分成多个子节点,这个结点就叫做父节点;其拆分后的子结点也叫做子结点。

决策树的构造

决策树的构造过程一般分为3个部分,分别是特征选择、决策树生产和决策树裁剪。

特征选择:

特征选择表示从众多的特征中选择一个特征作为当前节点分裂的标准,如何选择特征有不同的量化评估方法,从而衍生出不同的决策树,如ID3(通过信息增益选择特征)、C4.5(通过信息增益比选择特征)、CART(通过Gini指数选择特征)等。
目的(准则):使用某特征对数据集划分之后,各数据子集的纯度要比划分钱的数据集D的纯度高(也就是不确定性要比划分前数据集D的不确定性低)。

决策树的生成:

根据选择的特征评估标准,从上至下递归地生成子节点,直到数据集不可分则停止决策树停止生长。这个过程实际上就是使用满足划分准则的特征不断的将数据集划分成纯度更高,不确定行更小的子集的过程。对于当前数据集的每一次划分,都希望根据某个特征划分之后的各个子集的纯度更高,不确定性更小。

决策树的裁剪

决策树容易过拟合,一般需要剪枝来缩小树结构规模、缓解过拟合。

决策树的优缺点

优点

  • 具有可读性,如果给定一个模型,那么过呢据所产生的决策树很容易推理出相应的逻辑表达。
  • 分类速度快,能在相对短的时间内能够对大型数据源做出可行且效果良好的结果。

缺点

  • 对未知的测试数据未必有好的分类、泛化能力,即可能发生过拟合现象,此时可采用剪枝或随机森林。

决策树生成算法

特征选择依据
ID3 信息熵
C4.5 信息增益比
CART 基尼系数

以ID3算法为例

ID3算法最早是由罗斯昆(J. Ross Quinlan)于1975年在悉尼大学提出的一种分类预测算法,算法以信息论为基础,其核心是“信息熵”。ID3算法通过计算每个属性的信息增益,认为信息增益高的是好属性,每次划分选取信息增益最高的属性为划分标准,重复这个过程,直至生成一个能完美分类训练样例的决策树。

【scikit-learn】决策树_第1张图片
(1)数据准备:需要对数值型数据进行离散化

(2)ID3算法构建决策树:

  • 如果数据集类别完全相同,则停止划分
  • 否则,继续划分决策树:
    • 计算信息熵和信息增益来选择最好的数据集划分方法;
    • 划分数据集
    • 创建分支节点
    • 对每个分支进行判定是否类别相同,如果相同停止划分,不同按照上述方法进行划分。

实战

iris数据集

鸢尾花(iris)数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。

from sklearn.datasets import load_iris
import seaborn as sns
from pandas import plotting

# 加载数据集
data = load_iris() 
# 转换成.DataFrame形式
df = pd.DataFrame(data.data, columns = data.feature_names)
# 添加品种列
df['Species'] = data.target
# 查看数据集信息
df.info()
# 查看前5条数据
df.head()
# 查看各特征列的摘要信息
df.describe()

输出

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150 entries, 0 to 149
Data columns (total 5 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   sepal length (cm)  150 non-null    float64
 1   sepal width (cm)   150 non-null    float64
 2   petal length (cm)  150 non-null    float64
 3   petal width (cm)   150 non-null    float64
 4   Species            150 non-null    int64  
dtypes: float64(4), int64(1)
memory usage: 6.0 KB
sepal length (cm)	sepal width (cm)	petal length (cm)	petal width (cm)	Species
count	150.000000	150.000000	150.000000	150.000000	150.000000
mean	5.843333	3.057333	3.758000	1.199333	1.000000
std	0.828066	0.435866	1.765298	0.762238	0.819232
min	4.300000	2.000000	1.000000	0.100000	0.000000
25%	5.100000	2.800000	1.600000	0.300000	0.000000
50%	5.800000	3.000000	4.350000	1.300000	1.000000
75%	6.400000	3.300000	5.100000	1.800000	2.000000
max	7.900000	4.400000	6.900000	2.500000	2.000000

通过Violinplot和 Pointplot,分别从数据分布和斜率,观察各特征与品种之间的关系

# 设置颜色主题
antV = ['#1890FF', '#2FC25B', '#FACC14', '#223273', '#8543E0', '#13C2C2', '#3436c7', '#F04864'] 

# 绘制violinplot
f, axes = plt.subplots(2, 2, figsize=(8, 8), sharex=True)
sns.despine(left=True) # 删除上方和右方坐标轴上不需要的边框,这在matplotlib中是无法通过参数实现的
sns.violinplot(x='Species', y=df.columns[0], data=df, palette=antV, ax=axes[0, 0])
sns.violinplot(x='Species', y=df.columns[1], data=df, palette=antV, ax=axes[0, 1])
sns.violinplot(x='Species', y=df.columns[2], data=df, palette=antV, ax=axes[1, 0])
sns.violinplot(x='Species', y=df.columns[3], data=df, palette=antV, ax=axes[1, 1])

plt.show()

# 绘制pointplot
f, axes = plt.subplots(2, 2, figsize=(8, 8), sharex=True)
sns.despine(left=True)
sns.pointplot(x='Species', y=df.columns[0], data=df, color=antV[1], ax=axes[0, 0])
sns.pointplot(x='Species', y=df.columns[1], data=df, color=antV[1], ax=axes[0, 1])
sns.pointplot(x='Species', y=df.columns[2], data=df, color=antV[1], ax=axes[1, 0])
sns.pointplot(x='Species', y=df.columns[3], data=df, color=antV[1], ax=axes[1, 1])

plt.show()

# g = sns.pairplot(data=df, palette=antV, hue= 'Species')

【scikit-learn】决策树_第2张图片
【scikit-learn】决策树_第3张图片

plt.subplots(figsize = (10,8))
plotting.andrews_curves(df, 'Species', colormap='cool')

plt.show()

【scikit-learn】决策树_第4张图片

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

#load in the data
data = load_iris()
#convert to a dataframe
df = pd.DataFrame(data.data, columns = data.feature_names)
#create the species column
df['Species'] = data.target


#replace this with the actual names
target = np.unique(data.target)
target_names = np.unique(data.target_names)
targets = dict(zip(target, target_names))
df['Species'] = df['Species'].replace(targets)

#extract features and target variables
x = df.drop(columns="Species")
y = df["Species"]
#save the feature name and target variables
feature_names = x.columns
labels = y.unique()
#split the dataset
from sklearn.model_selection import train_test_split
X_train, test_x, y_train, test_lab = train_test_split(x,y,
                                                 test_size = 0.4,
                                                 random_state = 42)
model = DecisionTreeClassifier(max_depth =3, random_state = 42)
model.fit(X_train, y_train)      
#import relevant packages
from sklearn import tree
import matplotlib.pyplot as plt
#plt the figure, setting a black background
plt.figure(figsize=(30,10), facecolor ='g') # facecolor设置背景色
#create the tree plot
a = tree.plot_tree(model,
                   #use the feature names stored
                   feature_names = feature_names,
                   #use the class names stored
                   class_names = labels,
                   rounded = True,
                   filled = True,
                   fontsize=14)
#show the plot
plt.show()                                          

【scikit-learn】决策树_第5张图片

你可能感兴趣的:(python,决策树,机器学习,scikit-learn)