机器学习之决策树算法(ID3、C4.5、CART)

一、引言

决策树学习采用的是自顶向下的递归方法,其基本思想是以信息熵为度量构造一颗熵值下降最快的树,到叶子节点处,熵值为0。其具有可读性、分类速度快的优点,是一种有监督学习。最早提及决策树思想的是Quinlan在1986年提出的ID3算法和1993年提出的C4.5算法,以及Breiman等人在1984年提出的CART算法。本篇文章主要介绍决策树的基本概念,以及上面这3种常见决策树算法(ID3、C4.5、CART)原理及其代码实现。

二、决策树(ID3、C4.5和CART算法)

2.1、决策树是什么


下面主要讨论用与分类的决策树。决策树呈树形结构,在分类问题中,表示基于特征对实例进行分类的过程。学习时,利用训练数据,根据损失函数最小化的原则建立决策树模型;预测时,对新的数据,利用决策模型进行分类。

决策树的分类:决策树可以分为两类,主要取决于它目标变量的类型。

  • 离散性决策树:离散性决策树,其目标变量是离散的,如性别:男或女等;
  • 连续性决策树:连续性决策树,其目标变量是连续的,如工资、价格、年龄等;

        决策树相关的重要概念:

(1)根结点(Root Node):它表示整个样本集合,并且该节点可以进一步划分成两个或多个子集。

(2)拆分(Splitting):表示将一个结点拆分成多个子集的过程。

(3)决策结点(Decision Node):当一个子结点进一步被拆分成多个子节点时,这个子节点就叫做决策结点。

(4)叶子结点(Leaf/Terminal Node):无法再拆分的结点被称为叶子结点。

(5)剪枝(Pruning):移除决策树中子结点的过程就叫做剪枝,跟拆分过程相反。

(6)分支/子树(Branch/Sub-Tree):一棵决策树的一部分就叫做分支或子树。

(7)父结点和子结点(Paren and Child Node):一个结点被拆分成多个子节点,这个结点就叫做父节点;其拆分后的子结点也叫做子结点。

机器学习之决策树算法(ID3、C4.5、CART)_第1张图片

2.2、决策树的构造过程


决策树的构造过程一般分为3个部分,分别是特征选择、决策树生产和决策树裁剪。

(1)特征选择:

特征选择表示从众多的特征中选择一个特征作为当前节点分裂的标准,如何选择特征有不同的量化评估方法,从而衍生出不同的决策树,如ID3(通过信息增益选择特征)、C4.5(通过信息增益比选择特征)、CART(通过Gini指数选择特征)等。

目的(准则):使用某特征对数据集划分之后,各数据子集的纯度要比划分钱的数据集D的纯度高(也就是不确定性要比划分前数据集D的不确定性低)

(2)决策树的生成

根据选择的特征评估标准,从上至下递归地生成子节点,直到数据集不可分则停止决策树停止生长。这个过程实际上就是使用满足划分准则的特征不断的将数据集划分成纯度更高,不确定行更小的子集的过程。对于当前数据集的每一次划分,都希望根据某个特征划分之后的各个子集的纯度更高,不确定性更小。

(3)决策树的裁剪

决策树容易过拟合,一般需要剪枝来缩小树结构规模、缓解过拟合。

2.3、决策树的优缺点


决策树的优点:

(1)具有可读性,如果给定一个模型,那么过呢据所产生的决策树很容易推理出相应的逻辑表达。

(2)分类速度快,能在相对短的时间内能够对大型数据源做出可行且效果良好的结果。

决策树的缺点:

(1)对未知的测试数据未必有好的分类、泛化能力,即可能发生过拟合现象,此时可采用剪枝或随机森林。

2.4、ID3算法原理与python代码实现


ID3算法的核心是在决策树各个节点上应用信息增益准则选择特征递归地构建决策树。

2.4.1信息增益

在《最大熵模型学习》一文中,我们提到过熵和条件熵的概念,下面我们在总结一遍。

最大熵模型学习

 

(1)熵

在信息论中,熵(entropy)是随机变量不确定性的度量,也就是熵越大,则随机变量的不确定性越大。设X是一个取有限个值得离散随机变量,其概率分布为:

则随机变量X的熵定义为:

机器学习之决策树算法(ID3、C4.5、CART)_第2张图片

(2)条件熵

设有随机变量(X, Y),其联合概率分布为:

条件熵H(Y|X)表示在已知随机变量X的条件下,随机变量Y的不确定性。随机变量X给定的条件下随机变量Y的条件熵H(Y|X),定义为X给定条件下Y的条件概率分布的熵对X的数学期望:

机器学习之决策树算法(ID3、C4.5、CART)_第3张图片

当熵和条件熵中的概率由数据估计得到时(如极大似然估计),所对应的熵与条件熵分别称为经验熵和经验条件熵。

(3)信息增益

定义:信息增益表示由于得知特征A的信息后儿时的数据集D的分类不确定性减少的程度,定义为:

Gain(D,A) = H(D) – H(D|A)

             即集合D的经验熵H(D)与特征A给定条件下D的经验条件熵H(H|A)之差。

理解:选择划分后信息增益大的作为划分特征,说明使用该特征后划分得到的子集纯度越高,即不确定性越小。因此我们总是选择当前使得信息增益最大的特征来划分数据集。

缺点:信息增益偏向取值较多的特征(原因:当特征的取值较多时,根据此特征划分更容易得到纯度更高的子集,因此划分后的熵更低,即不确定性更低,因此信息增益更大)

2.4.2 ID3算法

输入:训练数据集D,特征集A,阈值ε;

输出:决策树T.

Step1:若D中所有实例属于同一类,则T为单结点树,并将类作为该节点的类标记,返回T;

Step2:若A=Ø,则T为单结点树,并将D中实例数最大的类作为该节点的类标记,返回T;

Step3:否则,2.1.1(3)计算A中个特征对D的信息增益,选择信息增益最大的特征

Step4:如果的信息增益小于阈值ε,则T为单节点树,并将D中实例数最大的类作为该节点的类标记,返回T

Step5:否则,对的每一种可能值,依将D分割为若干非空子集,将中实例数最大的类作为标记,构建子结点,由结点及其子树构成树T,返回T;

 

Step6:对第i个子节点,以为训练集,以为特征集合,递归调用Step1~step5,得到子树,返回

2.4.3 python代码实现

接下来我们通过下面这组数据作为测试样本

序号 不浮出水面是否可以生存 是否有脚蹼 是否属于鱼类
1
2
3
4
5
文件名:id3.py
  1. # -*- coding: utf-8 -*-
  2. from math import log
  3. import operator
  4. import tree_plotter
  5.  
  6.  
  7. def create_data_set():
  8. """
  9. 创建样本数据
  10. :return:
  11. """
  12. data_set = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
  13. labels = ['no surfacing', 'flippers']
  14. return data_set, labels
  15.  
  16.  
  17. def calc_shannon_ent(data_set):
  18. """
  19. 计算信息熵
  20. :param data_set: 如: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
  21. :return:
  22. """
  23. num = len(data_set) # n rows
  24. # 为所有的分类类目创建字典
  25. label_counts = {}
  26. for feat_vec in data_set:
  27. current_label = feat_vec[-1] # 取得最后一列数据
  28. if current_label not in label_counts.keys():
  29. label_counts[current_label] = 0
  30. label_counts[current_label] += 1
  31.  
  32. # 计算香浓熵
  33. shannon_ent = 0.0
  34. for key in label_counts:
  35. prob = float(label_counts[key]) / num
  36. shannon_ent = shannon_ent - prob * log(prob, 2)
  37. return shannon_ent
  38.  
  39.  
  40. def split_data_set(data_set, axis, value):
  41. """
  42. 返回特征值等于value的子数据集,切该数据集不包含列(特征)axis
  43. :param data_set: 待划分的数据集
  44. :param axis: 特征索引
  45. :param value: 分类值
  46. :return:
  47. """
  48. ret_data_set = []
  49. for feat_vec in data_set:
  50. if feat_vec[axis] == value:
  51. reduce_feat_vec = feat_vec[:axis]
  52. reduce_feat_vec.extend(feat_vec[axis + 1:])
  53. ret_data_set.append(reduce_feat_vec)
  54. return ret_data_set
  55.  
  56.  
  57. def choose_best_feature_to_split(data_set):
  58. """
  59. 按照最大信息增益划分数据
  60. :param data_set: 样本数据,如: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
  61. :return:
  62. """
  63. num_feature = len(data_set[0]) - 1 # 特征个数,如:不浮出水面是否可以生存 和是否有脚蹼
  64. base_entropy = calc_shannon_ent(data_set) # 经验熵H(D)
  65. best_info_gain = 0
  66. best_feature_idx = -1
  67. for feature_idx in range(num_feature):
  68. feature_val_list = [number[feature_idx] for number in data_set] # 得到某个特征下所有值(某列)
  69. unique_feature_val_list = set(feature_val_list) # 获取无重复的属性特征值
  70. new_entropy = 0
  71. for feature_val in unique_feature_val_list:
  72. sub_data_set = split_data_set(data_set, feature_idx, feature_val)
  73. prob = len(sub_data_set) / float(len(data_set)) # 即p(t)
  74. new_entropy += prob * calc_shannon_ent(sub_data_set) #对各子集香农熵求和
  75. info_gain = base_entropy - new_entropy # 计算信息增益,g(D,A)=H(D)-H(D|A)
  76. # 最大信息增益
  77. if info_gain > best_info_gain:
  78. best_info_gain = info_gain
  79. best_feature_idx = feature_idx
  80.  
  81. return best_feature_idx
  82.  
  83.  
  84. def majority_cnt(class_list):
  85. """
  86. 统计每个类别出现的次数,并按大到小排序,返回出现次数最大的类别标签
  87. :param class_list: 类数组
  88. :return:
  89. """
  90. class_count = {}
  91. for vote in class_list:
  92. if vote not in class_count.keys():
  93. class_count[vote] = 0
  94. class_count[vote] += 1
  95. sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reversed=True)
  96. print sorted_class_count[0][0]
  97. return sorted_class_count[0][0]
  98.  
  99.  
  100. def create_tree(data_set, labels):
  101. """
  102. 构建决策树
  103. :param data_set: 数据集合,如: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
  104. :param labels: 标签数组,如:['no surfacing', 'flippers']
  105. :return:
  106. """
  107. class_list = [sample[-1] for sample in data_set] # ['yes', 'yes', 'no', 'no', 'no']
  108. # 类别相同,停止划分
  109. if class_list.count(class_list[-1]) == len(class_list):
  110. return class_list[-1]
  111. # 长度为1,返回出现次数最多的类别
  112. if len(class_list[0]) == 1:
  113. return majority_cnt((class_list))
  114. # 按照信息增益最高选取分类特征属性
  115. best_feature_idx = choose_best_feature_to_split(data_set) # 返回分类的特征的数组索引
  116. best_feat_label = labels[best_feature_idx] # 该特征的label
  117. my_tree = {best_feat_label: {}} # 构建树的字典
  118. del (labels[best_feature_idx]) # 从labels的list中删除该label,相当于待划分的子标签集
  119. feature_values = [example[best_feature_idx] for example in data_set]
  120. unique_feature_values = set(feature_values)
  121. for feature_value in unique_feature_values:
  122. sub_labels = labels[:] # 子集合
  123. # 构建数据的子集合,并进行递归
  124. sub_data_set = split_data_set(data_set, best_feature_idx, feature_value) # 待划分的子数据集
  125. my_tree[best_feat_label][feature_value] = create_tree(sub_data_set, sub_labels)
  126. return my_tree
  127.  
  128.  
  129. def classify(input_tree, feat_labels, test_vec):
  130. """
  131. 决策树分类
  132. :param input_tree: 决策树
  133. :param feat_labels: 特征标签
  134. :param test_vec: 测试的数据
  135. :return:
  136. """
  137. first_str = list(input_tree.keys())[0] # 获取树的第一特征属性
  138. second_dict = input_tree[first_str] # 树的分子,子集合Dict
  139. feat_index = feat_labels.index(first_str) # 获取决策树第一层在feat_labels中的位置
  140. for key in second_dict.keys():
  141. if test_vec[feat_index] == key:
  142. if type(second_dict[key]).__name__ == 'dict':
  143. class_label = classify(second_dict[key], feat_labels, test_vec)
  144. else:
  145. class_label = second_dict[key]
  146. return class_label
  147.  
  148.  
  149. data_set, labels = create_data_set()
  150. decision_tree = create_tree(data_set, labels)
  151. print "决策树:", decision_tree
  152. data_set, labels = create_data_set()
  153. print "(1)不浮出水面可以生存,无脚蹼:", classify(decision_tree, labels, [1, 0])
  154. print "(2)不浮出水面可以生存,有脚蹼:", classify(decision_tree, labels, [1, 1])
  155. print "(3)不浮出水面可以不能生存,无脚蹼:", classify(decision_tree, labels, [0, 0])
  156. tree_plotter.create_plot(decision_tree)

画图程序,tree_plotter.py:

  1. import matplotlib.pyplot as plt
  2.  
  3. decision_node = dict(boxstyle="sawtooth", fc="0.8")
  4. leaf_node = dict(boxstyle="round4", fc="0.8")
  5. arrow_args = dict(arrowstyle="<-")
  6.  
  7.  
  8. def plot_node(node_txt, center_pt, parent_pt, node_type):
  9. create_plot.ax1.annotate(node_txt, xy=parent_pt, xycoords='axes fraction', \
  10. xytext=center_pt, textcoords='axes fraction', \
  11. va="center", ha="center", bbox=node_type, arrowprops=arrow_args)
  12.  
  13.  
  14. def get_num_leafs(my_tree):
  15. num_leafs = 0
  16. first_str = list(my_tree.keys())[0]
  17. second_dict = my_tree[first_str]
  18. for key in second_dict.keys():
  19. if type(second_dict[key]).__name__ == 'dict':
  20. num_leafs += get_num_leafs(second_dict[key])
  21. else:
  22. num_leafs += 1
  23. return num_leafs
  24.  
  25.  
  26. def get_tree_depth(my_tree):
  27. max_depth = 0
  28. first_str = list(my_tree.keys())[0]
  29. second_dict = my_tree[first_str]
  30. for key in second_dict.keys():
  31. if type(second_dict[key]).__name__ == 'dict':
  32. thisDepth = get_tree_depth(second_dict[key]) + 1
  33. else:
  34. thisDepth = 1
  35. if thisDepth > max_depth:
  36. max_depth = thisDepth
  37. return max_depth
  38.  
  39.  
  40. def plot_mid_text(cntr_pt, parent_pt, txt_string):
  41. x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
  42. y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
  43. create_plot.ax1.text(x_mid, y_mid, txt_string)
  44.  
  45.  
  46. def plot_tree(my_tree, parent_pt, node_txt):
  47. num_leafs = get_num_leafs(my_tree)
  48. depth = get_tree_depth(my_tree)
  49. first_str = list(my_tree.keys())[0]
  50. cntr_pt = (plot_tree.x_off + (1.0 + float(num_leafs)) / 2.0 / plot_tree.total_w, plot_tree.y_off)
  51. plot_mid_text(cntr_pt, parent_pt, node_txt)
  52. plot_node(first_str, cntr_pt, parent_pt, decision_node)
  53. second_dict = my_tree[first_str]
  54. plot_tree.y_off = plot_tree.y_off - 1.0 / plot_tree.total_d
  55. for key in second_dict.keys():
  56. if type(second_dict[key]).__name__ == 'dict':
  57. plot_tree(second_dict[key], cntr_pt, str(key))
  58. else:
  59. plot_tree.x_off = plot_tree.x_off + 1.0 / plot_tree.total_w
  60. plot_node(second_dict[key], (plot_tree.x_off, plot_tree.y_off), cntr_pt, leaf_node)
  61. plot_mid_text((plot_tree.x_off, plot_tree.y_off), cntr_pt, str(key))
  62. plot_tree.y_off = plot_tree.y_off + 1.0 / plot_tree.total_d
  63.  
  64.  
  65. def create_plot(in_tree):
  66. fig = plt.figure(1, facecolor='white')
  67. fig.clf()
  68. axprops = dict(xticks=[], yticks=[])
  69. create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
  70. plot_tree.total_w = float(get_num_leafs(in_tree))
  71. plot_tree.total_d = float(get_tree_depth(in_tree))
  72. plot_tree.x_off = -0.5 / plot_tree.total_w
  73. plot_tree.y_off = 1.0
  74. plot_tree(in_tree, (0.5, 1.0), '')
  75. plt.show()

输出结果如下:

决策树: {‘no surfacing’: {0: ‘no’, 1: {‘flippers’: {0: ‘no’, 1: ‘yes’}}}}
(1)不浮出水面可以生存,无脚蹼: no
(2)不浮出水面可以生存,有脚蹼: yes
(3)不浮出水面可以不能生存,无脚蹼: no

最终我们得到决策树如下:

机器学习之决策树算法(ID3、C4.5、CART)_第4张图片

2.5、C4.5算法原理与python代码实现


C4.5算法与ID3算法很相似,C4.5算法是对ID3算法做了改进,在生成决策树过程中采用信息增益比来选择特征。

2.5.1 信息增益比

我们知道信息增益会偏向取值较多的特征,使用信息增益比可以对这一问题进行校正。

定义:特征A对训练数据集D的信息增益比GainRatio(D,A)定义为其信息增益Gain(D,A)与训练数据集D的经验熵H(D)之比:

机器学习之决策树算法(ID3、C4.5、CART)_第5张图片

2.5.2 C4.5算法

C4.5算法过程跟ID3算法一样,只是选择特征的方法由信息增益改成信息增益比。

2.5.3 python代码实现

我们还是采用2.1.3中的实例,C4.5算法跟ID3算法,不同的地方只是特征选择方法,即:

choose_best_feature_to_split方法。
  1. def choose_best_feature_to_split(data_set):
  2. """
  3. 按照最大信息增益比划分数据
  4. :param data_set: 样本数据,如: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
  5. :return:
  6. """
  7. num_feature = len(data_set[0]) - 1 # 特征个数,如:不浮出水面是否可以生存 和是否有脚蹼
  8. base_entropy = calc_shannon_ent(data_set) # 经验熵H(D)
  9. best_info_gain_ratio = 0.0
  10. best_feature_idx = -1
  11. for feature_idx in range(num_feature):
  12. feature_val_list = [number[feature_idx] for number in data_set] # 得到某个特征下所有值(某列)
  13. unique_feature_val_list = set(feature_val_list) # 获取无重复的属性特征值
  14. new_entropy = 0
  15. split_info = 0.0
  16. for value in unique_feature_val_list:
  17. sub_data_set = split_data_set(data_set, feature_idx, value)
  18. prob = len(sub_data_set) / float(len(data_set)) # 即p(t)
  19. new_entropy += prob * calc_shannon_ent(sub_data_set) # 对各子集香农熵求和
  20. split_info += -prob * log(prob, 2)
  21. info_gain = base_entropy - new_entropy # 计算信息增益,g(D,A)=H(D)-H(D|A)
  22. if split_info == 0: # fix the overflow bug
  23. continue
  24. info_gain_ratio = info_gain / split_info
  25. # 最大信息增益比
  26. if info_gain_ratio > best_info_gain_ratio:
  27. best_info_gain_ratio = info_gain_ratio
  28. best_feature_idx = feature_idx
  29.  
  30. return best_feature_idx

效果跟ID3算法一样,这里就不重复。

2.6、CART算法原理与python代码实现


2.6.1 Gini指数

分类问题中,假设有K个类,样本点属于第k类的概率为,则概率分布的基尼指数定义为:

备注:表示选中的样本属于k类别的概率,则这个样本被分错的概率为

对于给定的样本集合D,其基尼指数为:

机器学习之决策树算法(ID3、C4.5、CART)_第6张图片

备注:这里是D中属于第k类的样本自己,K是类的个数。

如果样本集合D根据特征A是否取某一可能值a被分割成D1和D2两部分,即:

则在特征A的条件下,集合D的基尼指数定义为:

机器学习之决策树算法(ID3、C4.5、CART)_第7张图片

基尼指数Gini(D)表示集合D的不确定性,基尼指数Gini(D,A)表示经A=a分割后集合D的不确定性。基尼指数值越大,样本集合的不确定性也就越大,这一点跟熵相似。

下面举一个例子来说明上面的公式:

如下,是一个包含30个学生的样本,其包含三种特征,分别是:性别(男/女)、班级(IX/X)和高度(5到6ft)。其中30个学生里面有15个学生喜欢在闲暇时间玩板球。那么要如何选择第一个要划分的特征呢,我们通过上面的公式来进行计算。

机器学习之决策树算法(ID3、C4.5、CART)_第8张图片

如下,可以Gini(D,Gender)最小,所以选择性别作为最优特征。

机器学习之决策树算法(ID3、C4.5、CART)_第9张图片

2.6.2 CART算法

输入:训练数据集D,停止计算的条件

输出:CART决策树

根据训练数据集,从根结点开始,递归地对每个结点进行以下操作,构建二叉树:

Step1:设结点的训练数据集为D,计算现有特征对该数据集的基尼指数。此时,对每一个特征A,对其可能取的每个值a,根据样本点A=a的测试为“是”或“否”将D分割为D1和D2两部分,利用上式Gini(D,A)来计算A=a时的基尼指数。

Step2:在所有可能的特征A以及他们所有可能的切分点a中,选择基尼指数最小的特征及其对应可能的切分点作为最有特征与最优切分点。依最优特征与最有切分点,从现结点生成两个子节点,将训练数据集依特征分配到两个子节点中去。

Step3:对两个子结点递归地调用Step1、Step2,直至满足条件。

Step4:生成CART决策树

算法停止计算的条件是节点中的样本个数小于预定阈值,或样本集的基尼指数小于预定阈值,或者没有更多特征。

2.6.3 python代码实现

cart.py:

  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3.  
  4.  
  5. class Tree(object):
  6. def __init__(self, value=None, true_branch=None, false_branch=None, results=None, col=-1, summary=None, data=None):
  7. self.value = value
  8. self.true_branch = true_branch
  9. self.false_branch = false_branch
  10. self.results = results
  11. self.col = col
  12. self.summary = summary
  13. self.data = data
  14.  
  15. def __str__(self):
  16. print(self.col, self.value)
  17. print(self.results)
  18. print(self.summary)
  19. return ""
  20.  
  21.  
  22. def split_datas(rows, value, column):
  23. """
  24. 根据条件分离数据集
  25. :param rows:
  26. :param value:
  27. :param column:
  28. :return: (list1, list2)
  29. """
  30. list1 = []
  31. list2 = []
  32. if isinstance(value, int) or isinstance(value, float):
  33. for row in rows:
  34. if row[column] >= value:
  35. list1.append(row)
  36. else:
  37. list2.append(row)
  38. else:
  39. for row in rows:
  40. if row[column] == value:
  41. list1.append(row)
  42. else:
  43. list2.append(row)
  44.  
  45. return list1, list2
  46.  
  47.  
  48. def calculate_diff_count(data_set):
  49. """
  50. 分类统计data_set中每个类别的数量
  51. :param datas:如:[[5.1, 3.5, 1.4, 0.2, 'setosa'], [4.9, 3, 1.4, 0.2, 'setosa'],....]
  52. :return: 如:{'setosa': 50, 'versicolor': 50, 'virginica': 50}
  53. """
  54. results = {}
  55. for data in data_set:
  56. # 数据的最后一列data[-1]是类别
  57. if data[-1] not in results:
  58. results.setdefault(data[-1], 1)
  59. else:
  60. results[data[-1]] += 1
  61. return results
  62.  
  63.  
  64. def gini(data_set):
  65. """
  66. 计算gini的值,即Gini(p)
  67. :param data_set: 如:[[5.1, 3.5, 1.4, 0.2, 'setosa'], [4.9, 3, 1.4, 0.2, 'setosa'],....]
  68. :return:
  69. """
  70. length = len(data_set)
  71. category_2_cnt = calculate_diff_count(data_set)
  72. sum = 0.0
  73. for category in category_2_cnt:
  74. sum += pow(float(category_2_cnt[category]) / length, 2)
  75. return 1 - sum
  76.  
  77.  
  78. def build_decision_tree(data_set, evaluation_function=gini):
  79. """
  80. 递归建立决策树,当gain=0时,停止回归
  81. :param data_set: 如:[[5.1, 3.5, 1.4, 0.2, 'setosa'], [4.9, 3, 1.4, 0.2, 'setosa'],....]
  82. :param evaluation_function:
  83. :return:
  84. """
  85. current_gain = evaluation_function(data_set)
  86. column_length = len(data_set[0])
  87. rows_length = len(data_set)
  88.  
  89. best_gain = 0.0
  90. best_value = None
  91. best_set = None
  92.  
  93. # choose the best gain
  94. for feature_idx in range(column_length - 1):
  95. feature_value_set = set(row[feature_idx] for row in data_set)
  96. for feature_value in feature_value_set:
  97. sub_data_set1, sub_data_set2 = split_datas(data_set, feature_value, feature_idx)
  98. p = float(len(sub_data_set1)) / rows_length
  99. # Gini(D,A)表示在特征A的条件下集合D的基尼指数,gini_d_a越小,样本集合不确定性越小
  100. # 我们的目的是找到另gini_d_a最小的特征,及gain最大的特征
  101. gini_d_a = p * evaluation_function(sub_data_set1) + (1 - p) * evaluation_function(sub_data_set2)
  102. gain = current_gain - gini_d_a
  103. if gain > best_gain:
  104. best_gain = gain
  105. best_value = (feature_idx, feature_value)
  106. best_set = (sub_data_set1, sub_data_set2)
  107. dc_y = {'impurity': '%.3f' % current_gain, 'sample': '%d' % rows_length}
  108.  
  109. # stop or not stop
  110. if best_gain > 0:
  111. true_branch = build_decision_tree(best_set[0], evaluation_function)
  112. false_branch = build_decision_tree(best_set[1], evaluation_function)
  113. return Tree(col=best_value[0], value=best_value[1], true_branch=true_branch, false_branch=false_branch, summary=dc_y)
  114. else:
  115. return Tree(results=calculate_diff_count(data_set), summary=dc_y, data=data_set)
  116.  
  117.  
  118. def prune(tree, mini_gain, evaluation_function=gini):
  119. """
  120. 裁剪
  121. :param tree:
  122. :param mini_gain:
  123. :param evaluation_function:
  124. :return:
  125. """
  126. if tree.true_branch.results == None:
  127. prune(tree.true_branch, mini_gain, evaluation_function)
  128. if tree.false_branch.results == None:
  129. prune(tree.false_branch, mini_gain, evaluation_function)
  130.  
  131. if tree.true_branch.results != None and tree.false_branch.results != None:
  132. len1 = len(tree.true_branch.data)
  133. len2 = len(tree.false_branch.data)
  134. len3 = len(tree.true_branch.data + tree.false_branch.data)
  135.  
  136. p = float(len1) / (len1 + len2)
  137.  
  138. gain = evaluation_function(tree.true_branch.data + tree.false_branch.data) \
  139. - p * evaluation_function(tree.true_branch.data)\
  140. - (1 - p) * evaluation_function(tree.false_branch.data)
  141.  
  142. if gain < mini_gain:
  143. # 当节点的gain小于给定的 mini Gain时则合并这两个节点
  144. tree.data = tree.true_branch.data + tree.false_branch.data
  145. tree.results = calculate_diff_count(tree.data)
  146. tree.true_branch = None
  147. tree.false_branch = None
  148.  
  149.  
  150. def classify(data, tree):
  151. """
  152. 分类
  153. :param data:
  154. :param tree:
  155. :return:
  156. """
  157. if tree.results != None:
  158. return tree.results
  159. else:
  160. branch = None
  161. v = data[tree.col]
  162. if isinstance(v, int) or isinstance(v, float):
  163. if v >= tree.value:
  164. branch = tree.true_branch
  165. else:
  166. branch = tree.false_branch
  167. else:
  168. if v == tree.value:
  169. branch = tree.true_branch
  170. else:
  171. branch = tree.false_branch
  172. return classify(data, branch)
  173.  
  174.  
  175. def load_csv():
  176. def convert_types(s):
  177. s = s.strip()
  178. try:
  179. return float(s) if '.' in s else int(s)
  180. except ValueError:
  181. return s
  182. data = np.loadtxt("datas.csv", dtype="str", delimiter=",")
  183. data = data[1:, :]
  184. data_set = ([[convert_types(item) for item in row] for row in data])
  185. return data_set
  186.  
  187.  
  188. if __name__ == '__main__':
  189. data_set = load_csv()
  190. print data_set
  191. decistion_tree = build_decision_tree(data_set, evaluation_function=gini)
  192. print decistion_tree.results
  193. # prune(decistion_tree, 0.4)
  194. print classify([5.1,3.5,1.4,0.2], decistion_tree) # setosa
  195. print classify([6.8,2.8,4.8,1.4], decistion_tree) # versicolor
  196. print classify([6.8,3.2,5.9,2.3], decistion_tree) # virginica

输出结果:

{‘setosa’: 50}
{‘versicolor’: 47}
{‘virginica’: 43}

如果想进一步学习机器学习的其他算法,可以参考我写的其他文章:

  •  《最大熵模型学习》
  • 《支持向量机SVM》
  • 《随机森林和adaboost》
  • 《常见决策树算法ID3、cart、C4.5》
  • 《聚类算法》
  • 《逻辑回归算法》
  • 《线性回归算法》

参考:

[1] 机器学习之决策树(ID3)算法与Python实现 https://blog.csdn.net/moxigandashu/article/details/71305273

[2] 决策树–信息增益,信息增益比,Geni指数的理解 https://www.cnblogs.com/muzixi/p/6566803.html

[3] https://github.com/RRdmlearning/Decision-Tree

[4] https://www.analyticsvidhya.com/blog/2016/04/complete-tutorial-tree-based-modeling-scratch-in-python/

你可能感兴趣的:(机器学习)