决策树算法 java_决策树ID3算法的java实现(基本适用所有的ID3)

packageID3Tree;import java.util.*;public classUtilID3 {

TreeNode root;private boolean[] flag;//训练集

privateObject[] trainArrays;//节点索引

private intnodeIndex;public static voidmain(String[] args)

{//初始化训练集数组

Object[] arrays = newObject[]{new String[]{"是","是","正常","否"},new String[]{"是","是","高","是"},new String[]{"是","是","很高","是"},new String[]{"否","是","正常","否"},new String[]{"否","否","高","否"},new String[]{"否","是","很高","是"},new String[]{"是","否","高","是"}};

UtilID3 ID3Tree= newUtilID3();

ID3Tree.create(arrays,3);

}//创建

public void create(Object[] arrays, intindex)

{this.trainArrays =arrays;

initial(arrays, index);

createDTree(arrays);

printDTree(root);

}//初始化

public void initial(Object[] dataArray, intindex)

{this.nodeIndex =index;//数据初始化

this.flag = new boolean[((String[])dataArray[0]).length];for (int i = 0; i

{if (i ==index)

{this.flag[i] = true;

}else{this.flag[i] = false;

}

}

}//创建决策树

public voidcreateDTree(Object[] arrays)

{

Object[] ob=getMaxGain(arrays);if (this.root == null)

{this.root = newTreeNode();

root.parent= null;

root.parentAttribute= null;

root.attributes= getAttributes(((Integer)ob[1]).intValue());

root.nodeName= getNodeName(((Integer)ob[1]).intValue());

root.childNodes= newTreeNode[root.attributes.length];

insert(arrays, root);

}

}//插入决策树

public voidinsert(Object[] arrays, TreeNode parentNode)

{

String[] attributes=parentNode.attributes;for (int i = 0; i < attributes.length; i++)

{

Object[] Arrays=pickUpAndCreateArray(arrays, attributes[i],getNodeIndex(parentNode.nodeName));

Object[] info=getMaxGain(Arrays);double gain = ((Double)info[0]).doubleValue();if (gain != 0)

{int index = ((Integer)info[1]).intValue();

TreeNode currentNode= newTreeNode();

currentNode.parent=parentNode;

currentNode.parentAttribute=attributes[i];

currentNode.attributes=getAttributes(index);

currentNode.nodeName=getNodeName(index);

currentNode.childNodes= newTreeNode[currentNode.attributes.length];

parentNode.childNodes[i]=currentNode;

insert(Arrays, currentNode);

}else{

TreeNode leafNode= newTreeNode();

leafNode.parent=parentNode;

leafNode.parentAttribute=attributes[i];

leafNode.attributes= new String[0];

leafNode.nodeName=getLeafNodeName(Arrays);

leafNode.childNodes= new TreeNode[0];

parentNode.childNodes[i]=leafNode;

}

}

}//输出

public voidprintDTree(TreeNode node)

{

System.out.println(node.nodeName);

TreeNode[] childs=node.childNodes;for (int i = 0; i < childs.length; i++)

{if (childs[i] != null)

{

System.out.println("如果:"+childs[i].parentAttribute);

printDTree(childs[i]);

}

}

}//剪取数组

public Object[] pickUpAndCreateArray(Object[] arrays, String attribute, intindex)

{

List list = new ArrayList();for (int i = 0; i < arrays.length; i++)

{

String[] strs=(String[])arrays[i];if(strs[index].equals(attribute))

{

list.add(strs);

}

}returnlist.toArray();

}//取得节点名

public String getNodeName(intindex)

{

String[] strs= new String[]{"头痛","肌肉痛","体温","患流感"};for (int i = 0; i < strs.length; i++)

{if (i ==index)

{returnstrs[i];

}

}return null;

}//取得叶子节点名

publicString getLeafNodeName(Object[] arrays)

{if (arrays != null && arrays.length > 0)

{

String[] strs= (String[])arrays[0];returnstrs[nodeIndex];

}return null;

}//取得节点索引

public intgetNodeIndex(String name)

{

String[] strs= new String[]{"头痛","肌肉痛","体温","患流感"};for (int i = 0; i < strs.length; i++)

{if(name.equals(strs[i]))

{returni;

}

}return -1;

}//得到最大信息增益

publicObject[] getMaxGain(Object[] arrays)

{

Object[] result= new Object[2];double gain = 0;int index = -1;for (int i = 0; i

{if (!this.flag[i])

{double value =gain(arrays, i);if (gain

{

gain=value;

index=i;

}

}

}

result[0] =gain;

result[1] =index;if (index != -1)

{this.flag[index] = true;

}returnresult;

}//取得属性数组

public String[] getAttributes(intindex)

{

@SuppressWarnings("unchecked")

TreeSet set = new TreeSet(newComparisons());for (int i = 0; i

{

String[] strs= (String[])this.trainArrays[i];

set.add(strs[index]);

}

String[] result= newString[set.size()];returnset.toArray(result);

}//计算信息增益

public double gain(Object[] arrays, intindex)

{

String[] playBalls= getAttributes(this.nodeIndex);int[] counts = new int[playBalls.length];for (int i = 0; i

{

counts[i]= 0;

}for (int i = 0; i

{

String[] strs=(String[])arrays[i];for (int j = 0; j

{if (strs[this.nodeIndex].equals(playBalls[j]))

{

counts[j]++;

}

}

}double entropyS = 0;for (int i = 0;i

{

entropyS= entropyS +Entropy.getEntropy(counts[i], arrays.length);

}

String[] attributes=getAttributes(index);double total = 0;for (int i = 0; i

{

total= total +entropy(arrays, index, attributes[i], arrays.length);

}return entropyS -total;

}public double entropy(Object[] arrays, int index, String attribute, inttotals)

{

String[] playBalls= getAttributes(this.nodeIndex);int[] counts = new int[playBalls.length];for (int i = 0; i < counts.length; i++)

{

counts[i]= 0;

}for (int i = 0; i < arrays.length; i++)

{

String[] strs=(String[])arrays[i];if(strs[index].equals(attribute))

{for (int k = 0; k

{if (strs[this.nodeIndex].equals(playBalls[k]))

{

counts[k]++;

}

}

}

}int total = 0;double entropy = 0;for (int i = 0; i < counts.length; i++)

{

total= total +counts[i];

}for (int i = 0; i < counts.length; i++)

{

entropy= entropy +Entropy.getEntropy(counts[i], total);

}return Entropy.getShang(total, totals)*entropy;

}

}

你可能感兴趣的:(决策树算法,java)