具体数据如下(如果不能运行,尝试在末尾加回车)
5.1 3.5 1.4 0.2 1
4.9 3 1.4 0.2 1
4.7 3.2 1.3 0.2 1
4.6 3.1 1.5 0.2 1
5 3.6 1.4 0.2 1
5.4 3.9 1.7 0.4 1
4.6 3.4 1.4 0.3 1
5 3.4 1.5 0.2 1
4.4 2.9 1.4 0.2 1
4.9 3.1 1.5 0.1 1
5.4 3.7 1.5 0.2 1
4.8 3.4 1.6 0.2 1
4.8 3 1.4 0.1 1
4.3 3 1.1 0.1 1
5.8 4 1.2 0.2 1
5.7 4.4 1.5 0.4 1
5.4 3.9 1.3 0.4 1
5.1 3.5 1.4 0.3 1
5.7 3.8 1.7 0.3 1
5.1 3.8 1.5 0.3 1
5.4 3.4 1.7 0.2 1
5.1 3.7 1.5 0.4 1
4.6 3.6 1 0.2 1
5.1 3.3 1.7 0.5 1
4.8 3.4 1.9 0.2 1
5 3 1.6 0.2 1
5 3.4 1.6 0.4 1
5.2 3.5 1.5 0.2 1
5.2 3.4 1.4 0.2 1
4.7 3.2 1.6 0.2 1
4.8 3.1 1.6 0.2 1
5.4 3.4 1.5 0.4 1
5.2 4.1 1.5 0.1 1
5.5 4.2 1.4 0.2 1
4.9 3.1 1.5 0.1 1
5 3.2 1.2 0.2 1
5.5 3.5 1.3 0.2 1
4.9 3.1 1.5 0.1 1
4.4 3 1.3 0.2 1
5.1 3.4 1.5 0.2 1
5 3.5 1.3 0.3 1
4.5 2.3 1.3 0.3 1
4.4 3.2 1.3 0.2 1
5 3.5 1.6 0.6 1
5.1 3.8 1.9 0.4 1
4.8 3 1.4 0.3 1
5.1 3.8 1.6 0.2 1
4.6 3.2 1.4 0.2 1
5.3 3.7 1.5 0.2 1
5 3.3 1.4 0.2 1
7 3.2 4.7 1.4 2
6.4 3.2 4.5 1.5 2
6.9 3.1 4.9 1.5 2
5.5 2.3 4 1.3 2
6.5 2.8 4.6 1.5 2
5.7 2.8 4.5 1.3 2
6.3 3.3 4.7 1.6 2
4.9 2.4 3.3 1 2
6.6 2.9 4.6 1.3 2
5.2 2.7 3.9 1.4 2
5 2 3.5 1 2
5.9 3 4.2 1.5 2
6 2.2 4 1 2
6.1 2.9 4.7 1.4 2
5.6 2.9 3.6 1.3 2
6.7 3.1 4.4 1.4 2
5.6 3 4.5 1.5 2
5.8 2.7 4.1 1 2
6.2 2.2 4.5 1.5 2
5.6 2.5 3.9 1.1 2
5.9 3.2 4.8 1.8 2
6.1 2.8 4 1.3 2
6.3 2.5 4.9 1.5 2
6.1 2.8 4.7 1.2 2
6.4 2.9 4.3 1.3 2
6.6 3 4.4 1.4 2
6.8 2.8 4.8 1.4 2
6.7 3 5 1.7 2
6 2.9 4.5 1.5 2
5.7 2.6 3.5 1 2
5.5 2.4 3.8 1.1 2
5.5 2.4 3.7 1 2
5.8 2.7 3.9 1.2 2
6 2.7 5.1 1.6 2
5.4 3 4.5 1.5 2
6 3.4 4.5 1.6 2
6.7 3.1 4.7 1.5 2
6.3 2.3 4.4 1.3 2
5.6 3 4.1 1.3 2
5.5 2.5 4 1.3 2
5.5 2.6 4.4 1.2 2
6.1 3 4.6 1.4 2
5.8 2.6 4 1.2 2
5 2.3 3.3 1 2
5.6 2.7 4.2 1.3 2
5.7 3 4.2 1.2 2
5.7 2.9 4.2 1.3 2
6.2 2.9 4.3 1.3 2
5.1 2.5 3 1.1 2
5.7 2.8 4.1 1.3 2
6.3 3.3 6 2.5 3
5.8 2.7 5.1 1.9 3
7.1 3 5.9 2.1 3
6.3 2.9 5.6 1.8 3
6.5 3 5.8 2.2 3
7.6 3 6.6 2.1 3
4.9 2.5 4.5 1.7 3
7.3 2.9 6.3 1.8 3
6.7 2.5 5.8 1.8 3
7.2 3.6 6.1 2.5 3
6.5 3.2 5.1 2 3
6.4 2.7 5.3 1.9 3
6.8 3 5.5 2.1 3
5.7 2.5 5 2 3
5.8 2.8 5.1 2.4 3
6.4 3.2 5.3 2.3 3
6.5 3 5.5 1.8 3
7.7 3.8 6.7 2.2 3
7.7 2.6 6.9 2.3 3
6 2.2 5 1.5 3
6.9 3.2 5.7 2.3 3
5.6 2.8 4.9 2 3
7.7 2.8 6.7 2 3
6.3 2.7 4.9 1.8 3
6.7 3.3 5.7 2.1 3
7.2 3.2 6 1.8 3
6.2 2.8 4.8 1.8 3
6.1 3 4.9 1.8 3
6.4 2.8 5.6 2.1 3
7.2 3 5.8 1.6 3
7.4 2.8 6.1 1.9 3
7.9 3.8 6.4 2 3
6.4 2.8 5.6 2.2 3
6.3 2.8 5.1 1.5 3
6.1 2.6 5.6 1.4 3
7.7 3 6.1 2.3 3
6.3 3.4 5.6 2.4 3
6.4 3.1 5.5 1.8 3
6 3 4.8 1.8 3
6.9 3.1 5.4 2.1 3
6.7 3.1 5.6 2.4 3
6.9 3.1 5.1 2.3 3
5.8 2.7 5.1 1.9 3
6.8 3.2 5.9 2.3 3
6.7 3.3 5.7 2.5 3
6.7 3 5.2 2.3 3
6.3 2.5 5 1.9 3
6.5 3 5.2 2 3
6.2 3.4 5.4 2.3 3
5.9 3 5.1 1.8 3
import math
import random
def deDv(D, a, v, v1, v2): # 得到Dv
for i in range(len(D[a])): # 属性a列遍历
if D[a][i] >= v: # 找到属性v
for j in range(len(D)): # D属性遍历
v1[j].append(D[j][i]) # D的属性加到v1上
else:
for j in range(len(D)): # D属性遍历
v2[j].append(D[j][i]) # D的属性加到v2上
def Ent(D): # Ent(D) 信息熵
x = list(set(D[4])) # 取得无重复标签的列表
if len(x) == 0: # 数据为空
return 0
x1 = x[0] # D中样本标签第一类
c = 0 # 计数
for j in D[4]: # 遍历标签
if j == x1: # 如果标签一样
c = c + 1 # 计数加一
pk1 = c / len(D[4]) # 第一类标签的频率
if pk1 == 0 or pk1 == 1: # 如果有0(防止log 0)
return 0 # 信息熵为 0
x1 = x[1] # D中样本标签第二类
c = 0 # 计数
for j in D[4]: # 遍历标签
if j == x1: # 如果标签一样
c = c + 1 # 计数加一
pk2 = c / len(D[4]) # 第二类标签的频率
if pk1 + pk2 == 1: # 没有其他标签
return -(pk1 * math.log(pk1, 2) + pk2 * math.log(pk2, 2)) # 只有两种标签
pk3 = 1 - pk1 - pk2
return -(pk1 * math.log(pk1, 2) + pk2 * math.log(pk2, 2) + pk3 * math.log(pk3, 2)) # 三种标签
def Gain(D, i, m, t): # 求信息增益
MaxG = 0 # 储存最大的信息增益
Maxt = 0 # 储存最优的划分值
EntD = Ent(D) # 避免重复计算
sumG = 0 # 划分后信息熵
for j in D[i]: # 遍历所有属性
sumG = 0 # 划分后信息熵
Dv1 = [[], [], [], [], []] # 以j划分,左边的子集
Dv2 = [[], [], [], [], []] # 以j划分,右边的子集
deDv(D, i, j, Dv1, Dv2) # 得到Dv
if len(Dv1[0]) != 0:
sumG = sumG + len(Dv1[0]) / len(D[0]) * Ent(Dv1) # 求划分后Dv1信息熵
if len(Dv2[0]) != 0:
sumG = sumG + len(Dv2[0]) / len(D[0]) * Ent(Dv2) # 求划分后Dv2信息熵
sumG = EntD - sumG # 求划分后信息增益
if MaxG < sumG: # 找最大的信息增益
MaxG = sumG # 最大信息增益
Maxt = j # 最大划分值
m.append(MaxG) # m加上最大信息增益
t.append(Maxt) # t加上最优划分值
def TreeGenerate(D): # 递归生成决策树
y = [] # 生成节点
if len(set(D[4])) == 1: # 如果里面节点类型一样
return D[4][0] # 返回这个值
Gm = [] # 分类的信息增益
Gt = [] # 分类的数值
for i in range(4):
Gain(D, i, Gm, Gt) # 计算所有的信息增益
Gm = Gm.index(max(Gm)) # 得到a*位置的下标(最优划分属性)
Gt = Gt[Gm] # 得到a*最优划分属性的划分值(最优划分值)
y.append(Gm)
y.append(Gt)
Dv1 = [[], [], [], [], []] # 以属性Gm中的Gt划分,左边的子集
Dv2 = [[], [], [], [], []] # 以属性Gm中的Gt划分,右边的子集
deDv(D, Gm, Gt, Dv1, Dv2) # 得到Dv
y1 = [] # 生成分支
if len(Dv1) == 0: # Dv1为空
y1.append(max(D[4], key=D[4].count)) # D中最多的取值
else:
y1.append(TreeGenerate(Dv1[:])) # 递归调用
y.append(y1)
y1 = [] # 生成分支
if len(Dv2) == 0: # Dv2为空
y1.append(max(D[4], key=D[4].count)) # D中最多的取值
else:
y1.append(TreeGenerate(Dv2[:])) # 递归调用
y.append(y1)
return y
def panduan(D, y, i):
if isinstance(y, int): # 如果是数字,说明是结果
return y # 返回结果
x = D[y[0]] # 取分类列
if x[i] >= y[1]:
y1 = y[2][0] # 取分支
return panduan(D, y1, i) # 递归分支
else:
y1 = y[3][0] # 取分支
return panduan(D, y1, i) # 递归分支
def wucha(D, y): # 计算分类误差
c = 0 # 误差数
for i in range(len(D[0])):
if panduan(D, y, i) != D[4][i]: # 决策树判断不等于标签
print(D[0][i], D[1][i], D[2][i], D[3][i], D[4][i], panduan(D, y, i))
c = c + 1 # 误差数加一
return c / len(D[0]) # 返回误差比例
f = open('Iris.txt', 'r') # 读文件
x = [[], [], [], [], []] # 花朵属性,(0,1,2,3),花朵种类
x1 = [[], [], [], [], []]
x2 = [[], [], [], [], []]
y = [] # 分支节点[分支类,[分支节点1],[分支节点2]···] 递归决策树
y1 = []
while 1:
yihang = f.readline() # 读一行
if len(yihang) <= 1: # 读到末尾结束
break
fenkai = yihang.split('\t') # 按\t分开
for i in range(5): # 分开的五个值
x[i].append(eval(fenkai[i])) # 化为数字加到x中
print('数据集===============================================')
print(len(x[0]))
for i in range(len(x)):
print(x[i])
x1 = x[:]
print('全训练决策树===============================================')
y = TreeGenerate(x1[:]) # 开始训练
print(y, '\n递归[属性,分类值,第一类,第二类]')
print('误差率:', wucha(x[:], y[:]) * 100, '%')
print('2/3训练决策树==============================================')
l = list(range(150)) # 得到一个顺序序列
random.shuffle(l) # 打乱序列
x1 = [[], [], [], [], []]
x2 = [[], [], [], [], []]
for i in l[0:100]: # 截取部分训练集
for j in range(len(x)): # D属性遍历
x1[j].append(x[j][i]) # D的属性加到v1上
for i in l[100:150]: # 截取部分训练集
for j in range(len(x)): # D属性遍历
x2[j].append(x[j][i]) # D的属性加到v1上
y1 = TreeGenerate(x1) # 开始训练
print(y1, '\n递归[属性,分类值,第一类,第二类]')
print('误差率:', wucha(x2[:], y1[:]) * 100, )
由于随机选取数据,2/3训练决策树结果可能不一样