代码来自闵老师”日撸 Java 三百行(61-70天)
日撸 Java 三百行(61-70天,决策树与集成学习)_闵帆的博客-CSDN博客
今天的代码完成的是基础分类器的集成,类名为Booster。构造函数传入数据集后,设置了训练集和测试集,两个是一样的都是传入的数据集。
setNumBaseClassifier方法设置基础分类器的个数,里面的变量classifierWeights保存的是分类器的权值。这里要区分开实例权重和分类器权重。
classify方法里,所有方法都进行分类。分类的时候各个分类器根据自身的权重进行。最终所有分类器给出的所有标签得分最高的就是预测的类标签。
train方法里面对每一个分类器进行训练。分类器是昨天的树桩分类。遍历分类器个数的for循环里,i=0时是初始化对象tempWeightInstances。else里面是进行对象属性的权重调整,调用的是前天代码里的adjustWeights方法。接下来第二步是训练树桩分类器。“Set the classifier weight.”这一步就是分类器的权重计算公式,套进去就行了。至于后面的if语句,应该是为了防止出现无意义的权重。测试了将classifierWeights[i]赋值为0和1e-6,结果是一样的。最后一步是如果精度达到0.999999就在这一轮循环终止了,已经达到预期的精度,并输出是在第几轮停止的。测试了一下,大概会在150轮左右跳出。
package machinelearning.adaboosting;
import java.io.FileReader;
import weka.core.Instance;
import weka.core.Instances;
/**
* The booster which ensembles base classifiers.
* @author WX873
*
*/
public class Booster {
/**
* Classifiers.
*/
SimpleClassifier[] classifiers;
/**
* Number of classifiers.
*/
int numClassifiers;
/**
* Whether or not stop after the training error is 0.
*/
boolean stopAfterConverge = false;
/**
* The weights of classifiers.
*/
double[] classifierWeights;
/**
* The training data.
*/
Instances trainingData;
/**
* The test data.
*/
Instances testingData;
/**
* ******************************************************
* The first constructor. The testing set is the same as the training set.
*
* @param paraTrainingFilename The data filename.
* ******************************************************
*/
public Booster(String paraTrainingFilename) {
// TODO Auto-generated constructor stub
try {
FileReader tempFileReader = new FileReader(paraTrainingFilename);
trainingData = new Instances(tempFileReader);
tempFileReader.close();
} catch (Exception e) {
// TODO: handle exception
System.out.println("Cannot read the file: " + paraTrainingFilename + "\r\n" + e);
System.exit(0);
}//of try
// Step 2. Set the last attribute as the class index.
trainingData.setClassIndex(trainingData.numAttributes() - 1);
// Step 3. The testing data is the same as the training data.
testingData = trainingData;
stopAfterConverge = true;
System.out.println("****************Data**********\r\n" + trainingData);
}//Of the first constructor
/**
* *********************************************************************
* Set the number of base classifier, and allocate space for them.
*
* @param paraNumBaseClassifiers The number of base classifier.
* *********************************************************************
*/
public void setNumBaseClassifier(int paraNumBaseClassifiers) {
numClassifiers = paraNumBaseClassifiers;
//Step 1. Allocate space (only reference) for classifiers
classifiers = new SimpleClassifier[numClassifiers];
//Step 2. Initialize classifier weights.
classifierWeights = new double[numClassifiers];
}//of setNumBaseClassifier
/**
******************
* Classify an instance.
*
* @param paraInstance
* The given instance.
* @return The predicted label.
******************
*/
public int classify(Instance paraInstance) {
double[] tempLabelsCountArray = new double[trainingData.classAttribute().numValues()];
//所有分类器都对当前实例进行分类,每个分类器按照自己的权重进行。
for (int i = 0; i < numClassifiers; i++) {
int tempLabel = classifiers[i].classify(paraInstance);
tempLabelsCountArray[tempLabel] += classifierWeights[i];
}//of for i
int resultLabel = -1;
double tempMax = -1;
//最终权重和(得分)最高的标签就是预测标签
for (int i = 0; i < tempLabelsCountArray.length; i++) {
if (tempMax < tempLabelsCountArray[i]) {
tempMax = tempLabelsCountArray[i];
resultLabel = i;
}//of if
}//of for i
return resultLabel;
}//of classify
/**
* ***************************************************************
* Compute the training accuracy of the booster. It is not weighted.
*
* @return The training accuracy.
* ***************************************************************
*/
public double computeTrainingAccuracy() {
double tempCorrect = 0;
for (int i = 0; i < trainingData.numInstances(); i++) {
if (classify(trainingData.instance(i)) == (int)trainingData.instance(i).classValue()) {
tempCorrect++;
}//of if
}//of for i
double tempAccuracy = tempCorrect / trainingData.numInstances();
return tempAccuracy;
}//of computeTrainingAccuracy
/**
* **********************************************************************
* Train the booster.
*
* @see algorithm.StumpClassifier#train()
* **********************************************************************
*/
public void train() {
// Step 1. Initialize.
WeightedInstances tempWeightInstances = null;
double tempError;
numClassifiers = 0;
//Step 2. Build other classifiers.
for (int i = 0; i < classifiers.length; i++) {
// Step 2.1 Key code: Construct or adjust the weightedInstances
if (i == 0) {
tempWeightInstances = new WeightedInstances(trainingData);
} else {
// Adjust the weights of the data.
tempWeightInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(), classifierWeights[i - 1]);
}//of if
// Step 2.2 Train the next classifier.
classifiers[i] = new StumpClassifier(tempWeightInstances);
classifiers[i].train();
tempError = classifiers[i].computeWeightedError();
// Key code: Set the classifier weight.
classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
if (classifierWeights[i] < 1e-6) {
classifierWeights[i] = 0;
}//of if
System.out.println("Classifier #" + i + " , weighted error = " + tempError + ", weight = "
+ classifierWeights[i] + "\r\n");
numClassifiers++;
// The accuracy is enough.
if (stopAfterConverge) {
double tempTrainingAccuracy = computeTrainingAccuracy();
System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
if (tempTrainingAccuracy > 0.999999) {
System.out.println("Stop at the round: " + i + " due to converge.\r\n");
break;
}//of if
}//of if
}//of for i
}//of train
/**
* ********************************************************
* Test the booster on the training data.
*
* @return The classification accuracy.
* ********************************************************
*/
public double test() {
System.out.println("Testing on " + testingData.numInstances() + " instances.\r\n");
return test(testingData);
}//of test
/**
* **********************************************************
* Test the booster.
*
* @param paraInstances The testing set.
* @return The classification accuracy.
* **********************************************************
*/
public double test(Instances paraInstances) {
double tempCorrect = 0;
paraInstances.setClassIndex(paraInstances.numAttributes() - 1);
for (int i = 0; i < paraInstances.numInstances(); i++) {
Instance tempInstance = paraInstances.instance(i);
if (classify(tempInstance) == (int)tempInstance.classValue()) {
tempCorrect ++;
}//of if
}// of for i
double resultAccuracy = tempCorrect / paraInstances.numInstances();
System.out.println("The accuracy is: " + resultAccuracy);
return resultAccuracy;
}//of test
/**
* *****************************************************
* For integration test.
*
* @param args
* *****************************************************
*/
public static void main(String args[]) {
System.out.println("Starting AdaBoosting...");
Booster tempBooster = new Booster("E:/Datasets/UCIdatasets/其他数据集/iris.arff");
tempBooster.setNumBaseClassifier(100);
tempBooster.train();
System.out.println("The training accuracy is: " + tempBooster.computeTrainingAccuracy());
tempBooster.test();
}//of main
}//of Booster