统计学习方法学习笔记4——决策树模型

目录

1.概述

2.决策树的优缺点:

2.1.决策树的优点:

2.2.决策树的缺点:

3.决策树算法模型

3.1.特征选择的准则

3.2.树的生成

3.3.树的剪枝

4.决策树在sklearn中的类

4.1.分类

4.2.回归

5.书本案例sklearn实现


1.概述

决策树是一种用来分类和回归的无参监督学习方法,其目的是创建一种模型从数据特征中简单的决策规则来预测一个目标变量的值;

决策树的宗旨在于构建一个与训练数据集你和比较好的模型,同时保证模型的复杂度比较小;

决策树学习算法包括3部分:特征的选择、树的生成和树的剪枝。

统计学习方法学习笔记4——决策树模型_第1张图片

 

2.决策树的优缺点:

2.1.决策树的优点:

  1. 便于理解和解释。树的结构可以可视化出来。
  2. 训练需要的数据少。其他机器学习模型通常需要数据规范化,比如构建虚拟变量和移除缺失值,不过请注意,这种模型不支持缺失值。
  3. 由于训练决策树的数据点的数量导致了决策树的使用开销呈指数分布(训练树模型的时间复杂度是参与训练数据点的对数值)。
  4. 能够处理数值型数据和分类数据。其他的技术通常只能用来专门分析某一种变量类型的数据集。详情请参阅算法。
  5. 能够处理多路输出的问题。
  6. 使用白盒模型。如果某种给定的情况在该模型中是可以观察的,那么就可以轻易的通过布尔逻辑来解释这种情况。相比之下,在黑盒模型中的结果就是很难说明清 楚地。
  7. 可以通过数值统计测试来验证该模型。这对事解释验证该模型的可靠性成为可能。
  8. 即使该模型假设的结果与真实模型所提供的数据有些违反,其表现依旧良好。

2.2.决策树的缺点:

  1. 决策树模型容易产生一个过于复杂的模型,这样的模型对数据的泛化性能会很差。这就是所谓的过拟合.一些策略像剪枝、设置叶节点所需的最小样本数或设置数的最大深度是避免出现 该问题最为有效地方法。
  2. 决策树可能是不稳定的,因为数据中的微小变化可能会导致完全不同的树生成。这个问题可以通过决策树的集成来得到缓解。
  3. 在多方面性能最优和简单化概念的要求下,学习一棵最优决策树通常是一个NP难问题。因此,实际的决策树学习算法是基于启发式算法,例如在每个节点进 行局部最优决策的贪心算法。这样的算法不能保证返回全局最优决策树。这个问题可以通过集成学习来训练多棵决策树来缓解,这多棵决策树一般通过对特征和样本有放回的随机采样来生成。
  4. 有些概念很难被决策树学习到,因为决策树很难清楚的表述这些概念。例如XOR,奇偶或者复用器的问题。
  5. 如果某些类在问题中占主导地位会使得创建的决策树有偏差。因此,我们建议在拟合前先对数据集进行平衡。

3.决策树算法模型

3.1.特征选择的准则

  • 熵:衡量集合的不确定度,不确定度越大,熵值就越大。

$$ H(p)=H(X)=-\sum_{i=1}^{n}p_i\log p_i $$

熵值只与X的分布有关,与X的取值无关。

  • 条件熵:

随机变量(X,Y)的联合概率分布为:

$P(X=x_i,Y=y_j)=p_{ij}, i=1,2,\dots ,n;j=1,2,\dots ,m$

条件熵H(Y|X)表示在已知随机变量X的条件下随机变量Y的不确定性:

$$ H(Y|X)=\sum_{i=1}^np_iH(Y|X=x_i) $$ 其中$p_i=P(X=x_i),i=1,2,\dots ,n$

经验熵, 经验条件熵

     当熵和条件熵中的概率由数据估计(特别是极大似然估计)得到时,所对应的熵与条件熵分别称为经验熵和经验条件熵就是从已知              的数据计算得到的结果。

  • 信息增益(对应ID3算法):

特征A对训练数据集D的信息增益g(D|A),定义为集合D的经验熵H(D)与特征A给定的条件下D的经验条件熵H(D|A)之差:

$$ g(D,A)=H(D)-H(D|A) $$

熵与条件熵的差称为互信息.

决策树中的信息增益等价于训练数据集中的类与特征的互信息。

考虑ID这种特征, 本身是唯一的。按照ID做划分, 得到的经验条件熵为0, 会得到最大的信息增益。所以, 按照信息增益的准则来选择特征, 可能会倾向于取值比较多的特征。

  • 信息增益比(对应C4.5算法):

$$ g_R(D,A)=\frac{g(D,A)}{H_A(D)}\ H_A(D)=-\sum_{i=1}^n\frac{D_i}{D}log_2\frac{D_i}{D} $$

 

  • 基尼指数(对应CART剪枝树算法,但是此处的基尼指数不同于经济学中的基尼指数,但是经济学中的基尼指数也可以用来衡量机器学习中的一些问题,也具有一定的效果)

$$ Gini(p) = \sum_{k=1}^Kp_k(1-p_k)=1-\sum_{k=1}^Kp_k^2 $$

3.2.树的生成

  • ID3算法:

输入:训练数据集D, 特征集A,阈值$\epsilon$

输出:决策树T

算法流程:

  1. 如果D属于同一类C_k,T为单节点树,类C_k作为该节点的类标记,返回T
  2. 如果A是空集,置T为单节点树,实例数最多的类作为该节点类标记,返回T
  3. 计算g, 选择信息增益最大的特征A_g
  4. 如果A_g的信息增益小于$\epsilon$,T为单节点树,$D$中实例数最大的类C_k作为类标记,返回T
  5. A_g划分若干非空子集D_i
  6. D_i训练集,A-A_g为特征集,递归调用前面步骤,得到T_i,返回T_i
  • C4.5算法:

输入:训练数据集D, 特征集A,阈值$\epsilon$

输出:决策树T

算法流程:

  1. 如果D属于同一类C_k,T为单节点树,类C_k作为该节点的类标记,返回T
  2. 如果A是空集, 置T为单节点树,实例数最多的作为该节点类标记,返回T
  3. 计算g, 选择信息增益比最大的特征A_g
  4. 如果A_g信息增益比小于$\epsilon$,T为单节点树,D中实例数最大的类C_k作为类标记,返回T
  5. A_g划分若干非空子集D_i
  6. D_i训练集,A-A_g为特征集,递归调用前面步骤,得到T_i,返回T_i ID3和C4.5在生成上,差异只在准则的差异。
  • 最小二乘回归树生成算法:

输入:训练数据集D

输出:回归树f(x)

算法流程:

  1. 遍历变量j,对固定的切分变量j扫描切分点s,得到满足上面关系的(j,s):

    $$ \min\limits_{j,s}\left[\min\limits_{c_1}\sum\limits_{x_i\in R_1(j,s)}(y_i-c_1)^2+\min\limits_{c_2}\sum\limits_{x_i\in R_2(j,s)}(y_i-c_2)^2\right] $$

  2. 用选定的$(j,s)$, 划分区域并决定相应的输出值 :

    $$ R_1(j,s)={x|x^{(j)}\leq s}, R_2(j,s)={x|x^{(j)}> s} \ \hat{c}m= \frac{1}{N}\sum\limits{x_i\in R_m(j,s)} y_j, x\in R_m, m=1,2 $$

  3. 对两个子区域调用(1)(2)步骤, 直至满足停止条件

  4. 将输入空间划分为M个区域$R_1, R_2,\dots,R_M$,生成决策树:

    $$ f(x)=\sum_{m=1}^M\hat{c}_mI(x\in R_m) $$

3.3.树的剪枝

决策树损失函数:

树T的叶结点个数为|T|,t是树T的叶结点,该结点有$N_t$个样本点,其中k类的样本点有$N_{tk}$个,$H_t(T)$为叶结点$t$上的经验熵,$\alpha\geqslant 0$为参数,决策树学习的损失函数可以定义为:

$$ C_\alpha(T)=\sum_{i=1}^{|T|}N_tH_t(T)+\alpha|T| $$

$$ H_t(T)=-\sum_k\color{red}\frac{N_{tk}}{N_t}\color{black}\log \frac{N_{tk}}{N_t} $$

$$ C(T)=\sum_{t=1}^{|T|}\color{red}N_tH_t(T)\color{black}=-\sum_{t=1}^{|T|}\sum_{k=1}^K\color{red}N_{tk}\color{black}\log\frac{N_{tk}}{N_t} $$

这时有:

$$ C_\alpha(T)=C(T)+\alpha|T| $$

其中C(T)表示模型对训练数据的误差,|T|表示模型复杂度,参数$\alpha \geqslant 0$控制两者之间的影响。

  • 熵与概率的关系:

输入:生成算法生成的整个树T,参数$\alpha$

输出:修剪后的子树$T_\alpha$

算法流程:

  1. 计算每个结点的经验熵
  2. 递归的从树的叶结点向上回缩 假设一组叶结点回缩到其父结点之前与之后的整体树分别是$T_B$$T_A$,其对应的损失函数分别是$C_\alpha(T_A)$$C_\alpha(T_B)$,如果$C_\alpha(T_A)\leqslant C_\alpha(T_B)$则进行剪枝,即将父结点变为新的叶结点
  3. 返回2,直至不能继续为止,得到损失函数最小的子树$T_\alpha$

4.决策树在sklearn中的类


4.1.分类

  1. DecisionTreeClassifier能够在数据集上进行二分类也可以进行多分类,同其他训练方法一样,fit(x,y)既是训练模型,predict(x_test)既是预测的过程返回分类预测的结果;

4.2.回归

  1.  DecisionTreeRegressor:用来解决回归问题。在分类设置中,拟合方法将数组X和数组y作为参数,只有在这种情况下,y数组预期才是浮点值同其他训练方法一样,fit(x,y)既是训练回归模型,predict(x_test)既是预测的过程返回回归预测的结果。

5.书本案例sklearn实现

统计学习方法学习笔记4——决策树模型_第2张图片

统计学习方法学习笔记4——决策树模型_第3张图片

from sklearn.tree import DecisionTreeClassifier
from sklearn.externals.six import StringIO
import pydot
from sklearn import tree
import numpy as np
import matplotlib.pyplot as plt

def dataset():
    train_data = np.array([[1, 0, 0, 0],
                           [1, 0, 0, 1],
                           [1, 1, 0, 1],
                           [1, 1, 1, 0],
                           [1, 0, 0, 0],
                           [2, 0, 0, 0],
                           [2, 0, 0, 1],
                           [2, 1, 1, 1],
                           [2, 0, 1, 2],
                           [2, 0, 1, 2],
                           [3, 0, 1, 2],
                           [3, 0, 1, 1],
                           [3, 1, 0, 1],
                           [3, 1, 0, 2],
                           [3, 0, 0, 0]])
    train_label = np.array([0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0])
    return train_data, train_label.reshape((-1, 1))


def main(train_data, label_data, test_data):
    dec_model = DecisionTreeClassifier()
    dec_model.fit(train_data, label_data)
    pred = dec_model.predict(test_data)
    
    dot_data = StringIO()
    tree.export_graphviz(dec_model, out_file=dot_data) 
    graph = pydot.graph_from_dot_data(dot_data.getvalue())
    graph[0].write_dot('DecisionTree.dot')
    graph[0].write_png('DecisionTree.png')
    print(pred)
    return pred

# A={青年=1,否=0,是=1,一般=0}
# 输出:1=是,0=否
train_data, train_label = dataset()
test_data = np.array([[1, 0, 1, 0], [2, 1, 0, 1], [3, 0, 1, 0]])
print(train_data.shape, train_label.shape, test_data.shape)
pred = main(train_data, train_label, test_data)

参考:

1.https://github.com/SmirkCao/Lihang/blob/master/CH05/README.md

2.《统计学习方法》 李航

3.http://sklearn.apachecn.org/#/

你可能感兴趣的:(机器学习,机器学习算法,机器学习,统计学习方法)