日撸代码300行:第65天(集成学习之 AdaBoosting-3)

  代码来自闵老师”日撸 Java 三百行(61-70天)

日撸 Java 三百行(61-70天,决策树与集成学习)_闵帆的博客-CSDN博客

今天的代码完成的是基础分类器的集成,类名为Booster。构造函数传入数据集后,设置了训练集和测试集,两个是一样的都是传入的数据集。

setNumBaseClassifier方法设置基础分类器的个数,里面的变量classifierWeights保存的是分类器的权值。这里要区分开实例权重和分类器权重。

classify方法里,所有方法都进行分类。分类的时候各个分类器根据自身的权重进行。最终所有分类器给出的所有标签得分最高的就是预测的类标签。

train方法里面对每一个分类器进行训练。分类器是昨天的树桩分类。遍历分类器个数的for循环里,i=0时是初始化对象tempWeightInstances。else里面是进行对象属性的权重调整,调用的是前天代码里的adjustWeights方法。接下来第二步是训练树桩分类器。“Set the classifier weight.”这一步就是分类器的权重计算公式,套进去就行了。至于后面的if语句,应该是为了防止出现无意义的权重。测试了将classifierWeights[i]赋值为0和1e-6,结果是一样的。最后一步是如果精度达到0.999999就在这一轮循环终止了,已经达到预期的精度,并输出是在第几轮停止的。测试了一下,大概会在150轮左右跳出。

package machinelearning.adaboosting;

import java.io.FileReader;

import weka.core.Instance;
import weka.core.Instances;

/**
 * The booster which ensembles base classifiers.
 * @author WX873
 *
 */
public class Booster {
	
	/**
	 * Classifiers.
	 */
	SimpleClassifier[] classifiers;
	
	/**
	 * Number of classifiers.
	 */
	int numClassifiers;
	
	/**
	 * Whether or not stop after the training error is 0.
	 */
	boolean stopAfterConverge = false;
	
	/**
	 * The weights of classifiers.
	 */
	double[] classifierWeights;
	
	/**
	 * The training data.
	 */
	Instances trainingData;
	
	/**
	 * The test data.
	 */
	Instances testingData;
	
	/**
	 * ******************************************************
	 * The first constructor. The testing set is the same as the training set.
	 * 
	 * @param paraTrainingFilename  The data filename.
	 * ******************************************************
	 */
	public Booster(String paraTrainingFilename) {
		// TODO Auto-generated constructor stub
		try {
			FileReader tempFileReader = new FileReader(paraTrainingFilename);
			trainingData = new Instances(tempFileReader);
			tempFileReader.close();
		} catch (Exception e) {
			// TODO: handle exception
			System.out.println("Cannot read the file: " + paraTrainingFilename + "\r\n" + e);
			System.exit(0);
		}//of try
		
		// Step 2. Set the last attribute as the class index.
		trainingData.setClassIndex(trainingData.numAttributes() - 1);
		
		// Step 3. The testing data is the same as the training data.
		testingData = trainingData;
		
		stopAfterConverge = true;
		System.out.println("****************Data**********\r\n" + trainingData);
	}//Of the first constructor
	
	/**
	 * *********************************************************************
	 * Set the number of base classifier, and allocate space for them.
	 * 
	 * @param paraNumBaseClassifiers  The number of base classifier.
	 * *********************************************************************
	 */
	public void setNumBaseClassifier(int paraNumBaseClassifiers) {
		numClassifiers = paraNumBaseClassifiers;
		
		//Step 1. Allocate space (only reference) for classifiers
		classifiers = new SimpleClassifier[numClassifiers];
		
		//Step 2. Initialize classifier weights.
		classifierWeights = new double[numClassifiers];
	}//of setNumBaseClassifier
	
	/**
	 ****************** 
	 * Classify an instance.
	 * 
	 * @param paraInstance
	 *            The given instance.
	 * @return The predicted label.
	 ****************** 
	 */
	public int classify(Instance paraInstance) {
		double[] tempLabelsCountArray = new double[trainingData.classAttribute().numValues()];
		//所有分类器都对当前实例进行分类,每个分类器按照自己的权重进行。
		for (int i = 0; i < numClassifiers; i++) {
			int tempLabel = classifiers[i].classify(paraInstance);
			tempLabelsCountArray[tempLabel] += classifierWeights[i];
		}//of for i
		
		int resultLabel = -1;
		double tempMax = -1;
		//最终权重和(得分)最高的标签就是预测标签
		for (int i = 0; i < tempLabelsCountArray.length; i++) {
			if (tempMax < tempLabelsCountArray[i]) {
				tempMax = tempLabelsCountArray[i];
				resultLabel = i;
			}//of if
		}//of for i
		
		return resultLabel;
	}//of classify
	
	/**
	 * ***************************************************************
	 * Compute the training accuracy of the booster. It is not weighted.
	 * 
	 * @return  The training accuracy.
	 * ***************************************************************
	 */
	public double computeTrainingAccuracy() {
		double tempCorrect = 0;
		
		for (int i = 0; i < trainingData.numInstances(); i++) {
			if (classify(trainingData.instance(i)) == (int)trainingData.instance(i).classValue()) {
				tempCorrect++;
			}//of if
		}//of for i
		
		double tempAccuracy = tempCorrect / trainingData.numInstances();
		
		return tempAccuracy;
	}//of computeTrainingAccuracy
	
	/**
	 * **********************************************************************
	 * Train the booster.
	 * 
	 * @see algorithm.StumpClassifier#train()
	 * **********************************************************************
	 */
	public void train() {
		// Step 1. Initialize.
		WeightedInstances tempWeightInstances = null;
		double tempError;
		numClassifiers = 0;
		
		//Step 2. Build other classifiers.
		for (int i = 0; i < classifiers.length; i++) {
			// Step 2.1 Key code: Construct or adjust the weightedInstances
			if (i == 0) {
				tempWeightInstances = new WeightedInstances(trainingData);
			} else {
				// Adjust the weights of the data.
				tempWeightInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(), classifierWeights[i - 1]);
			}//of if
			
			// Step 2.2 Train the next classifier.
			classifiers[i] = new StumpClassifier(tempWeightInstances);
			classifiers[i].train();
			
			tempError = classifiers[i].computeWeightedError();
			
			// Key code: Set the classifier weight.
			classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
			if (classifierWeights[i] < 1e-6) {
				classifierWeights[i] = 0;
			}//of if
			
			System.out.println("Classifier #" + i + " , weighted error = " + tempError + ", weight = "
					+ classifierWeights[i] + "\r\n");
			
			numClassifiers++;
			
			// The accuracy is enough.
			if (stopAfterConverge) {
				double tempTrainingAccuracy = computeTrainingAccuracy();
				System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
				if (tempTrainingAccuracy > 0.999999) {
					System.out.println("Stop at the round: " + i + " due to converge.\r\n");
					break;
				}//of if
			}//of if
		}//of for i
	}//of train
	
	/**
	 * ********************************************************
	 * Test the booster on the training data.
	 * 
	 * @return  The classification accuracy.
	 * ********************************************************
	 */
	public double test() {
		System.out.println("Testing on " + testingData.numInstances() + " instances.\r\n");
		
		return test(testingData);
	}//of test
	
	/**
	 * **********************************************************
	 * Test the booster.
	 * 
	 * @param paraInstances  The testing set.
	 * @return   The classification accuracy.
	 * **********************************************************
	 */
	public double test(Instances paraInstances) {
		double tempCorrect = 0;
		paraInstances.setClassIndex(paraInstances.numAttributes() - 1);
		
		for (int i = 0; i < paraInstances.numInstances(); i++) {
			Instance tempInstance = paraInstances.instance(i);
			if (classify(tempInstance) == (int)tempInstance.classValue()) {
				tempCorrect ++;
			}//of if
		}// of for i
		
		double resultAccuracy = tempCorrect / paraInstances.numInstances();
		System.out.println("The accuracy is: " + resultAccuracy);
		
		return resultAccuracy;
	}//of test
	
	/**
	 * *****************************************************
	 * For integration test.
	 * 
	 * @param args
	 * *****************************************************
	 */
	public static void main(String args[]) {
		System.out.println("Starting AdaBoosting...");
		Booster tempBooster = new Booster("E:/Datasets/UCIdatasets/其他数据集/iris.arff");
		
		tempBooster.setNumBaseClassifier(100);
		tempBooster.train();
		
		System.out.println("The training accuracy is: " + tempBooster.computeTrainingAccuracy());
		tempBooster.test();
	}//of main

}//of Booster

你可能感兴趣的:(集成学习,机器学习,java)