废话不多说,直接上代码
详细原理见
《统计学习方法》第五章决策树总结
import numpy as np
class DecisionTree(object):
def __init__(self, tree_type):
self.Tree = None
self.tree_type = tree_type
def build_ID3(self, D, A, e):
"""
:param D: 训练数据集
:param A: 特征集
:param e: 阈值
:return: 决策树T
"""
n = D.shape[0]
tree = {'is_leaf': False}
X = D[:,:-1]
Y = D[:,-1]
# 判断是否符合终止条件
# 所有样本属于同一类别、A为空集
if len(np.unique(Y)) == 1 or len(A)==0:
tree['is_leaf'] = True
# 将D中实例数最大的类作为该节点标记
tree['label'] = self.majority_vote(Y)
return tree
# 计算原数据集的信息熵
origin_entropy = self.entropy(Y)
# 保存使用每一个属性分割后带来的信息增益或信息增益比
gains = []
for i in range(len(A)):
uniques, counts = np.unique(X[:, i], return_counts=True)
# 针对每一个属性计算分割后的信息熵
entropy = 0
for j in range(uniques.shape[0]):
value, count = uniques[j], counts[j]
entropy += (count / n) * self.entropy(Y[X[:, i] == value])
# 计算信息增益
gain = origin_entropy - entropy
if self.tree_type == 'ID3':
gains.append(gain)
if self.tree_type == 'C4.5':
h = self.entropy(X[:, i])
gains.append(gain / h)
# 如果该特征最优的信息增益小于阈值,则返回决策树
if max(gains)<e:
tree['is_leaf'] = True
# 将D中实例数最大的类作为该节点标记
tree['label'] = self.majority_vote(Y)
return tree
col = np.argmax(gains)
Ag = A[col] # 挑选信息增益最大的特征切分
tree['Ag_name'] = Ag
tree['Ag_index'] = col
tree['children'] = {}
uniques = np.unique(X[:, col])
for value in uniques:
id = X[:, col] == value
# 水平拼接训练集构成子集,即去除最优特征的这一列
Di = np.hstack((D[id, :col], D[id, col + 1:]))
# 以A-Ag为新的特征集
A_Ag = np.hstack((A[:col], A[col + 1:]))
tree['children'][value] = self.build_ID3(Di, A_Ag, e)
return tree
def majority_vote(self, targets):
if len(targets) == 0:
return
uniques, counts = np.unique(targets, return_counts=True)
return uniques[np.argmax(counts)]
def entropy(self, D):
_, C = np.unique(D, return_counts=True) # 返回无重复元素列表以及每个元素在旧列表里各自出现了几次
p = C / D.shape[0]
H = -(p * np.log2(p)).sum()
return H
def predict(self, tree, data):
if tree['is_leaf']:
return tree['label']
return self.predict(tree['children'][data[tree['Ag_index']]],
np.hstack((data[:tree['Ag_index']], data[tree['Ag_index'] + 1:])))
data=np.array([['青年','否','否','一般','否'],
['青年','否','否','好','否'],
['青年','是','否','好','是'],
['青年','是','是','一般','是'],
['青年','否','否','一般','否'],
['中年','否','否','一般','否'],
['中年','否','否','好','否'],
['中年','是','是','好','是'],
['中年','否','是','非常好','是'],
['中年','否','是','非常好','是'],
['老年','否','是','非常好','是'],
['老年','否','是','好','是'],
['老年','是','否','好','是'],
['老年','是','否','非常好','是'],
['老年','否','否','一般','否']])
A=['年龄','有工作','有自己房子','信贷情况']
tree = DecisionTree('C4.5')
tree.Tree = tree.build_ID3(data, A, 0)
print(tree.Tree)
print(tree.predict(tree.Tree, ['青年','否','是','非常好']))