这是一个计算决策树中信息增益、信息增益比和GINI指标的例子。
相关阅读:
Information Gain
http://www.cs.csi.cuny.edu/~imberman/ai/Entropy%20and%20Information%20Gain.htm
Decision Tree
https://blog.csdn.net/Tomcater321/article/details/80699044
df_iris = pd.read_csv('./data mining/iris.csv',header=0,names=['sepal_len','sepal_wid','petal_len','petal_wid','class'])
df_iris_train,df_iris_test = train_test_split(df_iris,test_size=0.2)
print(df_iris_test[0:10])
# sepal_len sepal_wid petal_len petal_wid class
# 86 6.3 2.3 4.4 1.3 Iris-versicolor
# 85 6.7 3.1 4.7 1.5 Iris-versicolor
# 57 6.6 2.9 4.6 1.3 Iris-versicolor
# 43 5.1 3.8 1.9 0.4 Iris-setosa
# 93 5.6 2.7 4.2 1.3 Iris-versicolor
# 6 5.0 3.4 1.5 0.2 Iris-setosa
# 41 4.4 3.2 1.3 0.2 Iris-setosa
# 79 5.5 2.4 3.8 1.1 Iris-versicolor
# 100 5.8 2.7 5.1 1.9 Iris-virginica
# 14 5.7 4.4 1.5 0.4 Iris-setosa
####################################### 1. Information Gain(ID3) ######################################
H(class) = h(Iris-versicolor) + h(Iris-setosa) + h(Iris-virginica)
= -(5/10 * math.log(5/10) + 4/10 * math.log(4/10) + 1/10 * math.log(1/10))
= 0.9433483923290391
split = 5
H(sepal_len|class) = p(sepal_len>5)*H(sepal_len>5|class) + p(sepal_len<=5)*H(sepal_len<=5|class)
= 8/10 * 0.9002560512685369 + 2/8 * 0.6931471805599453
= 0.893491636154816
# 86 6.3 2.3 4.4 1.3 Iris-versicolor
# 85 6.7 3.1 4.7 1.5 Iris-versicolor
# 57 6.6 2.9 4.6 1.3 Iris-versicolor
# 43 5.1 3.8 1.9 0.4 Iris-setosa
# 93 5.6 2.7 4.2 1.3 Iris-versicolor
# 79 5.5 2.4 3.8 1.1 Iris-versicolor
# 100 5.8 2.7 5.1 1.9 Iris-virginica
# 14 5.7 4.4 1.5 0.4 Iris-setosa
H(sepal_len>5|class) = h(Iris-versicolor) + h(Iris-setosa) + h(Iris-virginica)
= -(5/8 * math.log(5/8) + 2/8 * math.log(2/8) + 1/8 * math.log(1/8))
= 0.9002560512685369
# 41 4.4 3.2 1.3 0.2 Iris-setosa
# 79 5.5 2.4 3.8 1.1 Iris-versicolor
H(sepal_len<=5|class) = h(Iris-setosa) + h(Iris-versicolor)
= -(1/2 * math.log(1/2) + 1/2 * math.log(1/2))
= 0.6931471805599453
D_kl = 0.9433483923290391 - 0.893491636154816
= 0.04985675617422314
split = 6
H(sepal_len|class) = p(sepal_len>6)*H(sepal_len>6|class) + p(sepal_len<=6)*H(sepal_len<=6|class)
= 0 + 7/10*0.9556998911125343
= 0.668989923778774
# 86 6.3 2.3 4.4 1.3 Iris-versicolor
# 85 6.7 3.1 4.7 1.5 Iris-versicolor
# 57 6.6 2.9 4.6 1.3 Iris-versicolor
H(sepal_len>6|class) = 0
# 43 5.1 3.8 1.9 0.4 Iris-setosa
# 93 5.6 2.7 4.2 1.3 Iris-versicolor
# 6 5.0 3.4 1.5 0.2 Iris-setosa
# 41 4.4 3.2 1.3 0.2 Iris-setosa
# 79 5.5 2.4 3.8 1.1 Iris-versicolor
# 100 5.8 2.7 5.1 1.9 Iris-virginica
# 14 5.7 4.4 1.5 0.4 Iris-setosa
H(sepal_len<=6|class) = h(Iris-setosa) + h(Iris-versicolor) + h(Iris-virginica)
= -( 4/7*math.log(4/7) + 2/7*math.log(2/7) + 1/7*math.log(1/7) )
= 0.9556998911125343
D_kl = 0.9433483923290391 - 0.668989923778774
= 0.27435846855026513
####################################### 2.Information Gain Ratio ######################################
H(split = 5) = H(2/10,8/10)
= -(2/10*math.log(2/10) + 8/10*math.log(8/10))
= 0.5004024235381879
gr(split = 5) = D_kl(split = 5)/H(split = 5)
= 0.04985675617422314/0.5004024235381879
= 0.09963332275991336
H(split = 6) = H(3/10,7/10)
= -(3/10*math.log(3/10) + 7/10*math.log(7/10))
= 0.6108643020548935
gr(split = 6) = D_kl(split = 6)/H(split = 6)
= 0.27435846855026513/0.6108643020548935
= 0.4491316117627884
gr(split = 6) is big and good
####################################### 3.Gini Index ######################################
Gini(split = 5) = p(sepal_len>5)*Gini(sepal_len>5) + p(sepal_len<=5)*Gini(sepal_len<=5)
= 8/10*0.53125 + 2/10*0.5
= 0.525
# 86 6.3 2.3 4.4 1.3 Iris-versicolor
# 85 6.7 3.1 4.7 1.5 Iris-versicolor
# 57 6.6 2.9 4.6 1.3 Iris-versicolor
# 43 5.1 3.8 1.9 0.4 Iris-setosa
# 93 5.6 2.7 4.2 1.3 Iris-versicolor
# 79 5.5 2.4 3.8 1.1 Iris-versicolor
# 100 5.8 2.7 5.1 1.9 Iris-virginica
# 14 5.7 4.4 1.5 0.4 Iris-setosa
gini(sepal_len>5|class) = 1-p(Iris-versicolor)^2-p(Iris-setosa)^2-p(Iris-virginica)^2
= 1 - (5/8)^2 - (2/8)^2 - (1/8)^2
= 0.53125
# 41 4.4 3.2 1.3 0.2 Iris-setosa
# 79 5.5 2.4 3.8 1.1 Iris-versicolor
gini(sepal_len<=5|class) = 1 - p(Iris-setosa)^2 - p(Iris-versicolor)^2
= 1 - 1/2 ^2 - 1/2 ^2
= 0.5
Gini(split = 6) = p(sepal_len>6)*Gini(sepal_len>6) + p(sepal_len<=6)*Gini(sepal_len<=6)
= 0 + 7/10*0.6479591836734694
= 0.4535714285714285
# 86 6.3 2.3 4.4 1.3 Iris-versicolor
# 85 6.7 3.1 4.7 1.5 Iris-versicolor
# 57 6.6 2.9 4.6 1.3 Iris-versicolor
gini(sepal_len>6) = 0
# 43 5.1 3.8 1.9 0.4 Iris-setosa
# 93 5.6 2.7 4.2 1.3 Iris-versicolor
# 6 5.0 3.4 1.5 0.2 Iris-setosa
# 41 4.4 3.2 1.3 0.2 Iris-setosa
# 79 5.5 2.4 3.8 1.1 Iris-versicolor
# 100 5.8 2.7 5.1 1.9 Iris-virginica
# 14 5.7 4.4 1.5 0.4 Iris-setosa
gini(sepal_len<=6) = 1 - p(Iris-setosa)^2 - p(Iris-versicolor)^2 - p(Iris-virginica)^2
= 1 - 1/2 ^2 - 2/7 ^2 - 1/7^2
= 0.6479591836734694
Gini(split = 6) is small and good