决策树C4.5算法

数据挖掘中决策树C4.5预测算法实现(半成品,还要写规则后煎支及对非离散数据信息增益计算),下一篇博客讲原理

package org.struct.decisiontree;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.TreeSet;

/**
 * @author Leon.Chen
 */
public class DecisionTreeBaseC4p5 {
	
	/**
	 * root node
	 */
	private DecisionTreeNode root;

	/**
	 * visableArray
	 */
	private boolean[] visable;

	private static final int NOT_FOUND = -1;

	private static final int DATA_START_LINE = 1;

	private Object[] trainingArray;

	private String[] columnHeaderArray;

	/**
	 * forecast node index
	 */
	private int nodeIndex;

	/**
	 * @param args
	 */
	@SuppressWarnings("boxing")
	public static void main(String[] args) {
		Object[] array = new Object[] {
				new String[] { "age",          "income",   "student", "credit_rating", "buys_computer" },
				new String[] { "youth",        "high",     "no",      "fair",          "no"  },
				new String[] { "youth",        "high",     "no",      "excellent",     "no"  },
				new String[] { "middle_aged",  "high",     "no",      "fair",          "yes" },
				new String[] { "senior",       "medium",   "no",      "fair",          "yes" },
				new String[] { "senior",       "low",      "yes",     "fair",          "yes" },
				new String[] { "senior",       "low",      "yes",     "excellent",     "no"  },
				new String[] { "middle_aged",  "low",      "yes",     "excellent",     "yes" },
				new String[] { "youth",        "medium",   "no",      "fair",          "no"  },
				new String[] { "youth",        "low",      "yes",     "fair",          "yes" },
				new String[] { "senior",       "medium",   "yes",     "fair",          "yes" },
				new String[] { "youth",        "medium",   "yes",     "excellent",     "yes" },
				new String[] { "middle_aged",  "medium",   "no",      "excellent",     "yes" },
				new String[] { "middle_aged",  "high",     "yes",     "fair",          "yes" },
				new String[] { "senior",       "medium",   "no",      "excellent",     "no"  },
		};

		DecisionTreeBaseC4p5 tree = new DecisionTreeBaseC4p5();
		tree.create(array, 4);
		System.out.println("===============END PRINT TREE===============");
		System.out.println("===============DECISION RESULT===============");
		//tree.forecast(printData, tree.root);
	}

	/**
	 * @param printData
	 * @param node
	 */
	public void forecast(String[] printData, DecisionTreeNode node) {
		int index = getColumnHeaderIndexByName(node.nodeName);
		if (index == NOT_FOUND) {
			System.out.println(node.nodeName);
		}
		DecisionTreeNode[] childs = node.childNodesArray;
		for (int i = 0; i < childs.length; i++) {
			if (childs[i] != null) {
				if (childs[i].parentArrtibute.equals(printData[index])) {
					forecast(printData, childs[i]);
				}
			}
		}
	}

	/**
	 * @param array
	 * @param index
	 */
	public void create(Object[] array, int index) {
		this.trainingArray = Arrays.copyOfRange(array, DATA_START_LINE,
				array.length);
		init(array, index);
		createDecisionTree(this.trainingArray);
		printDecisionTree(root);
	}

	/**
	 * @param array
	 * @return Object[]
	 */
	@SuppressWarnings("boxing")
	public Object[] getMaxGain(Object[] array) {
		Object[] result = new Object[2];
		double gain = 0;
		int index = -1;

		for (int i = 0; i < visable.length; i++) {
			if (!visable[i]) {
				//TODO ID3 change to C4.5
				double value = gainRatio(array, i, this.nodeIndex);
				System.out.println(value);
				if (gain < value) {
					gain = value;
					index = i;
				}
			}
		}
		result[0] = gain;
		result[1] = index;
		// TODO throws can't forecast this model exception
		if (index != -1) {
			visable[index] = true;
		}
		return result;
	}

	/**
	 * @param array
	 */
	public void createDecisionTree(Object[] array) {
		Object[] maxgain = getMaxGain(array);
		if (root == null) {
			root = new DecisionTreeNode();
			root.parentNode = null;
			root.parentArrtibute = null;
			root.arrtibutesArray = getArrtibutesArray(((Integer) maxgain[1])
					.intValue());
			root.nodeName = getColumnHeaderNameByIndex(((Integer) maxgain[1])
					.intValue());
			root.childNodesArray = new DecisionTreeNode[root.arrtibutesArray.length];
			insertDecisionTree(array, root);
		}
	}

	/**
	 * @param array
	 * @param parentNode
	 */
	public void insertDecisionTree(Object[] array, DecisionTreeNode parentNode) {
		String[] arrtibutes = parentNode.arrtibutesArray;
		for (int i = 0; i < arrtibutes.length; i++) {
			Object[] pickArray = pickUpAndCreateSubArray(array, arrtibutes[i],
					getColumnHeaderIndexByName(parentNode.nodeName));
			Object[] info = getMaxGain(pickArray);
			double gain = ((Double) info[0]).doubleValue();
			if (gain != 0) {
				int index = ((Integer) info[1]).intValue();
				DecisionTreeNode currentNode = new DecisionTreeNode();
				currentNode.parentNode = parentNode;
				currentNode.parentArrtibute = arrtibutes[i];
				currentNode.arrtibutesArray = getArrtibutesArray(index);
				currentNode.nodeName = getColumnHeaderNameByIndex(index);
				currentNode.childNodesArray = new DecisionTreeNode[currentNode.arrtibutesArray.length];
				parentNode.childNodesArray[i] = currentNode;
				insertDecisionTree(pickArray, currentNode);
			} else {
				DecisionTreeNode leafNode = new DecisionTreeNode();
				leafNode.parentNode = parentNode;
				leafNode.parentArrtibute = arrtibutes[i];
				leafNode.arrtibutesArray = new String[0];
				leafNode.nodeName = getLeafNodeName(pickArray,this.nodeIndex);
				leafNode.childNodesArray = new DecisionTreeNode[0];
				parentNode.childNodesArray[i] = leafNode;
			}
		}
	}

	/**
	 * @param node
	 */
	public void printDecisionTree(DecisionTreeNode node) {
		System.out.println(node.nodeName);
		DecisionTreeNode[] childs = node.childNodesArray;
		for (int i = 0; i < childs.length; i++) {
			if (childs[i] != null) {
				System.out.println(childs[i].parentArrtibute);
				printDecisionTree(childs[i]);
			}
		}
	}

	/**
	 * init data
	 * 
	 * @param dataArray
	 * @param index
	 */
	public void init(Object[] dataArray, int index) {
		this.nodeIndex = index;
		// init data
		this.columnHeaderArray = (String[]) dataArray[0];
		visable = new boolean[((String[]) dataArray[0]).length];
		for (int i = 0; i < visable.length; i++) {
			if (i == index) {
				visable[i] = true;
			} else {
				visable[i] = false;
			}
		}
	}

	/**
	 * @param array
	 * @param arrtibute
	 * @param index
	 * @return Object[]
	 */
	public Object[] pickUpAndCreateSubArray(Object[] array, String arrtibute,
			int index) {
		List<String[]> list = new ArrayList<String[]>();
		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			if (strs[index].equals(arrtibute)) {
				list.add(strs);
			}
		}
		return list.toArray();
	}

	/**
	 * gain(A)
	 * 
	 * @param array
	 * @param index
	 * @return double
	 */
	public double gain(Object[] array, int index, int nodeIndex) {
		int[] counts = separateToSameValueArrays(array, nodeIndex);
		String[] arrtibutes = getArrtibutesArray(index);
		double infoD = infoD(array, counts);
		double infoaD = infoaD(array, index, nodeIndex, arrtibutes);
		return infoD - infoaD;
	}

	/**
	 * @param array
	 * @param nodeIndex
	 * @return
	 */
	public int[] separateToSameValueArrays(Object[] array, int nodeIndex) {
		String[] arrti = getArrtibutesArray(nodeIndex);
		int[] counts = new int[arrti.length];
		for (int i = 0; i < counts.length; i++) {
			counts[i] = 0;
		}
		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			for (int j = 0; j < arrti.length; j++) {
				if (strs[nodeIndex].equals(arrti[j])) {
					counts[j]++;
				}
			}
		}
		return counts;
	}
	
	/**
	 * gainRatio = gain(A)/splitInfo(A)
	 * 
	 * @param array
	 * @param index
	 * @param nodeIndex
	 * @return
	 */
	public double gainRatio(Object[] array,int index,int nodeIndex){
		double gain = gain(array,index,nodeIndex);
		int[] counts = separateToSameValueArrays(array, index);
		double splitInfo = splitInfoaD(array,counts);
		if(splitInfo != 0){
			return gain/splitInfo;
		}
		return 0;
	}

	/**
	 * infoD = -E(pi*log2 pi)
	 * 
	 * @param array
	 * @param counts
	 * @return
	 */
	public double infoD(Object[] array, int[] counts) {
		double infoD = 0;
		for (int i = 0; i < counts.length; i++) {
			infoD += DecisionTreeUtil.info(counts[i], array.length);
		}
		return infoD;
	}

	/**
	 * splitInfoaD = -E|Dj|/|D|*log2(|Dj|/|D|)
	 * 
	 * @param array
	 * @param counts
	 * @return
	 */
	public double splitInfoaD(Object[] array, int[] counts) {
		return infoD(array, counts);
	}

	/**
	 * infoaD = E(|Dj| / |D|) * info(Dj)
	 * 
	 * @param array
	 * @param index
	 * @param arrtibutes
	 * @return
	 */
	public double infoaD(Object[] array, int index, int nodeIndex,
			String[] arrtibutes) {
		double sv_total = 0;
		for (int i = 0; i < arrtibutes.length; i++) {
			sv_total += infoDj(array, index, nodeIndex, arrtibutes[i],
					array.length);
		}
		return sv_total;
	}

	/**
	 * ((|Dj| / |D|) * Info(Dj))
	 * 
	 * @param array
	 * @param index
	 * @param arrtibute
	 * @param allTotal
	 * @return double
	 */
	public double infoDj(Object[] array, int index, int nodeIndex,
			String arrtibute, int allTotal) {
		String[] arrtibutes = getArrtibutesArray(nodeIndex);
		int[] counts = new int[arrtibutes.length];
		for (int i = 0; i < counts.length; i++) {
			counts[i] = 0;
		}

		for (int i = 0; i < array.length; i++) {
			String[] strs = (String[]) array[i];
			if (strs[index].equals(arrtibute)) {
				for (int k = 0; k < arrtibutes.length; k++) {
					if (strs[nodeIndex].equals(arrtibutes[k])) {
						counts[k]++;
					}
				}
			}
		}

		int total = 0;
		double infoDj = 0;
		for (int i = 0; i < counts.length; i++) {
			total += counts[i];
		}
		for (int i = 0; i < counts.length; i++) {
			infoDj += DecisionTreeUtil.info(counts[i], total);
		}
		return DecisionTreeUtil.getPi(total, allTotal) * infoDj;
	}

	/**
	 * @param index
	 * @return String[]
	 */
	@SuppressWarnings("unchecked")
	public String[] getArrtibutesArray(int index) {
		TreeSet<String> set = new TreeSet<String>(new SequenceComparator());
		for (int i = 0; i < trainingArray.length; i++) {
			String[] strs = (String[]) trainingArray[i];
			set.add(strs[index]);
		}
		String[] result = new String[set.size()];
		return set.toArray(result);
	}

	/**
	 * @param index
	 * @return String
	 */
	public String getColumnHeaderNameByIndex(int index) {
		for (int i = 0; i < columnHeaderArray.length; i++) {
			if (i == index) {
				return columnHeaderArray[i];
			}
		}
		return null;
	}

	/**
	 * @param array
	 * @return String
	 */
	public String getLeafNodeName(Object[] array,int nodeIndex) {
		if (array != null && array.length > 0) {
			String[] strs = (String[]) array[0];
			return strs[nodeIndex];
		}
		return null;
	}

	/**
	 * @param name
	 * @return int
	 */
	public int getColumnHeaderIndexByName(String name) {
		for (int i = 0; i < columnHeaderArray.length; i++) {
			if (name.equals(columnHeaderArray[i])) {
				return i;
			}
		}
		return NOT_FOUND;
	}
}

package org.struct.decisiontree;

/**
 * @author Leon.Chen
 */
public class DecisionTreeNode {

	DecisionTreeNode parentNode;

	String parentArrtibute;

	String nodeName;

	String[] arrtibutesArray;

	DecisionTreeNode[] childNodesArray;

}

package org.struct.decisiontree;

/**
 * @author Leon.Chen
 */
public class DecisionTreeUtil {

	/**
	 * entropy:Info(T)=(i=1...k)pi*log(2)pi
	 * 
	 * @param x
	 * @param total
	 * @return double
	 */
	public static double info(int x, int total) {
		if (x == 0) {
			return 0;
		}
		double x_pi = getPi(x, total);
		return -(x_pi * logYBase2(x_pi));
	}

	/**
	 * log2y
	 * 
	 * @param y
	 * @return double
	 */
	public static double logYBase2(double y) {
		return Math.log(y) / Math.log(2);
	}

	/**
	 * pi=|C(i,d)|/|D|
	 * 
	 * @param x
	 * @param total
	 * @return double
	 */
	public static double getPi(int x, int total) {
		return x / (double) total;
	}

}


package org.struct.decisiontree;

import java.util.Comparator;

/**
 * @author Leon.Chen
 * 
 */
@SuppressWarnings("unchecked")
public class SequenceComparator implements Comparator {

	public int compare(Object o1, Object o2) throws ClassCastException {
		String str1 = (String) o1;
		String str2 = (String) o2;
		return str1.compareTo(str2);
	}
}

你可能感兴趣的:(编程,算法,数据挖掘,J#)