决策树分类算法:ID3 & C4.5 & CART

分类的概念

分类的基本任务就是根据给定的一系列属性集,最后去判别它属于的类型!

比如我们现在需要去给动物分类,类别可选项为哺乳类,爬行类,鸟类,鱼类,或者两栖类。给你一些属性集如这个动物的体温,是否胎生,是否为水生动物,是否为飞行动物,是否有腿,是否冬眠。

现在分类的基本任务就是,已知一个动物的属性集,判断或预测这个动物属于哪一种类别?

决策树分类法

简述

从根节点开始,每个分支都会包含一个属性测试条件,用于分开具有不同特性的记录,最终到达叶节点,即可得到类标号。

具体过程

从根节点开始,从众多的属性集里边选择一个属性,由这个属性把数据进行分类(该属性的一个值则形成一个孩子节点),得到这个根节点的多个孩子节点。
再由这些孩子节点开始选择剩余的属性来进行分类,递归的进行下去,直至所有属性都已经使用完毕!

问题

(1). 如何确定选择哪个属性来作为测试条件?
某个分类的熵值定义为:
这里写图片描述
所以对于一个属性来说,分类后的熵值越低说明数据的纯度越高,这个正是我们想要得到的结果,故使用这个指标来判断属性的优先选择权。
(2). 如何终止递归?避免过度拟合?
数据中可能会出现一些离群点,这会造成决策树在进行决策的过程中对这样的数据非常敏感,所以我们可以使用一个阈值来终止递归(即当前的节点下数据标号的纯度已经满足某个阈值)。

关键代码

private void buildDecisionTree(AttrNode node, String parentAttrValue, String[][] remainData, ArrayList remainAttr, boolean isID3) {
        node.setParentAttrValue(parentAttrValue);

        String attrName = "";
        double gainValue = 0;
        double tempValue = 0;

        if(remainAttr.size() == 1) {
            System.out.println("attr null");
            return ;
        }
        // 在所有剩余属性集里选择一个信息增益最大的属性
        for(int i = 0;i < remainAttr.size();i ++) {
            if(isID3) {
                // ID3算法计算信息增益
                tempValue = computeGain(remainData, remainAttr.get(i));
            } else {
                // C4.5算法计算信息增益比
                tempValue = computeGainRatio(remainData, remainAttr.get(i));
            }

            if(tempValue > gainValue) {
                gainValue = tempValue;
                // 找到最佳的属性
                attrName = remainAttr.get(i);
            }
        }

        node.setAttrName(attrName);
        // 得到这个属性下的所有取值 去进一步拓展孩子节点
        ArrayList valueTypes = attrValue.get(attrName);
        // 移除掉这个已经使用了的属性
        remainAttr.remove(attrName);

        AttrNode[] childNode = new AttrNode[valueTypes.size()];
        String[][] rData;
        // 遍历这个属性的所有取值
        for(int i = 0;i < valueTypes.size();i ++) {
            // 把该种取值下的数据提取出来
            rData = removeData(remainData, attrName, valueTypes.get(i));

            childNode[i] = new AttrNode();
            boolean sameClass = true;
            ArrayList indexArray = new ArrayList<>();
            // 遍历剩余的数据
            for(int k = 1;k < rData.length;k ++) {
                indexArray.add(rData[k][0]);
                if (!rData[k][attrNames.length - 1].equals(rData[1][attrNames.length - 1])) {
                    sameClass = false;
                    break;
                }
            }

            if(!sameClass) {
                buildDecisionTree(childNode[i], valueTypes.get(i), rData, remainAttr, isID3);
            } else {
                // 如果数据中标号全部相同(或者是达到了某个阈值)停止递归
                childNode[i].setParentAttrValue(valueTypes.get(i));
                childNode[i].setChildDataIndex(indexArray);
            }
        }
        // 递归完成后,给头结点设定孩子节点
        node.setChildAttrNode(childNode);
    }

总结

决策树分类算法是属于监督学习的算法,也就是他需要初始的数据来进行训练,去得到一个经过训练的模型。然后这个模型就可以用来根据属性集预测标号。它的不足在于它无法进行增量计算,也就是当新增一些已知的数据集的时候,只有重新结合之前的数据来重新构建决策树,而无法仅仅利用增量来构建强化。但是这类算法的思路非常简单,理解起来也不难。

引申

CART算法(Classification And Regression Tree):也是一种决策树分类算法,与之前的C4.5和ID3不同的是:
1. 每个非叶子节点都有两个孩子节点,这也就意味着划分条件仅为等于和不等于某个值,来对数据进行划分空间。
2. CART算法对于属性的值采用的是基于Gini系数值的方式做比较,举一个网上的一个例子:(划分条件为体温是否恒温)
比如体温为恒温时包含哺乳类5个、鸟类2个,则:

Gini(left_child)=1(57)2(27)2=2049

体温为非恒温时包含爬行类3个、鱼类3个、两栖类2个,则
Gini(right_child)=1(38)2(38)2(28)2=4264

所以如果按照“体温为恒温和非恒温”进行划分的话,我们得到 GINI的增益(类比信息增益):
Gini(A)=7152049+8154264

最好的划分就是使得GINI_Gain最小的划分
通过比较每个属性的最小的gini指数值,作为最后的结果。
3. CART算法在把数据进行分类之后,会对树进行一个剪枝,常用的用前剪枝和后剪枝法,而常见的后剪枝发包括代价复杂度剪枝,悲观误差剪枝等等.代价复杂度剪枝的公式为:
r=R(t)R(Tt)NTt1

其中 R(t) 表示如果对节点进行剪枝的话,最终的误差代价 = 该节点的误差率 * 该节点数据数目所占比例, R(Tt) 表示如果没有进行剪枝的话,这颗子树所有的叶子节点的误差代价之和, NTt 表示该子树叶子节点的个数。

scikit-learn使用

from sklearn import tree
# 有一些可选择参数 可以查看文档
clf = tree.DecisionTreeClassifier()

clf.fit(features_train, lables_train)

pre = clf.prediction(features_test)

你可能感兴趣的:(经典算法)