【常见决策树算法逻辑理解以及代码实现(5)】CART (代码实现,包含绘图,西瓜书示例)

使用的面向对象方式编写,主要类是Cart类,直接传入数据和属性集合,然后draw就可以

运行结果如下(每次运行属性值顺序可能会不同,由于hash问题,不用管,结果是一样的)

全部代码可下载项目https://gitee.com/TomCoCo/mLearn.git

【常见决策树算法逻辑理解以及代码实现(5)】CART (代码实现,包含绘图,西瓜书示例)_第1张图片

这里是代码,有完整的注释,可以直接运行如上图

核心方法 createTree

import math
import matplotlib.pyplot as plt
import copy

D = [
['青绿','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['乌黑','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','蜷缩','沉闷','清晰','凹陷','硬滑','是'],
['浅白','蜷缩','浊响','清晰','凹陷','硬滑','是'],
['青绿','稍蜷','浊响','清晰','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','稍糊','稍凹','软粘','是'],
['乌黑','稍蜷','浊响','清晰','稍凹','硬滑','是'],
['乌黑','稍蜷','沉闷','稍糊','稍凹','硬滑','否'],
['青绿','硬挺','清脆','清晰','平坦','软粘','否'],
['浅白','硬挺','清脆','模糊','平坦','硬滑','否'],
['浅白','蜷缩','浊响','模糊','平坦','软粘','否'],
['青绿','稍蜷','浊响','稍糊','凹陷','硬滑','否'],
['浅白','稍蜷','沉闷','稍糊','凹陷','硬滑','否'],
['乌黑','稍蜷','浊响','清晰','稍凹','软粘','否'],
['浅白','蜷缩','浊响','模糊','平坦','硬滑','否'],
['青绿','蜷缩','沉闷','稍糊','稍凹','硬滑','否']
]
A = ['色泽','根蒂','敲声','纹理','脐部','触感','好瓜']

class Cart:
    # 数据
    data = None
    # 属性集合
    attributes = None
    # 属性集合 (与下标关系),去除最后的类型判定列
    attributesAndIndex = None
    # 属性下标 (与属性可能的取值),去除最后的类型判定列
    attributesIndexAndValue = None
    # 根节点
    root = None

    def __init__(self,data,attributes):
        self.data = data       
        self.attributes = attributes
        self.attributesAndIndex = Cart.getAttributesAndIndex(attributes)
        self.attributesIndexAndValue = Cart.getAttributesAndValue(data,attributes)
        
    def draw(self):
        self.createTree(self.root,self.data,self.attributesAndIndex,None)
        tree = Tree(self.root)
        tree.drawTree()

    # attributesAndIndex 不是类的那个属性了,这个引用会在递归的过程中长度被削减
    def createTree(self,node,data,attributesAndIndex,desc):
        # 创建节点
        newNode = Node()
        # 如果传入了desc,写入
        if(desc is not None):
            newNode.desc = desc
        if node is None:
            self.root = newNode
        else:
            node.addChild(newNode)

        # 如果data中的样本属于同一类别,那么将newNode标记为C类叶节点.返回
        kMap = Cart.getKMap(data)
        if len(kMap) == 1:
            newNode.name = next(iter(kMap.keys()))
            return

        # 如果属性列表是空集,或D在A上的取值相同
        if Cart.checkDA(data,attributesAndIndex):
            # 获取数据集中较多的那个类别
            newNode.name = Cart.getMoreType(data)
            return

        # 获取最优属性下标
        bestIndex = Cart.getMinGiniIndexStrict(data,attributesAndIndex)
        newNode.name = self.attributes[bestIndex]

        # 遍历最优属性的每一个属性值(从原始数据中)
        aStart = self.attributesIndexAndValue[bestIndex]
        # 按最优属性拆分数据,为多个子集
        V = Cart.splitDataByIndex(data,bestIndex)
        for aStartV in aStart:
            dv = V.get(aStartV)
            # 如果dv是空集,那么以获取数据集中较多的那个类别建立子节点
            if dv is None or len(dv) == 0:
                newLeaf = Node()
                newLeaf.name = Cart.getMoreType(data)
                newLeaf.desc = aStartV
                newNode.addChild(newLeaf)
            else:
                 # 将A抛去选中的那个
                Anew = copy.deepcopy(attributesAndIndex)
                for index,item in enumerate(Anew):
                    if next(iter(item.values())) == bestIndex:
                        Anew.pop(index)
                        break
                self.createTree(newNode,dv,Anew,aStartV)


    # 检查D在a上的取值是否完全相同(data的所有数据不一定类别完全相同,只要在a上(可能多个)的取值完全相同即可)
    # 也就是指定类型的那些属性值完全一致,例如下文中的根蒂,脐部.在data上均没有区别都是稍蜷和稍凹
    # 例如 A['根蒂','脐部'] D : 
    #['青绿','稍蜷','浊响','清晰','稍凹','软粘','是'],
    #['乌黑','稍蜷','浊响','稍糊','稍凹','软粘','是'],
    #['乌黑','稍蜷','浊响','清晰','稍凹','硬滑','是'],
    #['乌黑','稍蜷','沉闷','稍糊','稍凹','硬滑','否'],
    @staticmethod
    def checkDA(data,attributesAndIndex):
        if len(attributesAndIndex) == 0:
            return True
        for item in attributesAndIndex:
            # 当前的属性值
            nowAttributesValue = None
            # 获取属性下标 i
            aIndex = next(iter(item.values()))
            for dLine in data:
                if nowAttributesValue == None:
                    nowAttributesValue = dLine[aIndex]
                elif nowAttributesValue != dLine[aIndex]:
                    return False
        return True;


    # 将属性附加一个指向数据的哪一个列,删除最后的类别信息,只保留属性信息
    # ['色泽','根蒂'] -> [{'色泽':0},{'根蒂':1}]
    @staticmethod
    def getAttributesAndIndex(attributes):
        attributesAndIndex = list()
        for index,attribute in enumerate(attributes):
            attributesAndIndex.append({attribute:index})
        return attributesAndIndex[:len(attributesAndIndex) - 1]
    

    @staticmethod
    def getAttributesAndValue(data,attributes):
        attributesAndValue = dict()
        for dLine in data:
            for i in range(len(attributes) - 1):
                v = attributesAndValue.get(i)
                if v == None:
                    v = set()
                    attributesAndValue[i] = v
                v.add(dLine[i])
        return attributesAndValue

    # 获取data数据集中,基尼指数最小的那个属性的下标,
    # attributesAndIndex的不需要维度必须和data[]的维度一致.使用attributesAndIndex指定的下标查询.不忽略最后一个
    @staticmethod
    def getMinGiniIndexStrict(data,attributesAndIndex):
        minGiniIndex = None
        minIndex = None
        for item in attributesAndIndex:
            # 获取属性名 v ,属性下标 i
            aName = next(iter(item.keys()))
            aIndex = next(iter(item.values()))
            giniIndex = Cart.getGiniIndex(data,aIndex)
            if minGiniIndex == None or giniIndex < minGiniIndex:
                minGiniIndex = giniIndex
                minIndex = aIndex
            print("第" , aIndex ,"列的属性",aName,"的基尼指数为:" , giniIndex)
        print("第" , minIndex ,"列的属性",aName,"的基尼指数最小为:" , minGiniIndex ,";为最优划分属性")
        return minIndex

    # 获取data数据集中,基尼指数最小的那个属性的下标,attributes的维度必须和data[]的维度一致
    @staticmethod
    def getMinGiniIndex(data,attributes):
        # attributes 的最后一列是类别,不计入
        attributesSize = len(attributes) - 1
        i = 0
        minGiniIndex = None
        minIndex = None
        while i < attributesSize:
            giniIndex = Cart.getGiniIndex(data,i)
            if minGiniIndex == None or giniIndex < minGiniIndex:
                minGiniIndex = giniIndex
                minIndex = i
            print("第" , i ,"列的属性",attributes[i],"的基尼指数为:" , giniIndex)
            i += 1
        print("第" , minIndex ,"列的属性",attributes[minIndex],"的基尼指数最小为:" , minGiniIndex ,";为最优划分属性")
        return minIndex

    
    # 获取基尼指数 data的最后一列认为是类型
    @staticmethod
    def getGiniIndex(data,attributesIndex):
        # 首先按照属性下标(attributesIndex)拆分出多个子集,
        V = Cart.splitDataByIndex(data,attributesIndex)
        # 总数据大小
        dSize = len(data)
        # 计算每个子集的Gini值,加权求和
        rs = 0
        for Dv in V.values():
            dvSize = len(Dv)
            dvGini = Cart.getGini(Cart.getKMap(Dv),dvSize)
            rs += (dvSize/dSize) * dvGini
        return rs
    
    #按照属性下标(attributesIndex)拆分出多个子集,子集的集合为:V,每个子集为Dv
    @staticmethod
    def splitDataByIndex(data,attributesIndex):
        V = dict()
        for dLine in data:
            attribute = dLine[attributesIndex]
            Dv = V.get(attribute)
            if Dv is None:
                Dv = list()
                V[attribute] = Dv
            Dv.append(dLine)
        return V


    # 获取基尼值
    @staticmethod
    def getGini(kMap,dSize):
        rs = 0
        for item in kMap.values():
            pk = (item/dSize)
            rs += pk * pk
        return 1 - rs

    @staticmethod
    def getMoreType(data):
        kMap = Cart.getKMap(data)
        maxCount = -1
        maxName = None
        for key in kMap.keys():
            if kMap.get(key) > maxCount:
                maxCount = kMap.get(key)
                maxName = key
        return maxName


    # 获取指定集合种类型->数量的映射
    @staticmethod
    def getKMap(data):
        kMap = dict()
        for dLine in data:
            # 获取分类值k
            k = dLine[len(dLine) - 1]
            # 获取当前k出现的次数
            kNum = kMap.get(k)
            if  kNum is None:
                kMap[k] = 1
            else:
                kMap[k] = kNum + 1
        return kMap

############################### 节点类 #####################################
class Node:
    name = "未命名节点"
    # 线描述,没有的是根节点
    desc = ""
    # 子节点,长度为0的是叶节点
    children = []

    def __init__(self):
        self.children = []        

    def addChild(self, node):
        self.children.append(node)

############################### 画树类 #####################################
class Tree:
    root = None
    # 定义决策节点以及叶子节点属性:boxstyle表示文本框类型,sawtooth:锯齿形;circle圆圈,fc表示边框线粗细
    decisionNode = dict(boxstyle="round4", fc="0.5")
    leafNode = dict(boxstyle="circle", fc="0.5")
    # 定义箭头属性
    arrow_args = dict(arrowstyle="<-")
    # 步长,每个节点的横线和纵向距离
    step = 3

    # 当前深度
    deep = 0
    # 当前深度的个数
    nowDeepIndex = 0
    # 当前深度和这个深度的当前节点数量的映射
    deepIndex = dict()

    def __init__(self, root):
        self.root = root

        # 设定坐标范围
        plt.xlim(0, 20)
        plt.ylim(-18, 0)
        # 设定中文支持
        plt.rcParams["font.sans-serif"] = ["SimHei"]
        plt.rcParams["axes.unicode_minus"] = False

    # 绘制叶节点
    # x1,y1 箭头起始点坐标
    # x2,y2 箭头目标点(文字点坐标)
    # text  节点文字
    # desc  线文字
    def drawLeaf(self, x1, y1, x2, y2, text, desc):
        # 绘制节点以及箭头
        plt.annotate(text,
                     xy=(x1, y1),
                     xytext=(x2, y2),
                     va='center',
                     ha='center',
                     xycoords="data",
                     textcoords='data',
                     bbox=self.leafNode,
                     arrowprops=self.arrow_args)
        # 绘制线上的文字
        plt.text((x1 + x2) / 2, (y1 + y2) / 2, desc)

    # 绘制决策节点
    def drawDecision(self, x1, y1, x2, y2, text, desc):
        # 绘制节点以及箭头
        plt.annotate(text,
                     xy=(x1, y1),
                     xytext=(x2, y2),
                     va='center',
                     ha='center',
                     xycoords="data",
                     textcoords='data',
                     bbox=self.decisionNode,
                     arrowprops=self.arrow_args)
        # 绘制线上的文字
        plt.text((x1 + x2) / 2, (y1 + y2) / 2, desc)

    # 绘制根节点(特殊决策节点)
    def drawRoot(self, text):
        # 绘制节点以及箭头
        plt.annotate(text,
                     xy=(0, 0),
                     va='center',
                     ha='center',
                     xycoords="data",
                     textcoords='data',
                     bbox=self.decisionNode)

    def drawTree(self):
        self.draw0(self.root, 0, 0)
        plt.show()

    # xy是父节点的坐标
    def draw0(self, node, x, y):
        # 如果当前深度节点数量没有,则置为0
        if(self.deepIndex.get(self.deep) is None):
            self.deepIndex[self.deep] = 0
        # 注意因为是基于当前节点数量排列所有节点,故都基于0点排列
        x2 = self.deepIndex[self.deep] * self.step
        y2 = y - self.step
        if len(node.children) > 0:
            if len(node.desc) > 0:
                self.drawDecision(x, y, x2, y2, node.name, node.desc)
                self.deep += 1
                for i, child in enumerate(node.children):
                    self.draw0(child, x2, y2)
                self.deep -= 1
            else:
                self.drawRoot(node.name)
                for i, child in enumerate(node.children):
                    self.draw0(child, 0, 0)
        else:
            self.drawLeaf(x, y, x2, y2, node.name, node.desc)
        # 当前深度节点数++
        self.deepIndex[self.deep] = self.deepIndex[self.deep] + 1


# 程序入口
cart = Cart(D,A)
cart.draw()


###########测试checkDA方法##########
# D2=[
# ['青绿','稍蜷','浊响','清晰','稍凹','软粘','是'],
# ['乌黑','稍蜷','浊响','稍糊','稍凹','软粘','是'],
# ['乌黑','稍蜷','浊响','清晰','稍凹','硬滑','是'],
# ['乌黑','稍蜷','沉闷','稍糊','稍凹','硬滑','否']]
# AA2 = [{'根蒂':1},{'脐部':4}]
# a = Cart.checkDA(D2,AA2)
# a2 = Cart.checkDA(D,[])
# a3 = Cart.checkDA(D,Cart.getAttributesAndIndex(A))
# print(a,a2,a3)


你可能感兴趣的:(算法,决策树,机器学习)