机器学习十大算法之决策树---ID3算法

对于决策树来说,主要有两种算法:ID3算法C4.5算法。C4.5算法是对ID3的改进。

 

Contents

 

     1. 决策树的基本认识

     2. ID3算法介绍

     3. 信息熵与信息增益

     4. ID3算法的Python实现

 

 

1. 决策树的基本认识

 

   决策树是一种依托决策而建立起来的一种树。在机器学习中,决策树是一种预测模型,代表的是一种对象属性与对象值之间的一种映射关系,每一个节点代表某个对象,树中的每一个分叉路径代表某个可能的属性值,而每一个叶子节点则对应从根节点到该叶子节点所经历的路径所表示的对象的值。决策树仅有单一输出,如果有多个输出,可以分别建立独立的决策树以处理不同的输出。

 

 

2. ID3算法介绍

 

   ID3算法是决策树的一种,它是基于奥卡姆剃刀原理的,即用尽量用较少的东西做更多的事。ID3算法Iterative Dichotomiser 3迭代二叉树3代,是Ross Quinlan发明的一种决策树算法,这个算法的基础就是上面提到的奥卡姆剃刀原理,越是小型的决策树越优于大的决策树,尽管如此,也不总是生成最小的树型结构,而是一个启发式算法

 

   在信息论中,期望信息越小,那么信息增益就越大,从而纯度就越高。ID3算法的核心思想就是以信息增益来度量属性的选择,选择分裂后信息增益最大的属性进行分裂。该算法采用自顶向下的贪婪搜索遍历可能的决策空间。

  • ID3算法要求特征必须离散化;
  • 信息增益可以用熵,而不是GINI系数来计算;
  • 选取信息增益最大特征,作为树的根节点;

 

3. 信息熵与信息增益

 

   在信息增益中,重要性的衡量标准就是看特征能够为分类系统带来多少信息,带来的信息越多,该特征越重要。

 

   是对不确定性的度量。香农引入了信息熵,将其定义为离散随机事件出现的概率,一个系统越是有序,信息熵就越低。所以信息熵可以被认为是系统有序化程度的一个度量。

   

 

   意思是一个变量的变化情况可能越多,那么它携带的信息量就越大。

       

   信息增益

 

   信息增益是针对一个一个特征而言的,就是看一个特征,系统有它和没有它时的信息量各是多少,两者的差值就是这个特征给系统带来的信息量,即信息增益

 

   

 

   其中为全部样本集合,是属性所有取值的集合,的其中一个属性值,中属性值为的样例集合,中所含样例数。

 

   在决策树的每一个非叶子结点划分之前,先计算每一个属性所带来的信息增益,选择最大信息增益的属性来划分,因为信息增益越大,区分样本的能力就越强,越具有代表性,很显然这是一种自顶向下的贪心策略。以上就是ID3算法的核心思想。

 

 

4. ID3算法的Python实现

 

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/6/21 8:33
# @Author  : Julia
# @Site    : 
# @File    : ID3.py
# @Software: PyCharm

import math
import operator


def calcShannonEnt(dataset):
    numEntries = len(dataset)
    labelCounts = {}
    for featVec in dataset:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1

    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * math.log(prob, 2)
    return shannonEnt


def CreateDataSet():
    '''dataset = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    labels = ['outlook', 'temperature','humidity','false']
    return dataset, labels'''


    lines_set = open('Dataset.txt').readlines()
    labelLine = lines_set[2];
    labels = labelLine.strip().split()
    lines_set = lines_set[4:11]
    dataSet = [];
    for line in lines_set:
        data = line.split();
        dataSet.append(data);
    return dataSet, labels

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)

    return retDataSet


def chooseBestFeatureToSplit(dataSet):
    numberFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0;
    bestFeature = -1;
    for i in range(numberFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] = 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    del (labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree


myDat, labels = CreateDataSet()
myTree = createTree(myDat, labels)
print myTree

运行结果如下:

{'outlook': {'overcast': 'Y', 'sunny': 'N', 'rain': {'windy': {'false': 'Y', 'true': 'N'}}}}

训练集和测试集

训练集:

    outlook    temperature    humidity    windy 
    ---------------------------------------------------------
    sunny     hot             high           false          N
    sunny     hot             high           true          N
    overcast  hot             high           false         Y
    rain       mild           high           false          Y
    rain        cool           normal       false          Y
    rain        cool           normal       true           N
   overcast  cool           normal       true          Y

测试集
 outlook    temperature    humidity    windy 
    ---------------------------------------------------------      
    sunny       mild           high           false          
    sunny       cool           normal       false         
    rain           mild           normal       false        
    sunny        mild           normal       true          
    overcast    mild            high           true          
    overcast    hot             normal      false         
    rain           mild           high           true 

你可能感兴趣的:(机器学习)