基于java简单实现决策树ID3算法

前言

最近在重新学习机器学习里的一些经典算法,希望能温故而知新,看到决策树时发现决策树的实现还是蛮简单的,就用java写了一个小例子。因为最近比较忙,具体的公式以及例子参照下图
基于java简单实现决策树ID3算法_第1张图片

代码

import java.lang.Math;
import java.util.List;
import java.util.ArrayList;
import java.util.Arrays;

public class DesionTree {
    public static void main(String[] args) {
        String[][] data= Dataprocess();
        String[] R=Result();
        String[]  a_num=attribute_num();
        String[][] attribute=attribute();
        double[] G=ID3(data, R, a_num, attribute);
        System.out.println(G[1]);



    }

    //原始数据
    public static String[][] Dataprocess() {
        String [] A = {"老", "帅", "高", "不会"};
        String[] B = {"年轻", "一般", "低", "会"};
        String[] C = {"年轻", "丑", "高", "不会"};
        String[] D = {"年轻", "一般", "高", "会"};
        String [] E = {"年轻", "一般", "低", "不会"};
        String[][] Data = {A, B, C, D, E};
        return Data;
    }

    //原始数据标签
    public static String[] Result() {
        String[] R = {"不见", "见", "不见", "见", "不见"};
        return R;
    }
    //所有属性取值
    public static String[] attribute_num(){
        String[] a_num={"老","年轻", "帅", "一般", "丑", "高", "低", "不会", "会"};
        return a_num;
    }
    public static String[][] attribute(){
        String[][] attribute={{"老","年轻"}, {"帅", "一般"}, {"丑", "高", "低"}, {"不会", "会"}};
        return attribute;
    }
    //ID3--最大信息增益
    public static double[] ID3(String[][] Data, String[] R, String[] a_num, String[][] attribute){
        //定义最大经验熵
        double[] HDA= new double[a_num.length];
        double[] num=new double[a_num.length];
        for (int i=0; i<a_num.length; i++) {
            for (int j = 0; j < attribute.length; j++) {
                for(int k=0; k<Data.length;k++)
                    if (a_num[i] == Data[k][j]) {
                        num[i]++;
                        num[i] = num[i] / Data.length;
                    }
                }
            }
        //对R进行去重
        List<String> R_set=new ArrayList<>();
        for (int i=0; i<R.length;i++){
            if(!R_set.contains(R[i])){
                R_set.add(R[i]);
            }
        }
        double[] R_num=new double[R_set.size()];
        for(int i=0; i<R_num.length;i++){
            R_num[i]=0;
            for (int j=0; j<R.length; j++){
                if (R_set.get(i)==R[j]){
                    R_num[i]++;
                }
            }
        }
        for(int i=0; i<R_num.length; i++){
            R_num[i]=R_num[i]/R.length;
            System.out.println(R_num[i]);
        }
        double HD=0;
        for(int i=0; i<R_set.size();i++){
            HD=HD-R_num[i]*Math.log(R_num[i])/Math.log(2);
        }
        System.out.println(HD);
        double[][] jbj=new double[2][a_num.length];
        System.out.println(jbj.length);
        for (int i=0; i<HDA.length; i++){
            for (int j=0; j<Data.length;j++){
                for (int k=0; k<Data[j].length;k++){
                    if(a_num[i]==Data[j][k]& R[j]==R_set.get(0)){
                        jbj[0][i]++;
                    } else if (a_num[i]==Data[j][k]& R[j]==R_set.get(1)) {
                        jbj[1][i]++;
                    }
                }
            }
        }
        double[] D=new double[a_num.length];
        for(int i=0; i<a_num.length;i++){
            for (int j=0; j<jbj.length;j++){
                System.out.println("分子:"+jbj[j][i]);
                System.out.println("分母:"+(jbj[0][i]+jbj[1][i]));
                if (jbj[j][i]/(jbj[0][i]+jbj[1][i])!=0) {
                    D[i] = D[i] -jbj[j][i] / (jbj[0][i] + jbj[1][i])* Math.log(jbj[j][i] / (jbj[0][i] + jbj[1][i])) / Math.log(2);
                }
            }
            D[i]=D[i]*((jbj[0][i]+jbj[1][i])/5);
            System.out.println("D[i]"+D[i]);

        }


        double[] G=new double[attribute.length];
        for(int i=0; i<attribute.length; i++){
            G[i]=HD-D[i];
        }
        return G;
    }

}

`代码写的有点糙,有兴趣的同学可以私信交流。

你可能感兴趣的:(java,决策树,算法)