转自 http://blog.csdn.net/luowen3405/article/details/6250731
决策树是以实例为基础的归纳学习算法。它从一组无次序、无规则的元组中推理出决策树表示形式的分类规则。它采用自顶向下的递归方式,在决策树的内部结点进行属性值的比较,并根据不同的属性值从该结点向下分支,叶结点是要学习划分的类。从根到叶结点的一条路径就对应着一条合取规则,整个决策树就对应着一组析取表达式规则。
1986年Quinlan提出了著名的ID3算法。在ID3算法的基础上,1993年Quinlan又提出了C4.5算法。为了适应处理大规模数据集的需要,后来又提出了若干改进的算法,其中SLIQ (super-vised learning in quest)和SPRINT (scalable parallelizableinduction of decision trees)是比较有代表性的两个算法。
(1) ID3算法
ID3算法的核心是:在决策树各级结点上选择属性时,用信息增益(information gain)作为属性的选择标准,以使得在每一个非叶结点进行测试时,能获得关于被测试记录最大的类别信息。其具体方法是:检测所有的属性,选择信息增益最大的属性产生决策树结点,由该属性的不同取值建立分支,再对各分支的子集递归调用该方法建立决策树结点的分支,直到所有子集仅包含同一类别的数据为止。最后得到一棵决策树,它可以用来对新的样本进行分类。
ID3算法的优点是:
算法的理论清晰,方法简单,学习能力较强。其缺点是:只对比较小的数据集有效,且对噪声比较敏感,当训练数据集加大时,决策树可能会随之改变。
(2) C4.5算法
C4.5算法继承了ID3算法的优点,并在以下几方面对ID3算法进行了改进:
1) 用信息增益率来选择属性,克服了用信息增益选择属性时偏向选择取值多的属性的不足;
2) 在树构造过程中进行剪枝;
3) 能够完成对连续属性的离散化处理;
4) 能够对不完整数据进行处理。
C4.5算法与其它分类算法如统计方法、神经网络等比较起来有如下优点:产生的分类规则易于理解,准确率较高。其缺点是:在构造树的过程中,需要对数据集进行多次的顺序扫描和排序,因而导致算法的低效。此外,C4.5只适合于能够驻留于内存的数据集,当训练集大得无法在内存容纳时程序无法运行。
(3) SLIQ算法
SLIQ算法对C4.5决策树分类算法的实现方法进行了改进,在决策树的构造过程中采用了“预排序”和“广度优先策略”两种技术。
1) 预排序。对于连续属性在每个内部结点寻找其最优分裂标准时,都需要对训练集按照该属性的取值进行排序,而排序是很浪费时间的操作。为此,SLIQ算法采用了预排序技术。所谓预排序,就是针对每个属性的取值,把所有的记录按照从小到大的顺序进行排序,以消除在决策树的每个结点对数据集进行的排序。具体实现时,需要为训练数据集的每个属性创建一个属性列表,为类别属性创建一个类别列表。
2) 广度优先策略。在C4.5算法中,树的构造是按照深度优先策略完成的,需要对每个属性列表在每个结点处都进行一遍扫描,费时很多,为此,SLIQ采用广度优先策略构造决策树,即在决策树的每一层只需对每个属性列表扫描一次,就可以为当前决策树中每个叶子结点找到最优分裂标准。
SLIQ算法由于采用了上述两种技术,使得该算法能够处理比C4.5大得多的训练集,在一定范围内具有良好的随记录个数和属性个数增长的可伸缩性。
然而它仍然存在如下缺点:
1)由于需要将类别列表存放于内存,而类别列表的元组数与训练集的元组数是相同的,这就一定程度上限制了可以处理的数据集的大小。
2) 由于采用了预排序技术,而排序算法的复杂度本身并不是与记录个数成线性关系,因此,使得SLIQ算法不可能达到随记录数目增长的线性可伸缩性。
(4) SPRINT算法
为了减少驻留于内存的数据量,SPRINT算法进一步改进了决策树算法的数据结构,去掉了在SLIQ中需要驻留于内存的类别列表,将它的类别列合并到每个属性列表中。这样,在遍历每个属性列表寻找当前结点的最优分裂标准时,不必参照其他信息,将对结点的分裂表现在对属性列表的分裂,即将每个属性列表分成两个,分别存放属于各个结点的记录。
SPRINT算法的优点是在寻找每个结点的最优分裂标准时变得更简单。其缺点是对非分裂属性的属性列表进行分裂变得很困难。解决的办法是对分裂属性进行分裂时用哈希表记录下每个记录属于哪个孩子结点,若内存能够容纳下整个哈希表,其他属性列表的分裂只需参照该哈希表即可。由于哈希表的大小与训练集的大小成正比,当训练集很大时,哈希表可能无法在内存容纳,此时分裂只能分批执行,这使得SPRINT算法的可伸缩性仍然不是很好。
本人对ID3的算法实现做了如下假设与处理:
1. 假设所有的属性值域都是分类型或名词离散型的
2.求信息增益时,log函数本来应以2为底,但是为了方便起见,直接调用了java.util.Math类中的以e为底的log函数,无论以什么为底均不会对影响结果产生影响
3.最后的输出并没有以树结构的形式给出,但是可以根据输出结果分析出决策树的结构
java实现代码如下
决策树结点类 class TreeNode
package DecisionTree; import java.util.ArrayList; /** * 决策树结点类 * @author mgq * @data 2012.01.09 */ public class TreeNode { private String name; //节点名(分裂属性的名称) private ArrayList<String> rule; //结点的分裂规则 ArrayList<TreeNode> child; //子结点集合 private ArrayList<ArrayList<String>> datas; //划分到该结点的训练元组 private ArrayList<String> candAttr; //划分到该结点的候选属性 public TreeNode() { this.name = ""; this.rule = new ArrayList<String>(); this.child = new ArrayList<TreeNode>(); this.datas = null; this.candAttr = null; } public ArrayList<TreeNode> getChild() { return child; } public void setChild(ArrayList<TreeNode> child) { this.child = child; } public ArrayList<String> getRule() { return rule; } public void setRule(ArrayList<String> rule) { this.rule = rule; } public String getName() { return name; } public void setName(String name) { this.name = name; } public ArrayList<ArrayList<String>> getDatas() { return datas; } public ArrayList<String> getCandAttr() { return candAttr; } public void setCandAttr(ArrayList<String> candAttr) { this.candAttr = candAttr; } public void setDatas(ArrayList<ArrayList<String>> datas2) { // TODO Auto-generated method stub this.datas = datas2; } }
package DecisionTree; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.Map; /** * 决策树构造类 * @author mgq * @data 2012.01.09 */ public class DecisionTree { private Integer attrSelMode; //最佳分裂属性选择模式,1表示以信息增益度量,2表示以信息增益率度量。暂未实现2 public DecisionTree(){ this.attrSelMode = 1; } public DecisionTree(int attrSelMode) { this.attrSelMode = attrSelMode; } public void setAttrSelMode(Integer attrSelMode) { this.attrSelMode = attrSelMode; } /** * 获取指定数据集中的类别及其计数 * @param datas 指定的数据集 * @return 类别及其计数的map */ public Map<String, Integer> classOfDatas(ArrayList<ArrayList<String>> datas){ Map<String, Integer> classes = new HashMap<String, Integer>(); String c = ""; ArrayList<String> tuple = null; for (int i = 0; i < datas.size(); i++) { tuple = datas.get(i); c = tuple.get(tuple.size() - 1); if (classes.containsKey(c)) { classes.put(c, classes.get(c) + 1); } else { classes.put(c, 1); } } return classes; } /** * 获取具有最大计数的类名,即求多数类 * @param classes 类的键值集合 * @return 多数类的类名 */ public String maxClass(Map<String, Integer> classes){ String maxC = ""; int max = -1; Iterator iter = classes.entrySet().iterator(); for(int i = 0; iter.hasNext(); i++) { Map.Entry entry = (Map.Entry) iter.next(); String key = (String)entry.getKey(); Integer val = (Integer) entry.getValue(); if(val > max){ max = val; maxC = key; } } return maxC; } /** * 构造决策树 * @param datas 训练元组集合 * @param attrList 候选属性集合 * @return 决策树根结点 */ public TreeNode buildTree(ArrayList<ArrayList<String>> datas, ArrayList<String> attrList){ System.out.print("候选属性列表: "); for (int i = 0; i < attrList.size(); i++) { System.out.print(" " + attrList.get(i) + " "); } System.out.println(); TreeNode node = new TreeNode(); node.setDatas(datas); node.setCandAttr(attrList); Map<String, Integer> classes = classOfDatas(datas); String maxC = maxClass(classes); if (classes.size() == 1 || attrList.size() == 0) { node.setName(maxC); return node; } Gain gain = new Gain(datas, attrList); int bestAttrIndex = gain.bestGainAttrIndex(); ArrayList<String> rules = gain.getValues(datas, bestAttrIndex); node.setRule(rules); node.setName(attrList.get(bestAttrIndex)); if(rules.size() > 2){ //?此处有待商榷 attrList.remove(bestAttrIndex); } for (int i = 0; i < rules.size(); i++) { String rule = rules.get(i); ArrayList<ArrayList<String>> di = gain.datasOfValue(bestAttrIndex, rule); for (int j = 0; j < di.size(); j++) { di.get(j).remove(bestAttrIndex); } if (di.size() == 0) { TreeNode leafNode = new TreeNode(); leafNode.setName(maxC); leafNode.setDatas(di); leafNode.setCandAttr(attrList); node.getChild().add(leafNode); } else { TreeNode newNode = buildTree(di, attrList); node.getChild().add(newNode); } } return node; } }
package DecisionTree; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.Map; /** * 选择最佳分裂属性 * @author mgq * @data 2012.01.09 */ public class Gain { private ArrayList<ArrayList<String>> D = null; //训练元组 private ArrayList<String> attrList = null; //候选属性集 public Gain(ArrayList<ArrayList<String>> datas, ArrayList<String> attrList) { this.D = datas; this.attrList = attrList; } /** * 获取最佳侯选属性列上的值域(假定所有属性列上的值都是有限的名词或分类类型的) * @param attrIndex 指定的属性列的索引 * @return 值域集合 */ public ArrayList<String> getValues(ArrayList<ArrayList<String>> datas, int attrIndex){ ArrayList<String> values = new ArrayList<String>(); String r = ""; for (int i = 0; i < datas.size(); i++) { r = datas.get(i).get(attrIndex); if (!values.contains(r)) { values.add(r); } } return values; } /** * 获取指定数据集中指定属性列索引的域值及其计数 * @param d 指定的数据集 * @param attrIndex 指定的属性列索引 * @return 类别及其计数的map */ public Map<String, Integer> valueCounts(ArrayList<ArrayList<String>> datas, int attrIndex){ Map<String, Integer> valueCount = new HashMap<String, Integer>(); String c = ""; ArrayList<String> tuple = null; for (int i = 0; i < datas.size(); i++) { tuple = datas.get(i); c = tuple.get(attrIndex); if (valueCount.containsKey(c)) { valueCount.put(c, valueCount.get(c) + 1); } else { valueCount.put(c, 1); } } return valueCount; } /** * 求对datas中元组分类所需的期望信息,即datas的熵 * @param datas 训练元组 * @return datas的熵值 */ public double infoD(ArrayList<ArrayList<String>> datas){ double info = 0.000; int total = datas.size(); Map<String, Integer> classes = valueCounts(datas, attrList.size()); Iterator iter = classes.entrySet().iterator(); Integer[] counts = new Integer[classes.size()]; for(int i = 0; iter.hasNext(); i++) { Map.Entry entry = (Map.Entry) iter.next(); Integer val = (Integer) entry.getValue(); counts[i] = val; } for (int i = 0; i < counts.length; i++) { double base = DecimalCalculate.div(counts[i], total, 3); info += (-1) * base * Math.log(base); } return info; } /** * 获取指定属性列上指定值域的所有元组 * @param attrIndex 指定属性列索引 * @param value 指定属性列的值域 * @return 指定属性列上指定值域的所有元组 */ public ArrayList<ArrayList<String>> datasOfValue(int attrIndex, String value){ ArrayList<ArrayList<String>> Di = new ArrayList<ArrayList<String>>(); ArrayList<String> t = null; for (int i = 0; i < D.size(); i++) { t = D.get(i); if(t.get(attrIndex).equals(value)){ Di.add(t); } } return Di; } /** * 基于按指定属性划分对D的元组分类所需要的期望信息 * @param attrIndex 指定属性的索引 * @return 按指定属性划分的期望信息值 */ public double infoAttr(int attrIndex){ double info = 0.000; ArrayList<String> values = getValues(D, attrIndex); for (int i = 0; i < values.size(); i++) { ArrayList<ArrayList<String>> dv = datasOfValue(attrIndex, values.get(i)); info += DecimalCalculate.mul(DecimalCalculate.div(dv.size(), D.size(), 3), infoD(dv)); } return info; } /** * 获取最佳分裂属性的索引 * @return 最佳分裂属性的索引 */ public int bestGainAttrIndex(){ int index = -1; double gain = 0.000; double tempGain = 0.000; for (int i = 0; i < attrList.size(); i++) { tempGain = infoD(D) - infoAttr(i); if (tempGain > gain) { gain = tempGain; index = i; } } return index; } }决策树算法测试类 class TestDecisionTree
package DecisionTree; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.StringTokenizer; /** * 决策树算法测试类 * @author mgq * @data 2012.01.09 */ public class TestDecisionTree { /** * 读取候选属性 * @return 候选属性集合 * @throws IOException */ public ArrayList<String> readCandAttr() throws IOException{ ArrayList<String> candAttr = new ArrayList<String>(); BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); String str = ""; while (!(str = reader.readLine()).equals("")) { StringTokenizer tokenizer = new StringTokenizer(str); while (tokenizer.hasMoreTokens()) { candAttr.add(tokenizer.nextToken()); } } return candAttr; } /** * 读取训练元组 * @return 训练元组集合 * @throws IOException */ public ArrayList<ArrayList<String>> readData() throws IOException { ArrayList<ArrayList<String>> datas = new ArrayList<ArrayList<String>>(); BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); String str = ""; while (!(str = reader.readLine()).equals("")) { StringTokenizer tokenizer = new StringTokenizer(str); ArrayList<String> s = new ArrayList<String>(); while (tokenizer.hasMoreTokens()) { s.add(tokenizer.nextToken()); } datas.add(s); } return datas; } /** * 递归打印树结构 * @param root 当前待输出信息的结点 */ public void printTree(TreeNode root){ System.out.println("name:" + root.getName()); ArrayList<String> rules = root.getRule(); System.out.print("node rules: {"); for (int i = 0; i < rules.size(); i++) { System.out.print(rules.get(i) + " "); } System.out.print("}"); System.out.println(""); ArrayList<TreeNode> children = root.getChild(); int size =children.size(); if (size == 0) { System.out.println("-->leaf node!<--"); } else { System.out.println("size of children:" + children.size()); for (int i = 0; i < children.size(); i++) { System.out.print("child " + (i + 1) + " of node " + root.getName() + ": "); printTree(children.get(i)); } } } /** * 主函数,程序入口 * @param args */ public static void main(String[] args) { TestDecisionTree tdt = new TestDecisionTree(); ArrayList<String> candAttr = null; ArrayList<ArrayList<String>> datas = null; try { System.out.println("请输入候选属性"); candAttr = tdt.readCandAttr(); //System.out.println("候选属性是:"+candAttr); System.out.println("请输入训练数据"); datas = tdt.readData(); } catch (IOException e) { e.printStackTrace(); } DecisionTree tree = new DecisionTree(); TreeNode root = tree.buildTree(datas, candAttr); tdt.printTree(root); } }
package DecisionTree; import java.math.BigDecimal; /* * @author mgq * @data 2012.01.09 */ public class DecimalCalculate { /** * 由于Java的简单类型不能够精确的对浮点数进行运算,这个工具类提供精 * 确的浮点数运算,包括加减乘除和四舍五入。 */ //默认除法运算精度 private static final int DEF_DIV_SCALE = 10; //这个类不能实例化 private DecimalCalculate(){ } /** * 提供精确的加法运算。 * @param v1 被加数 * @param v2 加数 * @return 两个参数的和 */ public static double add(double v1,double v2){ BigDecimal b1 = new BigDecimal(Double.toString(v1)); BigDecimal b2 = new BigDecimal(Double.toString(v2)); return b1.add(b2).doubleValue(); } /** * 提供精确的减法运算。 * @param v1 被减数 * @param v2 减数 * @return 两个参数的差 */ public static double sub(double v1,double v2){ BigDecimal b1 = new BigDecimal(Double.toString(v1)); BigDecimal b2 = new BigDecimal(Double.toString(v2)); return b1.subtract(b2).doubleValue(); } /** * 提供精确的乘法运算。 * @param v1 被乘数 * @param v2 乘数 * @return 两个参数的积 */ public static double mul(double v1,double v2){ BigDecimal b1 = new BigDecimal(Double.toString(v1)); BigDecimal b2 = new BigDecimal(Double.toString(v2)); return b1.multiply(b2).doubleValue(); } /** * 提供(相对)精确的除法运算,当发生除不尽的情况时,精确到 * 小数点以后10位,以后的数字四舍五入。 * @param v1 被除数 * @param v2 除数 * @return 两个参数的商 */ public static double div(double v1,double v2){ return div(v1,v2,DEF_DIV_SCALE); } /** * 提供(相对)精确的除法运算。当发生除不尽的情况时,由scale参数指 * 定精度,以后的数字四舍五入。 * @param v1 被除数 * @param v2 除数 * @param scale 表示表示需要精确到小数点以后几位。 * @return 两个参数的商 */ public static double div(double v1,double v2,int scale){ if(scale<0){ throw new IllegalArgumentException( "The scale must be a positive integer or zero"); } BigDecimal b1 = new BigDecimal(Double.toString(v1)); BigDecimal b2 = new BigDecimal(Double.toString(v2)); return b1.divide(b2,scale,BigDecimal.ROUND_HALF_UP).doubleValue(); } /** * 提供精确的小数位四舍五入处理。 * @param v 需要四舍五入的数字 * @param scale 小数点后保留几位 * @return 四舍五入后的结果 */ public static double round(double v,int scale){ if(scale<0){ throw new IllegalArgumentException( "The scale must be a positive integer or zero"); } BigDecimal b = new BigDecimal(Double.toString(v)); BigDecimal one = new BigDecimal("1"); return b.divide(one,scale,BigDecimal.ROUND_HALF_UP).doubleValue(); } /** * 提供精确的类型转换(Float) * @param v 需要被转换的数字 * @return 返回转换结果 */ public static float convertsToFloat(double v){ BigDecimal b = new BigDecimal(v); return b.floatValue(); } /** * 提供精确的类型转换(Int)不进行四舍五入 * @param v 需要被转换的数字 * @return 返回转换结果 */ public static int convertsToInt(double v){ BigDecimal b = new BigDecimal(v); return b.intValue(); } /** * 提供精确的类型转换(Long) * @param v 需要被转换的数字 * @return 返回转换结果 */ public static long convertsToLong(double v){ BigDecimal b = new BigDecimal(v); return b.longValue(); } /** * 返回两个数中大的一个值 * @param v1 需要被对比的第一个数 * @param v2 需要被对比的第二个数 * @return 返回两个数中大的一个值 */ public static double returnMax(double v1,double v2){ BigDecimal b1 = new BigDecimal(v1); BigDecimal b2 = new BigDecimal(v2); return b1.max(b2).doubleValue(); } /** * 返回两个数中小的一个值 * @param v1 需要被对比的第一个数 * @param v2 需要被对比的第二个数 * @return 返回两个数中小的一个值 */ public static double returnMin(double v1,double v2){ BigDecimal b1 = new BigDecimal(v1); BigDecimal b2 = new BigDecimal(v2); return b1.min(b2).doubleValue(); } /** * 精确对比两个数字 * @param v1 需要被对比的第一个数 * @param v2 需要被对比的第二个数 * @return 如果两个数一样则返回0,如果第一个数比第二个数大则返回1,反之返回-1 */ public static int compareTo(double v1,double v2){ BigDecimal b1 = new BigDecimal(v1); BigDecimal b2 = new BigDecimal(v2); return b1.compareTo(b2); } }
package DecisionTree; /* * @author mgq * @data 2012.01.09 * */ public class testDecimalCalc { public static void main(String[] args) { double info = 0.000; double base = DecimalCalculate.div(5, 14, 3); System.out.println(base); System.out.println(Math.log(base)); info += (-1) * base * Math.log(base); System.out.println(info); } }
测试数据:
程序输出结果:
根据输出结果画出的决策树,如下图所示:
出处:http://blog.csdn.net/luowen3405/archive/2011/03/15/6250731.aspx
http://blog.csdn.net/luowen3405/article/details/6249373