日撸java_day51-53

KNN分类器

package machineLearning.knn;

import weka.core.Instances;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

/**
 * ClassName: KnnClassification
 * Package: machineLearning.knn
 * Description:
 *
 * @Author: luv_x_c
 * @Create: 2023/6/7 17:54
 */
public class KnnClassification {

    /**
     * Manhattan distance.
     */
    public static final int MANHATTAN = 0;


    /**
     * Euclidean distance.
     */
    public static final int EUCLIDEAN = 1;

    /**
     * The distance measure.
     */
    private int distanceMeasure;

    /**
     * A random instance.
     */
    public static final Random RANDOM = new Random();

    /**
     * The number of neighbors.
     */
    private int numNeighbors;

    /**
     * The whole dataset.
     */
    Instances dataset;

    /**
     * The training set. Represented by the indices of the data.
     */
    int[] trainingSet;

    /**
     * The testing set. Represented by the indices of the data.
     */
    int[] testingSet;

    /**
     * The predictions.
     */
    int[] predictions;

    public void setDistanceMeasure(int paraDistanceMeasure) {
        this.distanceMeasure = paraDistanceMeasure;
    }

    public void setNumNeighbors(int paraNumNeighbors) {
        this.numNeighbors = paraNumNeighbors;
    }

    /**
     * The first constructor.
     *
     * @param paraFileName The arff filename.
     */
    public KnnClassification(String paraFileName) {
        try {
            FileReader fileReader = new FileReader(paraFileName);
            dataset = new Instances(fileReader);
            // The last attributes is the decision class.
            dataset.setClassIndex(dataset.numAttributes() - 1);
            fileReader.close();
        } catch (Exception e) {
            System.out.println("Error occurred while trying to read \'" + paraFileName
                    + "\' in KnnClassification constructor.\r\n" + e);
            System.exit(0);
        }// Of try
    }// Of the first constructor

    /**
     * Get a random indices for data randomization.
     *
     * @param paraLength The length of the sequence.
     * @return An array of indices.
     */
    public static int[] getRandomIndices(int paraLength) {
        int[] resultIndices = new int[paraLength];

        // Step1 . Initialize.
        for (int i = 0; i < paraLength; i++) {
            resultIndices[i] = i;
        }// Of for i

        // Step2 . Randomly swap.
        int tempFirst, tempSecond, tempValue;
        for (int i = 0; i < paraLength; i++) {
            // Generate two random indices.
            tempFirst = RANDOM.nextInt(paraLength);
            tempSecond = RANDOM.nextInt(paraLength);

            // Swap.
            tempValue = resultIndices[tempFirst];
            resultIndices[tempFirst] = resultIndices[tempSecond];
            resultIndices[tempSecond] = tempValue;
        }// Of for i

        return resultIndices;
    }//Of getRandomIndices

    /**
     * Split the data into training and testing parts.
     *
     * @param paraTrainingFraction The fraction of the training set.
     */
    public void splitTrainingTesting(double paraTrainingFraction) {
        int tempSize = dataset.numInstances();
        int[] tempIndices = getRandomIndices(tempSize);
        int tempTrainingSize = (int) (tempSize * paraTrainingFraction);

        trainingSet = new int[tempTrainingSize];
        testingSet = new int[tempSize - tempTrainingSize];

        for (int i = 0; i < tempTrainingSize; i++) {
            trainingSet[i] = tempIndices[i];
        }// Of for i

        for (int i = 0; i < tempSize - tempTrainingSize; i++) {
            testingSet[i] = tempIndices[i + tempTrainingSize];
        }// Of for i
    }// Of splitTrainingTesting

    /**
     * Predict for the whole testing set. The results are stored in predictions.
     */
    public void predict() {
        predictions = new int[testingSet.length];
        for (int i = 0; i < predictions.length; i++) {
            predictions[i] = predict(testingSet[i]);
        }// Of for i
    }// Of predict

    /**
     * Predict for given instance.
     *
     * @param paraIndex The given index.
     * @return The prediction.
     */
    private int predict(int paraIndex) {
        int[] tempNeighbors = computeNearest(paraIndex);

        return simpleVoting(tempNeighbors);
    }// Of  predict

    /**
     * Voting using the instances.
     *
     * @param paraNeighbors The indices of the neighbors.
     * @return The predicted label.
     */
    private int simpleVoting(int[] paraNeighbors) {
        int[] tempVotes = new int[dataset.numClasses()];
        for (int paraNeighbor : paraNeighbors) {
            tempVotes[(int) dataset.instance(paraNeighbor).classValue()]++;
        }// Of for i

        int tempMaximalVotingIndex = 0;
        int tempMaximalVoting = 0;
        for (int i = 0; i < dataset.numClasses(); i++) {
            if (tempVotes[i] > tempMaximalVoting) {
                tempMaximalVoting = tempVotes[i];
                tempMaximalVotingIndex = i;
            }// Of if
        }// Of for i

        return tempMaximalVotingIndex;

    }// Of simpleVoting

//    /**
//     * Compute the nearest k neighbors.
//     *
//     * @param paraCurrent current instance. We are comparing it with all others.
//     * @return The indices of the nearest instances.
//     */
//    private int[] computeNearest(int paraCurrent) {
//        int[] resultNearest = new int[numNeighbors];
//        boolean[] tempSelected = new boolean[trainingSet.length];
//        double tempMinimalDistance;
//        int tempMinimalIndex = 0;
//
//        // Compute all distances to avoid redundant computation.
//        double[] tempDistances = new double[trainingSet.length];
//        for (int i = 0; i < trainingSet.length; i++) {
//            tempDistances[i] = distance(paraCurrent, trainingSet[i]);
//        }// Of for i
//
//        // Select the nearest paraK indices.
//        for (int i = 0; i < numNeighbors; i++) {
//            tempMinimalDistance = Double.MAX_VALUE;
//
//            for (int j = 0; j < trainingSet.length; j++) {
//                if (tempSelected[j]) {
//                    continue;
//                }// Of if
//
//                if (tempDistances[j] < tempMinimalDistance) {
//                    tempMinimalDistance = tempDistances[j];
//                    tempMinimalIndex = j;
//                }// Of if
//            }// OF for j
//
//            resultNearest[i] = trainingSet[tempMinimalIndex];
//            tempSelected[tempMinimalIndex] = true;
//        }// Of for i
//
//        System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearest));
//        return resultNearest;
//    }// Of computeNearest

    /**
     * Compute the nearest k neighbors.
     *
     * @param paraCurrent current instance. We are comparing it with all others.
     * @return The indices of the nearest instances.
     */
    private int[] computeNearest(int paraCurrent) {
        int[] resultNearest = new int[numNeighbors];
        double[] tempDistance = new double[numNeighbors];

        double tempCurrentDistance;
        int tempCurrentIndex = 0;

        // Compute all distances and sort.
        for (int i = 0; i < trainingSet.length; i++) {
            tempCurrentDistance = distance(paraCurrent, trainingSet[i]);

            // Search and insert.
            tempCurrentIndex = i < numNeighbors ? i : numNeighbors - 1;
            boolean tempIsMove = false;
            while (tempCurrentIndex > 0 && tempDistance[tempCurrentIndex - 1] > tempCurrentDistance) {
                tempDistance[tempCurrentIndex] = tempDistance[tempCurrentIndex - 1];
                resultNearest[tempCurrentIndex] = resultNearest[tempCurrentIndex - 1];
                tempCurrentIndex--;
                tempIsMove = true;
            }// Of while
            if (tempIsMove || tempCurrentIndex

日撸java_day51-53_第1张图片

 增加了设置邻居和距离度量的set方法并把计算邻居的重新用插排实现了一下,

你可能感兴趣的:(java)