决策树是一类常见的机器学习方法。
-------------------------------------------------------------------------------------------------
-训练集D,属性集A-
.if训练集中样本全属于同一类别
. 标记为叶结点,类别标记为此类别 return
.if A=Φ or D中样本在A上取值全相同
. 标记为叶结点,类别标记为D中样本数最多的类 return
.从A中选择最优划分属性a
.for 属性值 in 属性a
. 生成一个分支,划分子集Dv
. if Dv=Φ
. 将此分支标记为叶结点,类别标记为D中样本数最多的类 return
. else
. 以Dv为训练集,A\{a}为属性集构建树
---------------------------------------------------------------------------------------------------
决策树的构建是递归的。
决策树构建算法,可以看出决策树的关键是最优划分属性a的选择。随着划分的进行,我们希望样本尽可能全为一类,即纯度越来越高。
ID3算法就是选择最优划分属性的一种方法。
ID3算法以信息增益为准则来选择划分属性。
信息熵是衡量样本集合纯度的一种指标。
设当前样本有 N N N类,其中第 k k k类样本比例为 p k p_k pk,则集合 D D D的信息熵
E n t ( D ) = − ∑ p k l o g 2 p k k = 1 , 2 , . . . , N Ent(D)=-\sum{p_klog_2p_k} k=1,2,...,N Ent(D)=−∑pklog2pk k=1,2,...,N
信息熵 E n t Ent Ent越小, D D D的纯度越高。
设离散属性 a a a的可能取值 { a 1 , a 2 , . . . , a V } \{a^1,a^2,...,a^V\} {a1,a2,...,aV},以属性 a a a划分集合 D D D,得到 V V V个样本子集,记为 D V D^V DV,则信息增益
G a i n ( D , a ) = E n t ( D ) − ∑ v = 1 V ∣ D V ∣ ∣ D ∣ E n t ( D V ) . Gain(D,a)=Ent(D)-\sum^V_{v=1}{\frac{|D^V|}{|D|}Ent(D^V)}. Gain(D,a)=Ent(D)−v=1∑V∣D∣∣DV∣Ent(DV).
一般来说,信息增益越大,使用属性 a a a划分后纯度提升越大。
下面以西瓜书76页数据集为例用Python实现决策树ID3算法。
源码及数据集下载:决策树ID3算法
import numpy as np
import pandas as pd
import copy
def majorType(data): # 集合中样本数最多的类
dataType_count = data['好瓜'].value_counts() # 统计集合中各类样本的数量
return dataType_count[dataType_count.values == max(dataType_count)].index[0] # 返回最大样本类
def Get_Ent(data): # 信息熵
dataType = data['好瓜']
dataType_count = dataType.value_counts()
pk = dataType_count.values / dataType.count()
Ent = 0
for k in pk:
Ent += k * np.log2(k)
return -Ent
def Get_Gain(data, attribute_set): # 所有属性的信息增益
Ent_D = Get_Ent(data)
Gain_all = dict()
for attribute in attribute_set: # 求各个属性的信息增益,以 属性:增益值 键值对形式存放在字典中
Gain = 0
dataType = data[attribute]
attri_value = attribute_values[attribute]
for v in attri_value:
data_v_index = dataType[dataType.values == v].index # 属性attribute值为v的样例index
data_v = data.reindex(index=data_v_index) # 不同属性值的样本子集
Ent_v = Get_Ent(data_v)
Gain += len(data_v_index) * Ent_v
Gain = Ent_D - Gain / len(dataType)
Gain_all[attribute] = Gain
return Gain_all
def TreeGenerate(data, attribute_set, Tree): # 构建决策树
dataType = data['好瓜']
if dataType.value_counts()[0] == dataType.count(): # 如果样本全属同一类别,则标记为叶结点
Tree = dataType.values[0]
return Tree
if len(attribute_set) == 0: # 如果属性集为空,无法继续划分,标记为叶结点,其类别为当前样本中样本数最多的类
Tree = majorType(data)
return Tree
Flag = True # 样本是否取值全相等
for attribute in attribute_set: # 判断是否样本取值全相等,不相等标记为False
if Flag == False:
break
adata = data[attribute]
Flag = Flag & adata.value_counts()[0] == adata.count()
if Flag == True: # 如果样本取值全相等,则标记为叶结点,类别为当前样本中样本数最多的类
Tree = majorType(data)
return Tree
Gain = Get_Gain(data, attribute_set)
best_attribute = max(Gain, key=Gain.get) # 求信息增益,得到当前信息增益最大的属性
Tree[best_attribute] = dict() # 以其为根节点构建子树
dataType = data[best_attribute]
attr_value = attribute_values[best_attribute] # 属性可取值
newattribute_set = copy.deepcopy(attribute_set)
newattribute_set.remove(best_attribute) # 删除以选属性,获得新的属性集
for v in attr_value: # 遍历最优选择的属性可取的属性值
data_v_index = dataType[dataType.values == v].index # 属性attribute值为v的样例index
data_v = data.reindex(index=data_v_index) # 子集Dv
if data_v['好瓜'].count() == 0: # 如果子集为空,标记为叶结点,其类标记为当前样本中样本最多的类
Tree[best_attribute][v] = majorType(data)
else: # 构建子树
Tree[best_attribute][v] = dict()
Tree[best_attribute][v] = TreeGenerate(data_v, newattribute_set, Tree[best_attribute][v])
return Tree
def show(Tree, i): # 输出决策树;i:当前所处树的层数
if type(Tree) is not dict: # 如果是叶子结点则输出值
print('\t' * i + '|' + Tree)
i -= 1
return i
root = list(Tree.keys())[0]
print('\t' * i + '|' + root)
i += 1
Child_Tree = Tree[root]
for node in Child_Tree:
print('|\t' * i + '|' + node)
i += 1
i = show(Child_Tree[node], i)
i -= 2
return i
def test(Tree, test_data): # 测试
if type(Tree) is not dict:
return Tree
root = list(Tree.keys())[0]
Child_Tree = Tree[root][test_data[root]]
return test(Child_Tree, test_data)
def main():
Tree = dict()
Tree = TreeGenerate(data_all, attribute_set_all, Tree)
show(Tree, 0)
test_data = pd.Series(['乌黑', '稍蜷', '沉闷', '清晰', '平坦', '硬滑'], index=['色泽', '根蒂', '敲声', '纹理', '脐部', '触感'])
res = test(Tree, test_data)
print(res)
if __name__ == '__main__':
data_all = pd.read_excel('watermelon20.xlsx', dtype=str) # 读取数据集
attribute_set_all = list(np.array(data_all.columns)[1:-1]) # 构建属性集
attribute_values = dict()
for attribute in attribute_set_all: # 各个属性的可取值
values = list(set(data_all[attribute].value_counts().index))
attribute_values[attribute] = values
main()