代码来自闵老师”日撸 Java 三百行(51-60天)“,链接:https://blog.csdn.net/minfanphd/article/details/116975957
代码实现的基本思路:先用构造函数KnnClassification读入数据,紧接着将读入的数据按照比例分为训练集和测试集,用训练集对测试集进行预测。预测的过程中需要计算距离,这里使用了两种距离计算公式,分别是欧氏距离和曼哈顿距离。根据计算出的距离,找出距离要预测的实例(测试集中的实例)最近的k个对象(训练集中的对象)。最后由最近的k个邻居投票决定要预测的实例的标签。预测结束后,根据预测结果计算预测的准确度。具体实现的代码如下:
package machinelearning.knn;
import java.io.FileReader;
import java.util.Random;
import javax.annotation.processing.RoundEnvironment;
import java.util.Arrays;
import java.util.PrimitiveIterator.OfDouble;
import weka.core.Instances;
/**
* KNN Classification
*
* @author WX873
*
*/
public class KnnClassification {
/**
* Manhattan distance
*/
public static final int MANHATTAN = 0;
/**
* Euclidean distance
*/
public static final int EUCLIDEAN = 1;
/**
* The distance measure.
*/
public int distanceMeasure = EUCLIDEAN;
/**
* A random instance;
*/
public static final Random random = new Random();
/**
* The number of neighbors.
*/
int numNeighbors = 7;
/**
* The whole dataSet.
*/
Instances dataset;
/**
* The training set. Represented by the indices of the data.
*/
int[] trainingSet;
/**
* The testing set. Represented by the indices of the data.
*/
int[] testingSet;
/**
* The predictions.
*/
int[] predictions;
/**
* **************************************************
* The first constructor.
*
* @param paraFilename The arff filename.
* **************************************************
*/
public KnnClassification(String paraFilename) {
// TODO Auto-generated constructor stub
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
dataset.setClassIndex(dataset.numAttributes() - 1);
fileReader.close();
} catch (Exception ee) {
// TODO: handle exception
System.out.println("Error occurred while trying to read \'" + paraFilename + "\' in KnnClassification constructor.\r\n" + ee);
System.exit(0);
}//of try
}//of the first constructor
public static int[] getRandomIndices(int paraLength) {
int[] resultInstances = new int[paraLength];
// Step 1. Initialize.
for (int i = 0; i < paraLength; i++) {
resultInstances[i] = i;
}//of for i
// Step 2. Randomly swap.
int tempFirst, tempSecond, tempValue;
for (int i = 0; i < paraLength; i++) {
// Generate two random indices.
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
// Swap.
tempValue = resultInstances[tempFirst];
resultInstances[tempFirst] = resultInstances[tempSecond];
resultInstances[tempSecond] = tempValue;
}//of for i
return resultInstances;
}//of getRandomIndices
/**
* **************************************************************
* Split the data into training and testing parts.
*
* @param paraTrainingFraction The fraction of the training set.
* **************************************************************
*/
public void splitTrainingTesting(double paraTrainingFraction) {
int tempSize = dataset.numInstances();
int[] tempIndices = getRandomIndices(tempSize);
int tempTrainingSize = (int)(tempSize * paraTrainingFraction);
trainingSet = new int[tempTrainingSize];
testingSet = new int[tempSize - tempTrainingSize];
for (int i = 0; i < tempTrainingSize; i++) {
trainingSet[i] = tempIndices[i];
}//of for i
for (int i = 0; i < testingSet.length; i++) {
testingSet[i] = tempIndices[tempTrainingSize + i];
}//of for i
}//of splitTrainingTesting
/**
* ******************************************************************
* The distance between two instances.
*
* @param paraI The index of the first instance.
* @param paraJ The index of the second instance.
* @return The distance.
* ******************************************************************
*/
public double distance(int paraI, int paraJ) {
double resultDistance = 0;
double tempDifference;
switch (distanceMeasure) {
case MANHATTAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
if (tempDifference < 0) {
resultDistance -= tempDifference;
}else {
resultDistance += tempDifference;
}//of if
}//of for i
break;
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
resultDistance += tempDifference * tempDifference;
}//of for i
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
}// of switch
return resultDistance;
}//of distance
/**
* ***********************************************************************
* Compute the nearest k neighbors. Select one neighbor in each scan. In
* fact we can scan only once. You may implement it by yourself.
*
* @param paraK The k value for kNN.
* @param paraCurrent Current instance. We are comparing it with all others.
* @return The indices of the nearest instances.
* ***********************************************************************
*/
public int[] computeNearests(int paraCurrent) {
int[] tempResultNearests = new int[numNeighbors];
double tempDistance;
boolean[] tempSelected = new boolean[trainingSet.length];
double tempMinimalDistance;
int tempMinimalIndex = 0;
//Select the nearest paraK indices.
for (int i = 0; i < numNeighbors; i++) {
tempMinimalDistance = Double.MAX_VALUE;
for (int j = 0; j < trainingSet.length; j++) {
if (tempSelected[j]) {
continue;
}//of if
tempDistance = distance(paraCurrent, trainingSet[j]);
if (tempDistance < tempMinimalDistance) {
tempMinimalDistance = tempDistance;
tempMinimalIndex = j;
}//of if
}//of for j
tempResultNearests[i] = trainingSet[tempMinimalIndex];
tempSelected[tempMinimalIndex] = true;
}//of for i
return tempResultNearests;
}//of computeNearests
/**
* ********************************************************************
* Voting using the instances.
*
* @param paraNeighbors The indices of the neighbors.
* @return The predicted label.
* ********************************************************************
*/
public int simpleVoting(int[] paraNeighbors) {
int[] tempVotes = new int[dataset.numClasses()];
for (int i = 0; i < paraNeighbors.length; i++) {
tempVotes[(int)dataset.instance(paraNeighbors[i]).classValue()]++;
}//of for i
int tempMaximalVotingIndex = 0;
int tempMaximalVoting = 0;
for (int i = 0; i < tempVotes.length; i++) {
if (tempVotes[i] > tempMaximalVoting) {
tempMaximalVoting = tempVotes[i];
tempMaximalVotingIndex = i;
}//of if
}//of for i
return tempMaximalVotingIndex;
}//of simpleVoting
/**
* ****************************************************************
* Predict for given instance.
*
* @param paraIndex The instance to predict.
* @return The prediction.
* ****************************************************************
*/
public int predict(int paraIndex) {
int[] tempNeighbors = computeNearests(paraIndex);
int resultPrediction = simpleVoting(tempNeighbors);
return resultPrediction;
}//of predict
/**
* ****************************************************************
* Predict for the whole testing set. The results are stored in predictions.
* @see predictions.
* ****************************************************************
*/
public void predict() {
predictions = new int[testingSet.length];
for (int i = 0; i < predictions.length; i++) {
predictions[i] = predict(testingSet[i]);
}//of for i
}//of predict
/**
* ******************************************************************
* Get the accuracy of the classifier.
*
* @return The accuracy.
* ******************************************************************
*/
public double getAccuracy() {
// A double divides an int gets another double.
double tempCorrect = 0;
for (int i = 0; i < predictions.length; i++) {
if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
tempCorrect++;
}//of if
}//of for i
return tempCorrect/predictions.length;
}//of getAccuracy
/**
* *********************************************************************
* The entrance of the program.
*
* @param args Not used now.
* *********************************************************************
*/
public static void main(String args[]) {
KnnClassification tempClassifier = new KnnClassification("E:/Datasets/iris.arff");
//System.out.println(tempClassifier.dataset);
tempClassifier.splitTrainingTesting(0.5);
tempClassifier.predict();
System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
}//of main
}//of class KnnClassification