日撸 Java 三百行学习笔记day61-62

近几天才返校,在忙毕业论文相关事宜,抽空写点心得体会,所以做得比较慢。

决策树对于我来说也算是一个全新的概念,是数据挖掘的一个重要分支,根据闵老师所写的决策树快问快答可以大致了解,算是通俗易懂了。

内容还是很多,一部分一部分拆开来解析。

我们建造一颗决策树,从选择属性开始:

/**
	 ********************************** 
	 * Select the best attribute.
	 * 
	 * @return The best attribute index.
	 ********************************** 
	 */
	public int selectBestAttribute() {
		splitAttribute = -1;
		double tempMinimalEntropy = 10000;
		double tempEntropy;
		for (int i = 0; i < availableAttributes.length; i++) {
			tempEntropy = conditionalEntropy(availableAttributes[i]);
			if (tempMinimalEntropy > tempEntropy) {
				tempMinimalEntropy = tempEntropy;
				splitAttribute = availableAttributes[i];
			} // Of if
		} // Of for i
		return splitAttribute;
	}// Of selectBestAttribute

	/**
	 ********************************** 
	 * Compute the conditional entropy of an attribute.
	 * 
	 * @param paraAttribute
	 *            The given attribute.
	 * 
	 * @return The entropy.
	 ********************************** 
	 */
	public double conditionalEntropy(int paraAttribute) {
		// Step 1. Statistics.
		int tempNumClasses = dataset.numClasses();
		int tempNumValues = dataset.attribute(paraAttribute).numValues();
		int tempNumInstances = availableInstances.length;
		double[] tempValueCounts = new double[tempNumValues];
		double[][] tempCountMatrix = new double[tempNumValues][tempNumClasses];

		int tempClass, tempValue;
		for (int i = 0; i < tempNumInstances; i++) {
			tempClass = (int) dataset.instance(availableInstances[i]).classValue();
			tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
			tempValueCounts[tempValue]++;
			tempCountMatrix[tempValue][tempClass]++;
		} // Of for i

		// Step 2.
		double resultEntropy = 0;
		double tempEntropy, tempFraction;
		for (int i = 0; i < tempNumValues; i++) {
			if (tempValueCounts[i] == 0) {
				continue;
			} // Of if
			tempEntropy = 0;
			for (int j = 0; j < tempNumClasses; j++) {
				tempFraction = tempCountMatrix[i][j] / tempValueCounts[i];
				if (tempFraction == 0) {
					continue;
				} // Of if
				tempEntropy += -tempFraction * Math.log(tempFraction);
			} // Of for j
			resultEntropy += tempValueCounts[i] / tempNumInstances * tempEntropy;
		} // Of for i

		return resultEntropy;
	}// Of conditionalEntropy

我们需要找到conditionalEntropy最小的那个可用属性,selectBestAttribute()就是简单的一个比较的方法,在conditionalEntropy()中,首先进行数据统计,再根据统计的数据进行数据计算,算出paraAttribute的resultEntropy,放到selectBestAttribute()中进行比较,选出最小的熵。

/**
	 ********************************** 
	 * Split the data according to the given attribute.
	 * 
	 * @return The blocks.
	 ********************************** 
	 */
	public int[][] splitData(int paraAttribute) {
		int tempNumValues = dataset.attribute(paraAttribute).numValues();
		// System.out.println("Dataset " + dataset + "\r\n");
		// System.out.println("Attribute " + paraAttribute + " has " +
		// tempNumValues + " values.\r\n");
		int[][] resultBlocks = new int[tempNumValues][];
		int[] tempSizes = new int[tempNumValues];

		// First scan to count the size of each block.
		int tempValue;
		for (int i = 0; i < availableInstances.length; i++) {
			tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
			tempSizes[tempValue]++;
		} // Of for i

		// Allocate space.
		for (int i = 0; i < tempNumValues; i++) {
			resultBlocks[i] = new int[tempSizes[i]];
		} // Of for i

		// Second scan to fill.
		Arrays.fill(tempSizes, 0);
		for (int i = 0; i < availableInstances.length; i++) {
			tempValue = (int) dataset.instance(availableInstances[i]).value(paraAttribute);
			// Copy data.
			resultBlocks[tempValue][tempSizes[tempValue]] = availableInstances[i];
			tempSizes[tempValue]++;
		} // Of for i

		return resultBlocks;
	}// Of splitData

这里的splitData方法返回的是一个二维数组,因为你要根据选中的属性来划分数据,这也是间址的运用。首先根据选择的属性的各个值来将所有availableInstances分成paraAttribute.length这么多类,像我们用的weather这个数据集按照outlook这个属性分的话,就会分为Sunny, Overcast, Rain这三类,然后分配空间,再分别装入数据。再次说明,返回的是一个二维数组,最后如何填满这个二维数组还是值得细看的。

接下来就是重头戏了:

/**
	 ********************************** 
	 * Build the tree recursively.
	 ********************************** 
	 */
	public void buildTree() {
		if (pureJudge(availableInstances)) {
			return;
		} // Of if
		if (availableInstances.length <= smallBlockThreshold) {
			return;
		} // Of if

		selectBestAttribute();
		int[][] tempSubBlocks = splitData(splitAttribute);
		children = new ID3[tempSubBlocks.length];

		// Construct the remaining attribute set.
		int[] tempRemainingAttributes = new int[availableAttributes.length - 1];
		for (int i = 0; i < availableAttributes.length; i++) {
			if (availableAttributes[i] < splitAttribute) {
				tempRemainingAttributes[i] = availableAttributes[i];
			} else if (availableAttributes[i] > splitAttribute) {
				tempRemainingAttributes[i - 1] = availableAttributes[i];
			} // Of if
		} // Of for i

		// Construct children.
		for (int i = 0; i < children.length; i++) {
			if ((tempSubBlocks[i] == null) || (tempSubBlocks[i].length == 0)) {
				children[i] = null;
				continue;
			} else {
				// System.out.println("Building children #" + i + " with
				// instances " + Arrays.toString(tempSubBlocks[i]));
				children[i] = new ID3(dataset, tempSubBlocks[i], tempRemainingAttributes);

				// Important code: do this recursively
				children[i].buildTree();
			} // Of if
		} // Of for i
	}// Of buildTree

开头的两个if语句,意义很简单,如果已经纯了就返回,这里我们设置的smallBlockThreshold = 3,如果小于3也返回。日撸 Java 三百行学习笔记day61-62_第1张图片

 在这一段卡了一会,没看懂,结果发现自己太瓜了,就是把剩下的重新整合一个数组....

之后就是把决策树往下衍生,进行递归。

/**
	 ********************************** 
	 * Classify an instance.
	 * 
	 * @param paraInstance
	 *            The given instance.
	 * @return The prediction.
	 ********************************** 
	 */
	public int classify(Instance paraInstance) {
		if (children == null) {
			return label;
		} // Of if

		ID3 tempChild = children[(int) paraInstance.value(splitAttribute)];
		if (tempChild == null) {
			return label;
		} // Of if

		return tempChild.classify(paraInstance);
	}// Of classify

	/**
	 ********************************** 
	 * Test on a testing set.
	 * 
	 * @param paraDataset
	 *            The given testing data.
	 * @return The accuracy.
	 ********************************** 
	 */
	public double test(Instances paraDataset) {
		double tempCorrect = 0;
		for (int i = 0; i < paraDataset.numInstances(); i++) {
			if (classify(paraDataset.instance(i)) == (int) paraDataset.instance(i).classValue()) {
				tempCorrect++;
			} // Of i
		} // Of for i

		return tempCorrect / paraDataset.numInstances();
	}// Of test

这就是最后一个要点,进行分类label,然后测算精确度了。

最后贴出结果:

 结果也是以三叉树形式表现的,outlook分3种,然后再分别往下延伸,1,0就表示P还是N最后的就是准确度了。

你可能感兴趣的:(学习)