决策树模型学习笔记(案例分析、推算过程、python代码)

文章目录

  • 1 什么是决策树
  • 2 基尼指数
  • 3 实例分析
    • 3.1 手工计算
    • 3.2 代码实现

1 什么是决策树

决策树(decision tree)是一类常见的机器学习方法.以二分类任务为例,我们希望从给定训练数据集学得一个模型用以对新示例进行分类,这个把样本分类的任务,可看作对 “当前样本属于正类吗?” 这个问题的“决策”或“判定〞过程.顾名思义,决策树是基于树结构来进行决策的。
构建决策树常用的方法有:信息增益法基尼指数法(CART决策树),此次我主要学习了通过计算数据集的基尼指数来构建决策树。

2 基尼指数

CART决策树使用基尼指数来选择划分属性,采用如下计算公式,数据集D的纯度越高则基尼指数越小:
决策树模型学习笔记(案例分析、推算过程、python代码)_第1张图片
直观地来说Gini(D)反映了数据集D中随机抽取的两个样本,其类别标记不一致的概率。因此,Gini(D)越小,则数据集D的纯度越高。则被选中的该属性的加权基尼指数定义为:
决策树模型学习笔记(案例分析、推算过程、python代码)_第2张图片

3 实例分析

以下面这个数据集来简单的分析和进一步的了解决策树,如下图:
决策树模型学习笔记(案例分析、推算过程、python代码)_第3张图片
使用决策树来判断最后一条数据是属于什么类型,冰川水 还是 湖泊水?
因为数据量比较小,此处我先用手算来构建决策树,再进行代码演示

3.1 手工计算

决策树模型学习笔记(案例分析、推算过程、python代码)_第4张图片
决策树模型学习笔记(案例分析、推算过程、python代码)_第5张图片决策树模型学习笔记(案例分析、推算过程、python代码)_第6张图片
答案湖泊水

3.2 代码实现

import operator
from math import pow
def cal_gini_index(data):
    total_sample=len(data)
    if total_sample==0:
        return 0
    label_counts=label_unique_cnt(data)
    gini=0
    for label in label_counts:
        gini=gini+pow(label_counts[label],2)
    gini=1-float(gini)/pow(total_sample,2)
    return gini

def label_unique_cnt(data):
    label_unique_cnt={}
    for x in data:
        label=x[len(x)-1]
        if label not in label_unique_cnt:
            label_unique_cnt[label]=0
        label_unique_cnt[label]+=1
    return label_unique_cnt


def createDataSet1():    # 创造示例数据
    dataSet =[[0, 1, 1, 1, '冰川水'],
               [1, 0, 1, 1, '冰川水'],
               [0, 1, 0, 0, '冰川水'],
               [1, 1, 0, 0, '冰川水'],
               [0, 0, 0, 0, '湖泊水'],
               [1, 0, 0, 0, '湖泊水'],
               [0, 1, 1, 0, '湖泊水'],
               [1, 0, 1, 0, '湖泊水'], ]
        #特征
    labels=['Ca浓度', 'Mg浓度', 'Na浓度', 'Cl浓度']
    return dataSet,labels

def getBestFeature(data):
    label_num=len(data[0])-1
    bestGini=0
    currentgini=cal_gini_index(data)
    for index in range(0,label_num):
        newgini=0
        sample_label_num=[example[index] for example in data]
        valueset=set(sample_label_num)
        for value in valueset:
            subdata=split_tree(data,index,value)
            newgini=newgini+len(subdata)*cal_gini_index(subdata)/len(data)
        gaingini=currentgini-newgini
        if gaingini>bestGini:
            bestGini=gaingini
            bestFeatureIndex=index
    return bestFeatureIndex

def majorcnt(classlist):
    classValue=set(classlist)
    bestValueNum=0
    bestClassValue=None
    for value in classValue:
        valueNum=classlist.count(value)
        if valueNum>=bestValueNum:
            bestValueNum=valueNum
            bestClassValue=value
    #print(bestClassValue)
    return bestClassValue


def build_tree(data,label):
    datalabel=label[:]
    classlist=[example[-1] for example in data]
    #print(classlist)
    if classlist.count(classlist[0])==len(classlist):
        return classlist[0]
    if len(data[0])==1:
        return majorcnt(classlist)
    bestFeature=getBestFeature(data)
    #print(bestFeature)
    bestFeatureLabel=datalabel[bestFeature]
    mytree={bestFeatureLabel:{}}
    del(datalabel[bestFeature])
    bestFeatureValue=[sample[bestFeature] for sample in data]
    bestFeatureValueSet=set(bestFeatureValue)
    for value in bestFeatureValueSet:
        sublabel=datalabel[:]
        mytree[bestFeatureLabel][value]=build_tree(split_tree(data,bestFeature,value),sublabel)
    return mytree


def split_tree(data,axis,value):
    subdata=[]
    for sample in data:
        if sample[axis]==value:
            subdata1=sample[:axis]
            subdata1.extend(sample[axis+1:])
            subdata.append(subdata1)
    return subdata

def makeClassDecision(sample,tree,labels):
    keyoftree="".join(tree.keys())
    print(str(keyoftree))
    indexoffea=labels.index(keyoftree)
    tree=tree[keyoftree]
    #print(tree)
    if tree[sample[indexoffea]]=='冰川水'or tree[sample[indexoffea]]=='湖泊水':
        return tree[sample[indexoffea]]
    else:
        tree=tree[sample[indexoffea]]
        return makeClassDecision(sample,tree,labels)



if __name__=='__main__':
    dataSet, labels=createDataSet1()  # 创造示列数据
    #print(labels)
    mytree=build_tree(dataSet,labels)
    #print(labels)
    print(mytree)  # 输出决策树模型结果
    sample= [0,1,1,0]
    #print(labels[3])
    t=makeClassDecision(sample,mytree,labels)
    print(t)

运行结果:

决策树从上到下节点为:
{'Cl浓度': {0: {'Mg浓度': {0: '湖泊水', 1: {'Na浓度': {0: '冰川水', 1: '湖泊水'}}}}, 1: '冰川水'}}
Cl浓度
Mg浓度
Na浓度
测试数据为:
湖泊水

你可能感兴趣的:(决策树,python,学习,数据挖掘)