决策树id3算法 java_决策树ID3算法的java实现

决策树的分类过程和人的决策过程比较相似,就是先挑“权重”最大的那个考虑,然后再往下细分。比如你去看医生,症状是流鼻涕,咳嗽等,那么医生就会根据你的流鼻涕这个权重最大的症状先认为你是感冒,接着再根据你咳嗽等症状细分你是否为病毒性感冒等等。决策树的过程其实也是基于极大似然估计。那么我们用一个什么标准来衡量某个特征是权重最大的呢,这里有信息增益和基尼系数两个。ID3算法采用的是信息增益这个量。

根据《统计学习方法》中的描述,G(D,A)表示数据集D在特征A的划分下的信息增益。具体公式:

G(D,A)=H(D)-H(D|A)。其中H(D)表示数据集D的熵,熵可以用来描述其混乱度,计算公式为

H(D)=

99e545aaa8f01f5ea9240f2832037a5b.png可见对于数据集D而言,|Dk|表示类别为k的数量,其类别越多,越混乱。

而H(D|A)表示数据集D在A的划分下的的不确定性。他们的差也即是信息增益,表示由于特征A使得数据集D的分类的不确定减少的差,所以这个值越大说明A的分类对D越有效,也就是权重越大。

H(D|A)=

6ca3e39598e1e795fd9e857fd8f6ba26.png|Dik|表示在特征A中value为i的划分下数据集类别为k的数量。

有了这两个公式,接下来就可以写代码了。这里为了清晰的表示这个结果,采用了xml来输出。由于刚开始学java所以只能即学即用(java和python简直不能比,python写ID3一百行代码妥妥的搞定,java用了将近300行。。。)

算法步骤:

输入:数据集D,特征集A(这里也可以输入一个阀值,如果信息增益小于该阀值就直接作为叶节点,这样可以避免过拟合)

输出:xml文件

1 如果D中的类别是同一类,则作为叶节点,标记为该类Ck

2 如果特征集A中没有特征了,那么作为叶节点,并且用数据集D中类别最多的类作为类标记

3 对D的各个特征求最大信息增益,选择信息增益最大的特征Ag

4 对特征Ag中各个值ai继续对数据集进行分割为Di

5 以Di为数据集,A-Ag为特征集为输入进行1-4步骤

具体代码:

1 importjava.io.BufferedReader;2 importjava.io.FileInputStream;3 importjava.io.FileWriter;4 importjava.io.IOException;5 importjava.io.InputStreamReader;6 importjava.util.ArrayList;7 importjava.util.HashMap;8 importjava.util.HashSet;9 importjava.util.Map;10 importjava.util.Set;11

12 importorg.dom4j.Document;13 importorg.dom4j.DocumentHelper;14 importorg.dom4j.Element;15 importorg.dom4j.io.XMLWriter;16

17

18

19

20

21

22 classUtils{23 //用于从文件中获取数据集

24 public static ArrayList> loadDataSet(String file) throwsIOException{25 ArrayList> dataSet=new ArrayList>();26 FileInputStream fis=newFileInputStream(file);27 InputStreamReader isr=new InputStreamReader(fis,"UTF-8");28 BufferedReader br=newBufferedReader(isr);29 String line="";30 line=br.readLine();31 while((line=br.readLine())!=null){32 String[] words=line.split(",");33 ArrayList data=new ArrayList();34 for(int i=0;i

45 public static ArrayList loadFeature(String file) throwsIOException{46 FileInputStream fis=newFileInputStream(file);47 InputStreamReader isr=new InputStreamReader(fis,"UTF-8");48 BufferedReader br=newBufferedReader(isr);49

50 String[] line=br.readLine().split(",");51 ArrayList feature=new ArrayList();52 for(int i=0;i

61 public static ArrayList getClassList(ArrayList>dataSet){62 ArrayList classList=new ArrayList();63 int length=dataSet.get(0).size();64 for(ArrayListdata:dataSet){65 String label=data.get(length-1);66 classList.add(label);67 }68 returnclassList;69 }70 //返回数据集中的特征数

71 public static int featureNum(ArrayList>dataList){72 int len=dataList.get(0).size()-1;73 returnlen;74 }75

76

77 //public static void writeToXML(String fileName) throws IOException{78 //Document document = DocumentHelper.createDocument();79 //Element root = document.addElement("DecisionTree");80 //Element outlook=root.addElement("outlook");81 //outlook.addAttribute("value","sunny");82 //Element humidity1=outlook.addElement("humidity");83 //humidity1.addAttribute("value","high");84 //humidity1.addText("no");85 //Element humidity2=outlook.addElement("humidity");86 //humidity2.addAttribute("value","normal");87 //humidity2.addText("yes");88 //

89 //XMLWriter writer=new XMLWriter(new FileWriter(fileName));90 //writer.write(document);91 //writer.close();92 //}93 //用于获得数据集中第index列的map映射,方便后续的遍历value和计算熵

94 public static Map getSubMap(ArrayList> dataSet,intindex){95 int total=dataSet.size();96 Map subMap=newHashMap();97 for(ArrayListdata:dataSet){98 String lable=data.get(index);99 if(subMap.get(lable)==null){100 subMap.put(lable,1);101 }else{102 subMap.put(lable,subMap.get(lable)+1);103 }104 }105 returnsubMap;106 }107 //打印map,用于debug的时候

108 public static void showMap(Mapmap){109 for(Map.Entryentry:map.entrySet()){110 System.out.println(entry.getKey()+":"+entry.getValue());111 }112 }113 //求熵

114 public static double getEntropy(ArrayList> dataSet,intindex){115 int total=dataSet.size();116 Map subMap=getSubMap(dataSet,index);117 double entropy=0;118 for(Map.Entryentry:subMap.entrySet()){119 double temp=entry.getValue()*1.0/total;120 entropy+=temp*(Math.log(temp)/Math.log(2));121 }122 return -entropy;123 }124 //求信息增益最大的分割点

125 public static String bestFeatureSplit(ArrayList> dataSet,ArrayListfeatureList){126 int length=dataSet.get(0).size();127 double totalEntropy=getEntropy(dataSet,length-1);128

129

130

131 int featureNum=dataSet.get(0).size()-1;132 int index=-1;133 double maxInfoGain=-1;134 for(int i=0;i map=getSubMap(dataSet,i);//获得该特征下的map

137 ArrayList lableList=new ArrayList();138 double entropySum=0;139

140 for(Map.Entry entry:map.entrySet()){//这里的Di就是map中的特征的value值

141 Map subMap=newHashMap();142

143

144 for(ArrayListdata:dataSet){145 if(data.get(i).compareTo(entry.getKey())==0){146 if(subMap.get(data.get(length-1))==null){147

148 subMap.put(data.get(length-1),1);149 }else{150 subMap.put(data.get(length-1),subMap.get(data.get(length-1))+1);151 }152 }153 }154 double x=0;155 for(Map.EntrysubEntry:subMap.entrySet()){156 double temp=subEntry.getValue()*1.0/entry.getValue();157 x+=temp*(Math.log(temp)/Math.log(2));158 }159

160 entropySum+=-x*(entry.getValue())/dataSet.size();161 }162 entropySum=totalEntropy-entropySum;163 if(entropySum>maxInfoGain){164 index=i;165 maxInfoGain=entropySum;166 }167 }168 returnfeatureList.get(index);169 }170 //分割数据集,index为特征的下标

171 public static ArrayList> splitDataSet(ArrayList> dataSet,intindex,String value){172 ArrayList> subDataSet=new ArrayList>();173 for(ArrayListdata:dataSet){174 if(data.get(index).compareTo(value)==0){175 ArrayList temp=new ArrayList();176 for(int i=0;i

187 public static Map arrayToMap(ArrayListlist){188 Map map=newHashMap();189 for(String word:list){190 if(map.get(word)==null){191 map.put(word,1);192 }else{193 map.put(word,map.get(word)+1);194 }195 }196 returnmap;197 }198 //求label中某个数量最多的类别

199 public static String major(ArrayListlabelList){200 Map map=arrayToMap(labelList);201 int max=0;202 String label="";203 for(Map.Entryentry:map.entrySet()){204 if(entry.getValue()>max){205 label=entry.getKey();206 }207 }208 returnlabel;209 }210

211 public static Set getValueFromDataSet(ArrayList> dataSet,intindex){212 ArrayList values=new ArrayList();213 for(ArrayListdata:dataSet){214 try{215 values.add(data.get(index));216 }catch(Exception e){217

218 System.out.println("index is "+index);219 }220 }221 Set set=newHashSet();222 for(String value:values){223 set.add(value);224 }225 returnset;226 }227

228 public static ArrayList copyArrayList(ArrayListsrc){229 ArrayList dest=new ArrayList();230 for(String s:src){231 dest.add(s);232 }233 returndest;234 }235

236

237 public static void showArrayList(ArrayList>dataSet){238 for(ArrayListdata:dataSet){239 System.out.println(data);240 }241 }242

243 }244

245

246 public classDecisionTree {247

248

249 public static int createTree(ArrayList> dataSet,ArrayListfeatureList,Element e){250 ArrayList labelList=Utils.getClassList(dataSet);//获取数据集中label的列表

251 if(Utils.arrayToMap(labelList).size()==1){//表示label中只有一种类别,所以此时不需要再分类了

252 e.addText(labelList.get(0));253 return 1;254 }255 if(dataSet.get(0).size()==1){//表示此时已经没有特征了,所以也不需要再继续了,此时以label中最多的类别为该节点的类别

256 e.addText(Utils.major(labelList));257 return 1;258 }259

260 ArrayList subFeatureList=Utils.copyArrayList(featureList);261

262

263

264 String feature=Utils.bestFeatureSplit(dataSet,featureList);265 subFeatureList.remove(feature);266 int index=featureList.indexOf(feature);267

268 Set valueSet=Utils.getValueFromDataSet(dataSet,index);269 //Element next=e.addElement(feature);//原来的代码位置

270 for(String value:valueSet){271 Element next=e.addElement(feature);//后来放到这里之后,xml的输出就正确了,原因在于每递归一次就需要创建一个element,所以应该在for内创建。

272 next.addAttribute("value",value);273 ArrayList> subDataSet=Utils.splitDataSet(dataSet,index,value);274 createTree(subDataSet,subFeatureList,next);275 }276 return 1;277 }278

279 public static void main(String[] args) throwsIOException {280 //TODO Auto-generated method stub

281 String file="C:/Users/Administrator/Desktop/upload/DT.txt";282 String xml="C:/Users/Administrator/Desktop/upload/DT1.xml";283 ArrayList> dataSet=Utils.loadDataSet(file);284 ArrayList featureList=Utils.loadFeature(file);285 Document document =DocumentHelper.createDocument();286 Element root = document.addElement("DecisionTree");287 createTree(dataSet,featureList,root);288 XMLWriter writer=new XMLWriter(newFileWriter(xml));289 writer.write(document);290 writer.close();291 System.out.println("finished");292 }293

294 }

这次除了算法上的理解更加深刻了外,在java上也学到了些关于xml解析,读写等方法。

另外对递归的使用也更加形象些,对于递归一个容易错的点就是函数上的参数,一定要认真对待,要清楚该参数该在什么时候初始化,什么时候被用到。我一开始在第269行上就出现错误了,一开始没有考虑清楚这个next该在什么时候分配,后来发现每次创建节点的时候我们在xml就要创建一个相应的节点用来描述他,所以应该是在for循环里面创建,如果在for外面创建就表示,该特征下的所有值都只有一个element。

当然对于set,map的遍历啥的也更加清晰了。

你可能感兴趣的:(决策树id3算法,java)