决策树

2019.5.11
基本流程

决策树(decision tree)是一类常见的机器学习方法,是基于树的结构来进行决策的,在一系列给定条件下,对各个属性进行测试,分流到不同的分支上,最后分流到决策树的叶节点上得出最终结果,其基本流程遵循分而治之策略。

如下图就是一个西瓜问题的决策树
西瓜问题的一颗决策树
决策树学习本质上是从训练集中归纳出一组分类规则,与训练集数据不相矛盾的决策树(即能对训练集正确分类的决策树)可能有多个,也可能一个也没有。我们需要的是一个与训练数据矛盾较小的决策树,同时具有很好的泛化能力。

决策树创建分支的伪代码函数createBranch()如下所示:

检测数据集中的每个子项是否属于同一类:
  If so return 类标签
  Else 
    寻找划分数据集的最好特征
    划分数据集
    创建分支结点
      for 每个划分的子集
        调用函数createBranch并增加返回结果到分支结点中
    return 分支结点

常见的决策树算法有很多,最著名的代表是ID3[Quinkan, 1979, 1986]、C4.5[Quinlan, 1993]和CART[Breiman et al., 1984]。

ID3

由决策树创建分支的流程可以看出,决策树学习的关键是如何选择划分数据集的最好特征。一般而言,我们希望决策树的分支结点所包含的样本尽可能属于同一类别,即结点的纯度越来越高。ID3算法的核心是在决策树的各个节点上应用信息增益准则选择划分数据集的最好特征,递归地构建决策树,ID3相当于用极大似然法进行概率模型的选择。
在信息论或数字图像处理等课程中应该都有学习过信息熵的概念,假定当前样本集合中第类样本所占的比例为,则D的信息熵定义为
的值越小,则的纯度越高。计算信息熵的代码如下

from math import log
def ent(data):
    num = len(data)
    label = {}
    for e in data:
        elabel = e[-1] #最后一项是标签
        if elabel not in label.keys():
            label[elabel] = 0
        label[elabel] += 1
    dataent= 0.0
    for key in label:
        prop = float(label[key]) / num
        dataent -= prop * log(prob, 2)
    return dataent

为了测试以上代码,声明一个创建数据集的函数以便后续使用

def createDataSet():
    dataSet = [[1, 1, 'yes'],
              [1, 1, 'yes'],
              [1, 0, 'no'],
              [0, 1, 'no'],
              [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

测试以上代码

mydata,labels = createDataSet()
ent(mydata)

输出0.9709505944546686
修改mydata的类别,增加一个分类即增加了其他的信息

mydata[0][-1]='maybe'
ent(mydata)

输出1.3709505944546687
假定离散属性有个可能的取值{},若使用对进行划分,则会产生个分支结点,其中第个分支结点包含了中所有在属性上取值为的样本,记为,考虑到不同分支结点所含样本数不同,赋予权重,于是可计算出用属性对划分所获得的信息增益(在划分数据集前后信息发生的变化)一般来说,信息增益越大,则意味着使用属性来进行划分所获得的纯度提升越大,因此,我们可用信息增益来选择决策树的划分属性。即选择来作为划分数据集的最好特征。选择出该特征的代码如下

def split(data):
    num = len(data[0]) - 1 #减去标签列,剩余的特征数量
    dataent = ent(data) #该数据集的信息熵
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(num):
        featList = [e[i] for e in data] #某一特征列
        uniqueVals = set(featList) #获取唯一值,从列表中创建集合是python中得到列表唯一元素值的最快方法
        newent = 0.0
        for e in uniqueVals:
            subData = splitfeat(data, i, e) #第i个特征为e的子集
            w = len(subData) / float(len(data)) #分支结点的权重
            newent += w * ent(subData)
        infoGain = dataent - newent
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature
def splitfeat(data, feat, value):
    newdata = []
    for e in data:
        if e[feat] == value:
            reducede = e[:feat]
            reducede.extend(e[feat + 1:])
            newdata.append(reducede)
    return newdata

在实际中,有可能所有特征都相同的实例标签并不相同,那么这时要如何确定标签呢?
我们可以采取投票的方式,将多数实例对应的标签作为最终结果,代码如下

import operator
def majorityclass(classList):
    classCount = {}
    for e in classList:
        if e not in classCount.keys():
            classCount[e] = 0
        classCount[e] += 1
    sortedClassCount = sorted(classCount.iteriterms(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]

做好了准备工作,下面就可以开始创建树了

def createTree(data, labels):
    classList = [e[-1] for e in data]
    if classList.count(classList[0]) == len(classList):
        return classList[0] #类别相同时停止继续划分
    if len(data[0]) == 1:
        return majorityclass(data) #如果已经没有特征了,选择多数实例对应的标签
    bestFeat = featureChoose(data) #选取最佳的划分特征
    bestFeatLabel = labels[bestFeat] #最佳特征
    decisionTree = {bestFeatLabel:{}} 
    del(labels[bestFeat]) #从待选特征列表里删除最佳特征
    featValues = [e[bestFeat] for e in data] #最佳特征列
    uniqueValue = set(featValues) #最佳特征可能的取值
    for value in uniqueValue:
        subLabels = labels[:] #python中函数参数是列表类型时,参数是按照引用方式传递的,为了保证每次调用creatTree函数时不改变原始列表内容,使用新变量代替原始列表
        decisionTree[bestFeatLabel][value] = createTree(splitfeat(data, bestFeat, value), subLabels) #删除最佳特征列和新的可选特征列表
    return decisionTree

还是用之前的数据集,结果如下

data, labels = createDataSet()
data[0][-1] = 'maybe'
decisionTree = createTree(data, labels)
decisionTree

{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'maybe'}}}}
可以使用Matplotlib注解绘制树形图,因为这不是我们关注的重点,不做过多介绍。

除了ID3算法,还有引入信息增益比的C4.5算法,以及引入剪枝的CART算法,后续有时间再进行介绍~
Note
  1. 对于TypeError: unhashable type: 'list'是由于list的内容是可变的,不可hash,因此不能作为dict的keys,即使list中只有单一变量也不行。
  2. Python 字典(Dictionary) items() 函数以列表返回可遍历的(键, 值) 元组数组,用法为dict.items()。
  3. sorted函数:

sort 与 sorted 区别:
sort 是应用在 list 上的方法,sorted 可以对所有可迭代的对象进行排序操作。
list 的 sort 方法返回的是对已经存在的列表进行操作,无返回值,而内建函数 sorted 方法返回的是一个新的 list,而不是在原来的基础上进行的操作。

用法:sorted(iterable, cmp=None, key=None, reverse=False)
iterable -- 可迭代对象。
cmp -- 比较的函数,这个具有两个参数,参数的值都是从可迭代对象中取出,此函数必须遵守的规则为,大于则返回1,小于则返回-1,等于则返回0。
key -- 主要是用来进行比较的元素,只有一个参数,具体的函数的参数就是取自于可迭代对象中,指定可迭代对象中的一个元素来进行排序。
reverse -- 排序规则,reverse = True 降序 , reverse = False 升序(默认)。

reference

1.机器学习 周志华
2.统计学习方法 李航
3.机器学习实战 Peter Harrington
4.unhashable type:'list'什么意思
5.Python 字典(Dictionary) items()方法
6.Python sorted() 函数

你可能感兴趣的:(决策树)