之前的文章已经解决了数据预处理的问题。从这里开始,就要开始创建决策树了。
首先可以使用之前用Java实现的ID3算法进行修改。 之前的算法是基于Weka自带的数据进行的,跟这里的格式不太兼容。基本上需要把String改成Double就好了~
现在先尝试手动的创建模型,保证待会我们写出来的代码确实是正确的。
关于决策树模型以及ID3算法,具体的概念以及思路就不在这里重复写了,可以参考《数据挖掘导论》相关章节。
之前已经处理好的dataMatrix可以下载附件之中的train-matrix.csv. 然后直接使用Excel完成最简单的统计功能。
比如,第一步我需要统计Suvived之中1跟0的个数:
即: Survived=0 有549条记录, Survived=1 有342条记录。
可以使用如下代码计算熵:[代码来源:http://commons.apache.org/proper/commons-math/jacoco/org.apache.commons.math3.stat.inference/GTest.java.html 根据里面的entropy进行修改]
public static double entropy(final int[] k) { double h = 0d; double sum_k = 0d; for (int i = 0; i < k.length; i++) { sum_k += (double) k[i]; } for (int i = 0; i < k.length; i++) { if (k[i] != 0) { final double p_i = (double) k[i] / sum_k; h += p_i * FastMath.log(p_i); } } return -h; }
对于Survived, Entropy=0.9607
接下来就应该逐个的计算各个属性对应的熵以及对应的信息增益(Info Gain) 了
以PClass为例:
0.9607-(80+136)/891. * 0.9509 - (97+87)/891. * 0.9978 - (372+119)/891.*0.7989因此,Pclass属性的信息增益为
=0.0229
一次类推,计算Sex,Age,SibSp,Embarked对应的信息增益,结果如下:
Pclass:0.0838310452960116
Sex:0.2176601066606142
Age:0.010620040421108423
SibSp:0.022557964533659103
Embarked:0.024047090707960517
最终,选择Sex作为根节点。
我们看看Sex的数据情况吧,
不得不说一句:女性的存活几率要比男性大得多啊!
接下来,计算第二层。 我们先计算Sex=1(male) 的情况
此时的Entropy=entropy(468,109)=0.6992 Sum = 468+109=577
详细情况:
0.6992 - 0.9567 * (45 + 77) / 577.0 - 0.628 * (91+17) / 577.0 - 0.5722 * (300 + 47) / 577.0
= 0.0352
后面的就不手动进行了~~~
[不得不再吐槽一下,Pclass=1的时候,生存的几率真的是非常非常大啊!不知道是不是当时的有钱人离救生艇比较近?]
具体的ID3分类器,可以参看我写的代码:https://gitcafe.com/rangerwolf/Kaggle-Titanic/blob/master/src/main/java/classifier/ID3Classifier.java。
运行test.MyID3.java即可得到结果。
将整个树json的格式输出出来:
{ "attribute": "Sex", "options": { "2.0": { "attribute": "Pclass", "options": { "3.0": { "attribute": "SibSp", "options": { "0.0": { "attribute": "Age", "options": { }, "subLeafs": { } }, "1.0": { "attribute": "Age", "options": { }, "subLeafs": { "3.0": { "count": 5, "outputValue": 0.0, "option": 3.0 } } } }, "subLeafs": { } }, "2.0": { "attribute": "Age", "options": { "2.0": { "attribute": "SibSp", "options": { }, "subLeafs": { "2.0": { "count": 3, "outputValue": 1.0, "option": 2.0 } } } }, "subLeafs": { "1.0": { "count": 10, "outputValue": 1.0, "option": 1.0 } } }, "1.0": { "attribute": "Age", "options": { "3.0": { "attribute": "SibSp", "options": { }, "subLeafs": { "1.0": { "count": 12, "outputValue": 1.0, "option": 1.0 } } }, "2.0": { "attribute": "SibSp", "options": { }, "subLeafs": { "0.0": { "count": 34, "outputValue": 1.0, "option": 0.0 }, "2.0": { "count": 4, "outputValue": 1.0, "option": 2.0 } } } }, "subLeafs": { } } }, "subLeafs": { } }, "1.0": { "attribute": "Pclass", "options": { "3.0": { "attribute": "Age", "options": { "3.0": { "attribute": "SibSp", "options": { }, "subLeafs": { "1.0": { "count": 2, "outputValue": 0.0, "option": 1.0 } } }, "2.0": { "attribute": "SibSp", "options": { }, "subLeafs": { } }, "1.0": { "attribute": "SibSp", "options": { }, "subLeafs": { "1.0": { "count": 5, "outputValue": 1.0, "option": 1.0 } } } }, "subLeafs": { } }, "2.0": { "attribute": "Age", "options": { "2.0": { "attribute": "SibSp", "options": { }, "subLeafs": { "2.0": { "count": 4, "outputValue": 0.0, "option": 2.0 } } } }, "subLeafs": { "1.0": { "count": 9, "outputValue": 1.0, "option": 1.0 } } }, "1.0": { "attribute": "Age", "options": { "3.0": { "attribute": "SibSp", "options": { }, "subLeafs": { } }, "2.0": { "attribute": "SibSp", "options": { }, "subLeafs": { } } }, "subLeafs": { "1.0": { "count": 3, "outputValue": 1.0, "option": 1.0 } } } }, "subLeafs": { } } }, "subLeafs": { } }
用GUI的方式来显示json,部分结果如下:
可以看到,大致已经有了雏形。而且可以验证的就是,至少我们的根节点是正确的。
下面是老外的成果图:(是基于Python做出来的,不过没太看懂里面的结果,感觉只有一条边有label说明~)
下一篇文章,如果不出意外,将介绍一下Dot Language的应用。
后面的树状图,将会使用Dot Language以及相应的软件来进行展示。
PS:明天又要开始上班了~ 哎,可以用来学习的时间要少得多了...