C4.5决策树(Java实现)

说明

此前已经上传了ID3决策树的Java实现,C4.5整体架构与之相差不大。
可参考:http://blog.csdn.net/xiaohukun/article/details/78041676
此次将结点的实现由Dom4J改为自定义类实现,更加自由和轻便。

代码已打包并上传

代码

数据仍采用ARFF格式
train.arff

@relation weather.symbolic 
@attribute outlook {sunny,overcast,rainy} 
@attribute temperature {hot,mild,cool} 
@attribute humidity {high,normal} 
@attribute windy {TRUE,FALSE} 
@attribute play {yes,no} 

@data 
sunny,hot,high,FALSE,no 
sunny,hot,high,TRUE,no 
overcast,hot,high,FALSE,yes 
rainy,mild,high,FALSE,yes 
rainy,cool,normal,FALSE,yes 
rainy,cool,normal,TRUE,no 
overcast,cool,normal,TRUE,yes 
sunny,mild,high,FALSE,no 
sunny,cool,normal,FALSE,yes 
rainy,mild,normal,FALSE,yes 
sunny,mild,normal,TRUE,yes 
overcast,mild,high,TRUE,yes 
overcast,hot,normal,FALSE,yes 
rainy,mild,high,TRUE,no

C4.5类(主类)

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.io.FileOutputStream;
import java.io.BufferedOutputStream;
import java.lang.Math.*;



public class DecisionTree {

    private ArrayList train_AttributeName = new ArrayList(); // 存储训练集属性的名称
    private ArrayList> train_attributeValue = new ArrayList>(); // 存储训练集每个属性的取值
    private ArrayList trainData = new ArrayList(); // 训练集数据 ,即arff文件中的data字符串

    public static final String patternString = "@attribute(.*)[{](.*?)[}]";
    //正则表达,其中*? 表示重复任意次,但尽可能少重复,防止匹配到更后面的"}"符号

    private int decatt; // 决策变量在属性集中的索引(即类标所在列)
    private InfoGain infoGain;
    private TreeNode root;


    public void train(String data_path, String targetAttr){
        //模型初始化操作
        read_trainARFF(new File(data_path));
        //printData();
        setDec(targetAttr);
        infoGain=new InfoGain(trainData, decatt);

        //拼装行与列
        LinkedList ll=new LinkedList(); //LinkList用于增删比ArrayList有优势
        for(int i = 0; i< train_AttributeName.size(); i++){
            if(i!=decatt) ll.add(i);  //防止类别变量不在最后一列发生错误
        }
        ArrayList al=new ArrayList();
        for(int i=0;i//构建决策树
        root = buildDT("root", "null", al, ll);
        //剪枝
        cutBranch(root);
    }

    /**
     * 构建决策树
     * @param fatherName 节点名称
     * @param fatherValue 节点值
     * @param subset 数据行子集
     * @param subset 数据列子集
     * @return 返回根节点
     */
    public TreeNode buildDT(String fatherName, String fatherValue, ArrayList subset,LinkedList selatt){
        TreeNode node=new TreeNode();
        Map targetNum = infoGain.get_AttributeNum(subset,decatt);//计算类-频率
        String targetValue=infoGain.get_targetValue(targetNum);//判定分类
        node.setTargetNum(targetNum);
        node.setAttributeName(fatherName);
        node.setAttributeValue(fatherValue);
        node.setTargetValue(targetValue);

        //终止条件为类标单一/树深度达到特征长度(还有可能是信息增益率不存在)
        if (infoGain.isPure(targetNum) | selatt.isEmpty() ) {
            node.setNodeType("leafNode");
            return node;
        }
        int maxIndex = infoGain.getGainRatioMax(subset,selatt);
        selatt.remove(new Integer(maxIndex));  //这样可以remove object
        String childName = train_AttributeName.get(maxIndex);

        Map> childSubset = infoGain.get_AttributeSubset(subset, maxIndex);
        ArrayList childNode = new ArrayList();
        for (String childValue : childSubset.keySet()){
            TreeNode child = buildDT(childName, childValue, childSubset.get(childValue), selatt);
            child.setFatherTreeNode(node);  //顺序很重要:回溯
            childNode.add(child);
        }
        node.setChildTreeNode(childNode);
        return  node;
    }

    /**
     * 剪枝函数
     * @param node 判断结点
     * @return 剪枝之后的叶子结点集
     */
    public ArrayList<int[]> cutBranch(TreeNode node){
        ArrayList<int[]> resultNum = new ArrayList<int[]>();
        if (node.getNodeType() =="leafNode"){
            int[] tempNum = get_leafNum(node);
            resultNum.add(tempNum);
            return resultNum;
        }else{
            int sumNum = 0;
            double oldRatio = 0;
            for (TreeNode child : node.getChildTreeNode()){
                for(int[] leafNum : cutBranch(child)){
                    resultNum.add(leafNum);
                    oldRatio += 0.5 + leafNum[0];
                    sumNum += leafNum[1];
                }
            }
            double oldNum =oldRatio;
            oldRatio /= sumNum;
            double sd = Math.sqrt(sumNum*oldRatio*(1-oldRatio));
            int temLeaf[] = get_leafNum(node);
            double newNum = temLeaf[0] + 0.5;
            if(newNum < oldNum + sd){//符合剪枝条件,剪枝并返回本身
                node.setChildTreeNode(null);
                node.setNodeType("leafNode");
                resultNum.clear();
                resultNum.add(temLeaf);
            }//不符合剪枝条件,返回叶子结点
            return resultNum;
        }
    }

    //获得叶子结点的数目
    public int[] get_leafNum(TreeNode node){
        int[] resultNum= new int[2];
        Map targetNum = node.getTargetNum();
        int minNum = Integer.MAX_VALUE;
        int sumNum = 0;
        for(int num : targetNum.values()){
            minNum = Integer.min(minNum, num);
            sumNum += num;
        }
        if (targetNum.size() == 1) minNum = 0;
        resultNum[0] = minNum;
        resultNum[1] = sumNum;
        return  resultNum;
    }

    /**
     * 读取arff文件,给attribute、attributevalue、data赋值
     * @param file  传入的文件
     */
    public void read_trainARFF(File file) {
        try {
            FileReader fr = new FileReader(file);
            BufferedReader br = new BufferedReader(fr);
            String line;
            Pattern pattern = Pattern.compile(patternString);
            while ((line = br.readLine()) != null) {
                Matcher matcher = pattern.matcher(line);
                if (matcher.find()) {
                    train_AttributeName.add(matcher.group(1).trim()); //获取第一个括号里的内容
                    //涉及取值,尽量加.trim(),后面也可以看到,即使是换行符也可能会造成字符串不相等
                    String[] values = matcher.group(2).split(",");
                    ArrayList al = new ArrayList(values.length);
                    for (String value : values) {
                        al.add(value.trim());
                    }
                    train_attributeValue.add(al);
                } else if (line.startsWith("@data")) {
                    while ((line = br.readLine()) != null) {
                        if(line=="")
                            continue;
                        String[] row = line.split(",");
                        trainData.add(row);
                    }
                } else {
                    continue;
                }
            }
            br.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    /**
     * 打印Data
     */
    public void printData(){
        System.out.println("当前的ATTR为");
        for(String attr : train_AttributeName){
            System.out.print(attr+" ");
        }
        System.out.println();
        System.out.println("---------------------------------");
        System.out.println("当前的DATA为");
        for(String[] row: trainData){
            for (String value : row){
                System.out.print(value+" ");
            }
            System.out.println();
        }
        System.out.println("---------------------------------");
    }

    //将决策树存储到xml文件中
    public void write_DecisionTree(String filename) {
        try {
            File file = new File(filename);
            if (!file.exists())
                file.createNewFile();
            FileOutputStream fs = new FileOutputStream(filename);
            BufferedOutputStream bos = new BufferedOutputStream(fs);
            write_Node(bos, root, "");
            bos.flush();
            bos.close();
            fs.close();
        }catch (IOException e){
            e.printStackTrace();
        }
    }

    private void write_Node(BufferedOutputStream bos, TreeNode node, String block){
        String outputWords1 = block + "<" + node.getAttributeName()+ " value=\"" + node.getAttributeValue() + "\"";
        String outputWords2;
        Map targetNum = node.getTargetNum();
        for (String value : targetNum.keySet()){
            outputWords1 += " " + value + ":" + targetNum.get(value);
        }
        outputWords1 += ">";
        if(node.getNodeType()=="leafNode"){
            outputWords1 += node.getTargetValue();
            outputWords2 = " + node.getAttributeName() + ">" + "\n";
        }else{
            outputWords1 += "\n";
            outputWords2 = block + " + node.getAttributeName() + ">" + "\n";
        }

        try {
            bos.write(outputWords1.getBytes());
        }catch (IOException e){
            e.printStackTrace();
        }
        ArrayList childNode=node.getChildTreeNode();
        if (childNode !=null){
            for (TreeNode child : childNode){
                write_Node(bos, child, block+"  ");
            }
        }

        try {
            bos.write(outputWords2.getBytes());
        }catch (IOException e){
            System.out.println(e.getMessage());
        }
    }

    //设置决策变量
    public void setDec(int n) {
        if (n < 0 || n >= train_AttributeName.size()) {
            System.err.println("决策变量指定错误。");
            System.exit(2);
        }
        decatt = n;
    }
    public void setDec(String targetAttr) {
        int n = train_AttributeName.indexOf(targetAttr);
        setDec(n);
    }



    public static void main(String[] args) {
        DecisionTree dt=new DecisionTree();
        dt.train("files/train.arff", "play");
        dt.write_DecisionTree("files/Tree.xml");
    }

}

节点类

import java.util.ArrayList;
import java.util.Map;

/**
 * 节点类
 */
public class TreeNode {

    private String nodeType;
    private String attributeName;
    private String attributeValue;
    private ArrayList childTreeNode;
    private TreeNode fatherTreeNode;
    private Map targetNum;
    private String targetValue;
    //private List pathName;


    public TreeNode(){
    }

    public String getNodeType() {
        return nodeType;
    }

    public void setNodeType(String nodeType) {
        this.nodeType = nodeType;
    }

    public String getAttributeName() {
        return attributeName;
    }

    public void setAttributeName(String attributeName) {
        this.attributeName = attributeName;
    }

    public String getAttributeValue() {
        return attributeValue;
    }

    public void setAttributeValue(String attributeValue) {
        this.attributeValue = attributeValue;
    }

    public ArrayList getChildTreeNode() {
        return childTreeNode;
    }

    public void setChildTreeNode(ArrayList childTreeNode) {
        this.childTreeNode = childTreeNode;
    }

    public TreeNode getFatherTreeNode() {
        return fatherTreeNode;
    }

    public void setFatherTreeNode(TreeNode fatherTreeNode) {
        this.fatherTreeNode = fatherTreeNode;
    }

    public Map getTargetNum() {
        return targetNum;
    }

    public void setTargetNum(Map targetNum) {
        this.targetNum = targetNum;
    }

    public String getTargetValue() {
        return targetValue;
    }

    public void setTargetValue(String targetValue) {
        this.targetValue = targetValue;
    }
}

信息熵相关类

import java.util.*;


/**
 * 信息增益相关类
 */
public class InfoGain {
    private ArrayList trainData;
    private int decatt;

    public InfoGain(ArrayList trainData, int decatt){
        this.trainData=trainData;
        this.decatt=decatt;
    }


    /**
     * 计算信息熵
     */
    public double getEntropy(Map attributeNum){
        double entropy = 0.0;
        int sum= 0;
        for(int num:attributeNum.values()){
            sum+=num;
            entropy += (-1) * num * Math.log(num+Double.MIN_VALUE)/Math.log(2); //避免log1
        }
        entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
        entropy /= sum;
        return entropy;
    }

    public double getEntropy(ArrayList subset, int attributeIndex){
        Map attributeNum = get_AttributeNum(subset,attributeIndex);
        double entropy = getEntropy(attributeNum);
        return entropy;
    }


    //信息熵增益率相关
    public int getGainRatioMax(ArrayList subset, LinkedList selatt){
        //计算原信息熵

        Map old_TargetNum = get_AttributeNum(subset, decatt);
        double oldEntropy = getEntropy(old_TargetNum);
        double maxGainRatio=0;
        int maxIndex=decatt;

        for(int attributeIndex: selatt){
            Map> attributeSubset = get_AttributeSubset(subset, attributeIndex);

            int sum = 0;
            double newEntropy = 0;
            for(ArrayList tempSubset: attributeSubset.values()){
                int num = tempSubset.size();
                sum += num;
                double tempEntropy = getEntropy(tempSubset,decatt);
                newEntropy += num * tempEntropy;
            }
            newEntropy /= sum;
            double tempGainRatio = (oldEntropy - newEntropy)/getEntropy(subset, attributeIndex);  //计算信息增益率

            //如果信息增益率为负,应该停止分支,此处避免麻烦没有做进一步讨论。
            if(tempGainRatio > maxGainRatio){
                maxGainRatio = tempGainRatio;
                maxIndex = attributeIndex;
            }
        }
        return  maxIndex;
    }

    /**
     * 判断分类是否唯一
     * @param targetNum 各类数目的map
     * @return 分类是否唯一标识
     */
    public boolean isPure(Map targetNum){
        if (targetNum.size()>1){
            return  false;
        }
        return  true;
    }

    /**
     * 获得对应数据子集的对应特征的值-频率字典
     * @param subset 子集行数
     * @param attributeIndex 特征列
     * @return
     */
    public  Map get_AttributeNum(ArrayList subset, int attributeIndex ) {
        Map attributeNum=new HashMap();
        for (int subsetIndex : subset) {
            String value=trainData.get(subsetIndex)[attributeIndex];
            Integer count = attributeNum.get(value);//int无法使用count!=null
            attributeNum.put(value, count!=null ? ++count:1);
        }
        return  attributeNum;
    }

    /**
     * 获得数据在某一特征维度下的子集划分
     * @param subset 原子集
     * @param attributeIndex 特征序号
     * @return 子集划分map
     */
    public Map> get_AttributeSubset(ArrayList subset, int attributeIndex){
        Map> attributeSubset=new HashMap>();
        for (int subsetIndex : subset) {
            String value=trainData.get(subsetIndex)[attributeIndex];
            ArrayList tempSubset = attributeSubset.get(value);
            if(tempSubset != null){
                tempSubset.add(subsetIndex);
            }else{
                tempSubset=new ArrayList();
                tempSubset.add(subsetIndex);
            }
            attributeSubset.put(value,tempSubset);
        }
        return  attributeSubset;
    }

    /**
     * 根据类-数目,判读分类结果
     * @param targetNum
     * @return
     */
    public String get_targetValue(Map targetNum){

         int maxNum=0;
         String targetValue="";
         for(String key: targetNum.keySet()){
             int tempNum=targetNum.get(key);
             if(tempNum>maxNum){
                 maxNum=tempNum;
                 targetValue=key;
             }
         }
         return targetValue;
    }
}

感受

决策树属于比较基本的分类算法,但是在编写代码的过程中,我对于迭代的运用和代码实现有了更进一步地认识。
在C4.5中有两块工作比较重要和复杂,其一,自然是生成决策树;其二,便是实现剪枝。
这二者都是通过迭代来实现的,并且都经历了uptodown和downtoup,只不过前者是在自上而下的过程中完成主要操作,回溯只是用以获得返回的结点;而后者的自上而下只是为了找到各个叶子结点,真正的剪枝工作是在回溯的过程实现的。

问题

此次的代码中并没有实现对连续特征的处理以及缺失值的处理。
后者根据具体的情况变化较大,而前者根据目前提供的函数应该可以比较方便的实现,也就不再浪费时间了,如果有亲希望保证完整性,可以自行补充。

你可能感兴趣的:(算法实现,C4.5决策树)