packageID3Tree;import java.util.*;public classUtilID3 {
TreeNode root;private boolean[] flag;//训练集
privateObject[] trainArrays;//节点索引
private intnodeIndex;public static voidmain(String[] args)
{//初始化训练集数组
Object[] arrays = newObject[]{new String[]{"是","是","正常","否"},new String[]{"是","是","高","是"},new String[]{"是","是","很高","是"},new String[]{"否","是","正常","否"},new String[]{"否","否","高","否"},new String[]{"否","是","很高","是"},new String[]{"是","否","高","是"}};
UtilID3 ID3Tree= newUtilID3();
ID3Tree.create(arrays,3);
}//创建
public void create(Object[] arrays, intindex)
{this.trainArrays =arrays;
initial(arrays, index);
createDTree(arrays);
printDTree(root);
}//初始化
public void initial(Object[] dataArray, intindex)
{this.nodeIndex =index;//数据初始化
this.flag = new boolean[((String[])dataArray[0]).length];for (int i = 0; i
{if (i ==index)
{this.flag[i] = true;
}else{this.flag[i] = false;
}
}
}//创建决策树
public voidcreateDTree(Object[] arrays)
{
Object[] ob=getMaxGain(arrays);if (this.root == null)
{this.root = newTreeNode();
root.parent= null;
root.parentAttribute= null;
root.attributes= getAttributes(((Integer)ob[1]).intValue());
root.nodeName= getNodeName(((Integer)ob[1]).intValue());
root.childNodes= newTreeNode[root.attributes.length];
insert(arrays, root);
}
}//插入决策树
public voidinsert(Object[] arrays, TreeNode parentNode)
{
String[] attributes=parentNode.attributes;for (int i = 0; i < attributes.length; i++)
{
Object[] Arrays=pickUpAndCreateArray(arrays, attributes[i],getNodeIndex(parentNode.nodeName));
Object[] info=getMaxGain(Arrays);double gain = ((Double)info[0]).doubleValue();if (gain != 0)
{int index = ((Integer)info[1]).intValue();
TreeNode currentNode= newTreeNode();
currentNode.parent=parentNode;
currentNode.parentAttribute=attributes[i];
currentNode.attributes=getAttributes(index);
currentNode.nodeName=getNodeName(index);
currentNode.childNodes= newTreeNode[currentNode.attributes.length];
parentNode.childNodes[i]=currentNode;
insert(Arrays, currentNode);
}else{
TreeNode leafNode= newTreeNode();
leafNode.parent=parentNode;
leafNode.parentAttribute=attributes[i];
leafNode.attributes= new String[0];
leafNode.nodeName=getLeafNodeName(Arrays);
leafNode.childNodes= new TreeNode[0];
parentNode.childNodes[i]=leafNode;
}
}
}//输出
public voidprintDTree(TreeNode node)
{
System.out.println(node.nodeName);
TreeNode[] childs=node.childNodes;for (int i = 0; i < childs.length; i++)
{if (childs[i] != null)
{
System.out.println("如果:"+childs[i].parentAttribute);
printDTree(childs[i]);
}
}
}//剪取数组
public Object[] pickUpAndCreateArray(Object[] arrays, String attribute, intindex)
{
List list = new ArrayList();for (int i = 0; i < arrays.length; i++)
{
String[] strs=(String[])arrays[i];if(strs[index].equals(attribute))
{
list.add(strs);
}
}returnlist.toArray();
}//取得节点名
public String getNodeName(intindex)
{
String[] strs= new String[]{"头痛","肌肉痛","体温","患流感"};for (int i = 0; i < strs.length; i++)
{if (i ==index)
{returnstrs[i];
}
}return null;
}//取得叶子节点名
publicString getLeafNodeName(Object[] arrays)
{if (arrays != null && arrays.length > 0)
{
String[] strs= (String[])arrays[0];returnstrs[nodeIndex];
}return null;
}//取得节点索引
public intgetNodeIndex(String name)
{
String[] strs= new String[]{"头痛","肌肉痛","体温","患流感"};for (int i = 0; i < strs.length; i++)
{if(name.equals(strs[i]))
{returni;
}
}return -1;
}//得到最大信息增益
publicObject[] getMaxGain(Object[] arrays)
{
Object[] result= new Object[2];double gain = 0;int index = -1;for (int i = 0; i
{if (!this.flag[i])
{double value =gain(arrays, i);if (gain
{
gain=value;
index=i;
}
}
}
result[0] =gain;
result[1] =index;if (index != -1)
{this.flag[index] = true;
}returnresult;
}//取得属性数组
public String[] getAttributes(intindex)
{
@SuppressWarnings("unchecked")
TreeSet set = new TreeSet(newComparisons());for (int i = 0; i
{
String[] strs= (String[])this.trainArrays[i];
set.add(strs[index]);
}
String[] result= newString[set.size()];returnset.toArray(result);
}//计算信息增益
public double gain(Object[] arrays, intindex)
{
String[] playBalls= getAttributes(this.nodeIndex);int[] counts = new int[playBalls.length];for (int i = 0; i
{
counts[i]= 0;
}for (int i = 0; i
{
String[] strs=(String[])arrays[i];for (int j = 0; j
{if (strs[this.nodeIndex].equals(playBalls[j]))
{
counts[j]++;
}
}
}double entropyS = 0;for (int i = 0;i
{
entropyS= entropyS +Entropy.getEntropy(counts[i], arrays.length);
}
String[] attributes=getAttributes(index);double total = 0;for (int i = 0; i
{
total= total +entropy(arrays, index, attributes[i], arrays.length);
}return entropyS -total;
}public double entropy(Object[] arrays, int index, String attribute, inttotals)
{
String[] playBalls= getAttributes(this.nodeIndex);int[] counts = new int[playBalls.length];for (int i = 0; i < counts.length; i++)
{
counts[i]= 0;
}for (int i = 0; i < arrays.length; i++)
{
String[] strs=(String[])arrays[i];if(strs[index].equals(attribute))
{for (int k = 0; k
{if (strs[this.nodeIndex].equals(playBalls[k]))
{
counts[k]++;
}
}
}
}int total = 0;double entropy = 0;for (int i = 0; i < counts.length; i++)
{
total= total +counts[i];
}for (int i = 0; i < counts.length; i++)
{
entropy= entropy +Entropy.getEntropy(counts[i], total);
}return Entropy.getShang(total, totals)*entropy;
}
}