归纳决策树ID3(Java实现)

ID3就不介绍了,最终的决策树保存在了XML中,使用了Dom4J,注意如果要让Dom4J支持按XPath选择节点,还得引入包jaxen.jar。程序代码要求输入文件满足ARFF格式,并且属性都是标称变量。

001 package dt;
002
003 import java.io.BufferedReader;
004 import java.io.File;
005 import java.io.FileReader;
006 import java.io.FileWriter;
007 import java.io.IOException;
008 import java.util.ArrayList;
009 import java.util.Iterator;
010 import java.util.LinkedList;
011 import java.util.List;
012 import java.util.regex.Matcher;
013 import java.util.regex.Pattern;
014
015 import org.dom4j.Document;
016 import org.dom4j.DocumentHelper;
017 import org.dom4j.Element;
018 import org.dom4j.io.OutputFormat;
019 import org.dom4j.io.XMLWriter;
020
021 public class ID3 {
022     private ArrayList<String> attribute = new ArrayList<String>(); // 存储属性的名称
023     private ArrayList<ArrayList<String>> attributevalue = new ArrayList<ArrayList<String>>(); // 存储每个属性的取值
024     private ArrayList<String[]> data = new ArrayList<String[]>();; // 原始数据
025     int decatt; // 决策变量在属性集中的索引
026     public static final String patternString = "@attribute(.*)[{](.*?)[}]";
027
028     Document xmldoc;
029     Element root;
030
031     public ID3() {
032         xmldoc = DocumentHelper.createDocument();
033         root = xmldoc.addElement("root");
034         root.addElement("DecisionTree").addAttribute("value", "null");
035     }
036
037     public static void main(String[] args) {
038         ID3 inst = new ID3();
039         inst.readARFF(new File("/home/orisun/test/weather.nominal.arff"));
040         inst.setDec("play");
041         LinkedList<Integer> ll=new LinkedList<Integer>();
042         for(int i=0;i<inst.attribute.size();i++){
043             if(i!=inst.decatt)
044                 ll.add(i);
045         }
046         ArrayList<Integer> al=new ArrayList<Integer>();
047         for(int i=0;i<inst.data.size();i++){
048             al.add(i);
049         }
050         inst.buildDT("DecisionTree", "null", al, ll);
051         inst.writeXML("/home/orisun/test/dt.xml");
052         return;
053     }
054
055     //读取arff文件,给attribute、attributevalue、data赋值
056     public void readARFF(File file) {
057         try {
058             FileReader fr = new FileReader(file);
059             BufferedReader br = new BufferedReader(fr);
060             String line;
061             Pattern pattern = Pattern.compile(patternString);
062             while ((line = br.readLine()) != null) {
063                 Matcher matcher = pattern.matcher(line);
064                 if (matcher.find()) {
065                     attribute.add(matcher.group(1).trim());
066                     String[] values = matcher.group(2).split(",");
067                     ArrayList<String> al = new ArrayList<String>(values.length);
068                     for (String value : values) {
069                         al.add(value.trim());
070                     }
071                     attributevalue.add(al);
072                 } else if (line.startsWith("@data")) {
073                     while ((line = br.readLine()) != null) {
074                         if(line=="")
075                             continue;
076                         String[] row = line.split(",");
077                         data.add(row);
078                     }
079                 } else {
080                     continue;
081                 }
082             }
083             br.close();
084         } catch (IOException e1) {
085             e1.printStackTrace();
086         }
087     }
088
089     //设置决策变量
090     public void setDec(int n) {
091         if (n < 0 || n >= attribute.size()) {
092             System.err.println("决策变量指定错误。");
093             System.exit(2);
094         }
095         decatt = n;
096     }
097     public void setDec(String name) {
098         int n = attribute.indexOf(name);
099         setDec(n);
100     }
101
102     //给一个样本(数组中是各种情况的计数),计算它的熵
103     public double getEntropy(int[] arr) {
104         double entropy = 0.0;
105         int sum = 0;
106         for (int i = 0; i < arr.length; i++) {
107             entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
108             sum += arr[i];
109         }
110         entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
111         entropy /= sum;
112         return entropy;
113     }
114
115     //给一个样本数组及样本的算术和,计算它的熵
116     public double getEntropy(int[] arr, int sum) {
117         double entropy = 0.0;
118         for (int i = 0; i < arr.length; i++) {
119             entropy -= arr[i] * Math.log(arr[i]+Double.MIN_VALUE)/Math.log(2);
120         }
121         entropy += sum * Math.log(sum+Double.MIN_VALUE)/Math.log(2);
122         entropy /= sum;
123         return entropy;
124     }
125
126     public boolean infoPure(ArrayList<Integer> subset) {
127         String value = data.get(subset.get(0))[decatt];
128         for (int i = 1; i < subset.size(); i++) {
129             String next=data.get(subset.get(i))[decatt];
130             //equals表示对象内容相同,==表示两个对象指向的是同一片内存
131             if (!value.equals(next))
132                 return false;
133         }
134         return true;
135     }
136
137     // 给定原始数据的子集(subset中存储行号),当以第index个属性为节点时计算它的信息熵
138     public double calNodeEntropy(ArrayList<Integer> subset, int index) {
139         int sum = subset.size();
140         double entropy = 0.0;
141         int[][] info = new int[attributevalue.get(index).size()][];
142         for (int i = 0; i < info.length; i++)
143             info[i] = new int[attributevalue.get(decatt).size()];
144         int[] count = new int[attributevalue.get(index).size()];
145         for (int i = 0; i < sum; i++) {
146             int n = subset.get(i);
147             String nodevalue = data.get(n)[index];
148             int nodeind = attributevalue.get(index).indexOf(nodevalue);
149             count[nodeind]++;
150             String decvalue = data.get(n)[decatt];
151             int decind = attributevalue.get(decatt).indexOf(decvalue);
152             info[nodeind][decind]++;
153         }
154         for (int i = 0; i < info.length; i++) {
155             entropy += getEntropy(info[i]) * count[i] / sum;
156         }
157         return entropy;
158     }
159
160     // 构建决策树
161     public void buildDT(String name, String value, ArrayList<Integer> subset,
162             LinkedList<Integer> selatt) {
163         Element ele = null;
164         @SuppressWarnings("unchecked")
165         List<Element> list = root.selectNodes("//"+name);
166         Iterator<Element> iter=list.iterator();
167         while(iter.hasNext()){
168             ele=iter.next();
169             if(ele.attributeValue("value").equals(value))
170                 break;
171         }
172         if (infoPure(subset)) {
173             ele.setText(data.get(subset.get(0))[decatt]);
174             return;
175         }
176         int minIndex = -1;
177         double minEntropy = Double.MAX_VALUE;
178         for (int i = 0; i < selatt.size(); i++) {
179             if (i == decatt)
180                 continue;
181             double entropy = calNodeEntropy(subset, selatt.get(i));
182             if (entropy < minEntropy) {
183                 minIndex = selatt.get(i);
184                 minEntropy = entropy;
185             }
186         }
187         String nodeName = attribute.get(minIndex);
188         selatt.remove(new Integer(minIndex));
189         ArrayList<String> attvalues = attributevalue.get(minIndex);
190         for (String val : attvalues) {
191             ele.addElement(nodeName).addAttribute("value", val);
192             ArrayList<Integer> al = new ArrayList<Integer>();
193             for (int i = 0; i < subset.size(); i++) {
194                 if (data.get(subset.get(i))[minIndex].equals(val)) {
195                     al.add(subset.get(i));
196                 }
197             }
198             buildDT(nodeName, val, al, selatt);
199         }
200     }
201
202     // 把xml写入文件
203     public void writeXML(String filename) {
204         try {
205             File file = new File(filename);
206             if (!file.exists())
207                 file.createNewFile();
208             FileWriter fw = new FileWriter(file);
209             OutputFormat format = OutputFormat.createPrettyPrint(); // 美化格式
210             XMLWriter output = new XMLWriter(fw, format);
211             output.write(xmldoc);
212             output.close();
213         } catch (IOException e) {
214             System.out.println(e.getMessage());
215         }
216     }
217 }

最终生成的文件如下:

view source print ?
<?xml version="1.0" encoding="UTF-8"?>
<root>
  <DecisionTree value="null">
    <outlook value="sunny">
      <humidity value="high">no</humidity>
      <humidity value="normal">yes</humidity>
    </outlook>
    <outlook value="overcast">yes</outlook>
    <outlook value="rainy">
      <windy value="TRUE">no</windy>
      <windy value="FALSE">yes</windy>
    </outlook>
  </DecisionTree>
</root>

你可能感兴趣的:(归纳决策树ID3(Java实现))