日撸代码300行:第51天(kNN分类器)

代码来自闵老师”日撸 Java 三百行(51-60天)“,链接:https://blog.csdn.net/minfanphd/article/details/116975957

代码实现的基本思路:先用构造函数KnnClassification读入数据,紧接着将读入的数据按照比例分为训练集和测试集,用训练集对测试集进行预测。预测的过程中需要计算距离,这里使用了两种距离计算公式,分别是欧氏距离和曼哈顿距离。根据计算出的距离,找出距离要预测的实例(测试集中的实例)最近的k个对象(训练集中的对象)。最后由最近的k个邻居投票决定要预测的实例的标签。预测结束后,根据预测结果计算预测的准确度。具体实现的代码如下:

package machinelearning.knn;

import java.io.FileReader;
import java.util.Random;

import javax.annotation.processing.RoundEnvironment;

import java.util.Arrays;
import java.util.PrimitiveIterator.OfDouble;

import weka.core.Instances;

/**
 * KNN Classification
 * 
 * @author WX873
 *
 */


public class KnnClassification {
	
	/**
	 * Manhattan distance
	 */
	public static final int MANHATTAN = 0;
	
	/**
	 * Euclidean distance
	 */
	public static final int EUCLIDEAN = 1;
	
	/**
	 * The distance measure.
	 */
	public int distanceMeasure = EUCLIDEAN;
	
	/**
	 * A random instance;
	 */
	public static final Random random = new Random();
	
	/**
	 * The number of neighbors.
	 */
	int numNeighbors = 7;
	
	/**
	 * The whole dataSet.
	 */
	Instances dataset;
	
	/**
	 * The training set. Represented by the indices of the data.
	 */
	int[] trainingSet;
	
	/**
	 * The testing set. Represented by the indices of the data.
	 */
	int[] testingSet;
	
	/**
	 * The predictions.
	 */
	int[] predictions;
	
	/**
	 * **************************************************
	 * The first constructor.
	 * 
	 * @param paraFilename  The arff filename.
	 * **************************************************
	 */
	public KnnClassification(String paraFilename) {
		// TODO Auto-generated constructor stub
		try {
			FileReader fileReader = new FileReader(paraFilename);
			dataset = new Instances(fileReader);
			dataset.setClassIndex(dataset.numAttributes() - 1);
			fileReader.close();
		} catch (Exception ee) {
			// TODO: handle exception
			System.out.println("Error occurred while trying to read \'" + paraFilename + "\' in KnnClassification constructor.\r\n" + ee);
			System.exit(0);
		}//of try
	}//of the first constructor
	
	public static int[] getRandomIndices(int paraLength) {
		int[] resultInstances = new int[paraLength];
		
		// Step 1. Initialize.
		for (int i = 0; i < paraLength; i++) {
			resultInstances[i] = i;
		}//of for i
		
		// Step 2. Randomly swap.
		int tempFirst, tempSecond, tempValue;
		for (int i = 0; i < paraLength; i++) {
			// Generate two random indices.
			tempFirst = random.nextInt(paraLength);
			tempSecond = random.nextInt(paraLength);
			// Swap.
			tempValue = resultInstances[tempFirst];
			resultInstances[tempFirst] = resultInstances[tempSecond];
			resultInstances[tempSecond] = tempValue;
		}//of for i
		return resultInstances;
	}//of getRandomIndices
	
	/**
	 * **************************************************************
	 * Split the data into training and testing parts.
	 * 
	 * @param paraTrainingFraction  The fraction of the training set.
	 * **************************************************************
	 */
	public void splitTrainingTesting(double paraTrainingFraction) {
		int tempSize = dataset.numInstances();
		int[] tempIndices = getRandomIndices(tempSize);
		int tempTrainingSize = (int)(tempSize * paraTrainingFraction);
		
		trainingSet = new int[tempTrainingSize];
		testingSet = new int[tempSize - tempTrainingSize];
		
		for (int i = 0; i < tempTrainingSize; i++) {
			trainingSet[i] = tempIndices[i];
		}//of for i
		
		for (int i = 0; i < testingSet.length; i++) {
			testingSet[i] = tempIndices[tempTrainingSize + i];
		}//of for i
		
	}//of splitTrainingTesting
	
	/**
	 * ******************************************************************
	 * The distance between two instances.
	 * 
	 * @param paraI  The index of the first instance.
	 * @param paraJ  The index of the second instance.
	 * @return  The distance.
	 * ******************************************************************
	 */
	public double distance(int paraI, int paraJ) {
		double resultDistance = 0;
		double tempDifference;
		switch (distanceMeasure) {
		case MANHATTAN:
			
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				if (tempDifference < 0) {
					resultDistance -= tempDifference;
				}else {
					resultDistance += tempDifference;
				}//of if
			}//of for i
			break;
			
		case EUCLIDEAN:
			
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				resultDistance += tempDifference * tempDifference;
			}//of for i
			break;

		default:
			System.out.println("Unsupported distance measure: " + distanceMeasure);
		}// of switch
		return resultDistance;
	}//of distance
	
	/**
	 * ***********************************************************************
	 * Compute the nearest k neighbors. Select one neighbor in each scan.  In
	 * fact we can scan only once. You may implement it by yourself.
	 * 
	 * @param paraK  The k value for kNN.
	 * @param paraCurrent Current instance. We are comparing it with all others.
	 * @return    The indices of the nearest instances.
	 * ***********************************************************************
	 */
	public int[] computeNearests(int paraCurrent) {
		int[] tempResultNearests = new int[numNeighbors];
		double tempDistance;
		boolean[] tempSelected = new boolean[trainingSet.length];
		double tempMinimalDistance;
		int tempMinimalIndex = 0;
		
		//Select the nearest paraK indices.
		for (int i = 0; i < numNeighbors; i++) {
			tempMinimalDistance = Double.MAX_VALUE;
			for (int j = 0; j < trainingSet.length; j++) {
				if (tempSelected[j]) {
					continue;
				}//of if
				
				tempDistance = distance(paraCurrent, trainingSet[j]);
				if (tempDistance < tempMinimalDistance) {
					tempMinimalDistance = tempDistance;
					tempMinimalIndex = j;
				}//of if
			}//of for j
			
			tempResultNearests[i] = trainingSet[tempMinimalIndex];
			tempSelected[tempMinimalIndex] = true;
		}//of for i
		
		return tempResultNearests;
	}//of computeNearests
	
	/**
	 * ********************************************************************
	 * Voting using the instances.
	 * 
	 * @param paraNeighbors  The indices of the neighbors.
	 * @return  The predicted label.
	 * ********************************************************************
	 */
	public int simpleVoting(int[] paraNeighbors) {
		int[] tempVotes = new int[dataset.numClasses()];
		
		for (int i = 0; i < paraNeighbors.length; i++) {
			tempVotes[(int)dataset.instance(paraNeighbors[i]).classValue()]++;
		}//of for i
		
		int tempMaximalVotingIndex = 0;
		int tempMaximalVoting = 0;
		for (int i = 0; i < tempVotes.length; i++) {
			if (tempVotes[i] > tempMaximalVoting) {
				tempMaximalVoting = tempVotes[i];
				tempMaximalVotingIndex = i;
			}//of if
		}//of for i
		return tempMaximalVotingIndex;
	}//of simpleVoting
	
	/**
	 * ****************************************************************
	 * Predict for given instance.
	 * 
	 * @param paraIndex The instance to predict.
	 * @return  The prediction.
	 * ****************************************************************
	 */
	public int predict(int paraIndex) {
		int[] tempNeighbors = computeNearests(paraIndex);
		int resultPrediction = simpleVoting(tempNeighbors);
		
		return resultPrediction;
	}//of predict
	
	/**
	 * ****************************************************************
	 * Predict for the whole testing set. The results are stored in predictions.
	 * @see predictions.
	 * ****************************************************************
	 */
	public void predict() {
		predictions = new int[testingSet.length];
		
		for (int i = 0; i < predictions.length; i++) {
			predictions[i] = predict(testingSet[i]);
		}//of for i
	}//of predict 
	
	/**
	 * ******************************************************************
	 * Get the accuracy of the classifier.
	 * 
	 * @return  The accuracy.
	 * ******************************************************************
	 */
	public double getAccuracy() {
		// A double divides an int gets another double.
		double tempCorrect = 0;
		for (int i = 0; i < predictions.length; i++) {
			if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
				tempCorrect++;
			}//of if
		}//of for i
		
		return tempCorrect/predictions.length;
	}//of getAccuracy
	/**
	 * *********************************************************************
	 * The entrance of the program.
	 * 
	 * @param args  Not used now.
	 * *********************************************************************
	 */
	public static void main(String args[]) {
		KnnClassification tempClassifier = new KnnClassification("E:/Datasets/iris.arff");
		//System.out.println(tempClassifier.dataset);
		tempClassifier.splitTrainingTesting(0.5);
		tempClassifier.predict();
		System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
	}//of main

}//of class KnnClassification

你可能感兴趣的:(java,knn)