日撸 Java 三百行day56-57

文章目录

  • 说明
  • day56-57 kMeans 聚类
    • 1.kMeans聚类理解
    • 2.代码理解
      • 2.1代码中变量的理解
      • 2.2代码理解

说明

闵老师的文章链接: 日撸 Java 三百行(总述)_minfanphd的博客-CSDN博客
自己也把手敲的代码放在了github上维护:https://github.com/fulisha-ok/sampledata

day56-57 kMeans 聚类

1.kMeans聚类理解

无监督的机器学习算法,其中k是划分为几个簇,并且选择k个数据作为不同簇的聚类中心,计算每个数据样本和聚类中心的距离(欧式距离或曼哈顿距离)并将数据样本分配给离聚类中心最近的类别。在遍历完所有数据后,则可以把数据集分成k个簇,对每个簇又要重新计算他的聚类中心(求平均值)。我们会进行多次迭代,直到聚类中心不变或者是到达一定次数的迭代。

2.代码理解

2.1代码中变量的理解

(主要是clustering()方法中的变量)

  • tempClusterArray
    当前循环中每个数据样本属于哪一个簇。如下值2=2代表数据样本2通过与k个聚类中心之间的计算,发现离2这个聚类中心距离最近,故将数据样本2聚类到1这个簇中。
    日撸 Java 三百行day56-57_第1张图片
  • tempOldClusterArray
    用于存储旧的聚类分配结果的数组(可以理解为上一次迭代对数据聚类的结果)
    日撸 Java 三百行day56-57_第2张图片- tempCenters
    存放聚类的中心。初始化时赋值为:对数据样本集随机排序,再随机选择数据集中的数据点作为初始聚类中心
    日撸 Java 三百行day56-57_第3张图片
  • tempNewCenters
    对循环后分类后的不同簇重新选择聚类中心(求平均值)
    日撸 Java 三百行day56-57_第4张图片

2.2代码理解

只要理解了KMeans的核心,代码分段读很好理解。

  • 1.选择簇的数量K(目前设置为3)、
  • 2.初始化聚类中心tempCenters(将数据集随机排序后选择前K个作为聚类中心)
  • 3.分配数据样本到簇(计算数据样本与聚类中心的距离,选择距离最短的)
  • 4.重新计算聚类中心(计算不同簇的平均值)
  • 5.重复步骤3,4(调出循环的条件是:tempOldClusterArray与tempClusterArray相等时 即上一次迭代和当前迭代聚类分配结果不再发生变化时)
package machinelearing.knn;

import weka.core.Instances;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

/**
 * @author: fulisha
 * @date: 2023-05-28 10:36
 * @description:
 */
public class KMeans {
    /**
     * Manhattan distance.
     */
    public static final int MANHATTAN = 0;


    /**
     * Euclidean distance.
     */
    public static final int EUCLIDEAN = 1;


    /**
     * The distance measure.
     */
    public int distanceMeasure = EUCLIDEAN;

    /**
     * A random instance;
     */
    public static final Random random = new Random();

    /**
     * The data.
     */
    Instances dataset;

    /**
     * The number of clusters.
     */
    int numClusters = 2;

    /**
     * The clusters.
     */
    int[][] clusters;


    /**
     * The first constructor.
     * @param paraFilename  The data filename.
     */
    public KMeans(String paraFilename) {
        dataset = null;
        try {
            FileReader fileReader = new FileReader(paraFilename);
            dataset = new Instances(fileReader);
            fileReader.close();
        } catch (Exception ee) {
            System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
            System.exit(0);
        }
    }

    public void setNumClusters(int paraNumClusters) {
        numClusters = paraNumClusters;
    }

    /**
     * Get a random indices for data randomization.
     * @param paraLength The length of the sequence.
     * @return  An array of indices, e.g., {4, 3, 1, 5, 0, 2} with length 6.
     */
    public static int[] getRandomIndices(int paraLength) {
        int[] resultIndices = new int[paraLength];

        // Step 1. Initialize.
        for (int i = 0; i < paraLength; i++) {
            resultIndices[i] = i;
        }

        // Step 2. Randomly swap.
        int tempFirst, tempSecond, tempValue;
        for (int i = 0; i < paraLength; i++) {
            // Generate two random indices.
            tempFirst = random.nextInt(paraLength);
            tempSecond = random.nextInt(paraLength);

            // Swap.
            tempValue = resultIndices[tempFirst];
            resultIndices[tempFirst] = resultIndices[tempSecond];
            resultIndices[tempSecond] = tempValue;
        }
        return resultIndices;
    }


    /**
     * The distance between two instances.
     * @param paraI The index of the first instance.
     * @param paraArray  The array representing a point in the space.
     * @return The distance.
     */
    public double distance(int paraI, double[] paraArray) {
        int resultDistance = 0;
        double tempDifference;
        switch (distanceMeasure) {
            case MANHATTAN:
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
                    if (tempDifference < 0) {
                        resultDistance -= tempDifference;
                    } else {
                        resultDistance += tempDifference;
                    }
                }
                break;
            case EUCLIDEAN:
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
                    resultDistance += tempDifference * tempDifference;
                }
                break;
            default:
                System.out.println("Unsupported distance measure: " + distanceMeasure);
        }

        return resultDistance;
    }


    public void clustering() {
        int[] tempOldClusterArray = new int[dataset.numInstances()];
        tempOldClusterArray[0] = -1;
        int[] tempClusterArray = new int[dataset.numInstances()];
        Arrays.fill(tempClusterArray, 0);
        double[][] tempCenters = new double[numClusters][dataset.numAttributes() - 1];

        // Step 1. Initialize centers.
        int[] tempRandomOrders = getRandomIndices(dataset.numInstances());
        for (int i = 0; i < numClusters; i++) {
            for (int j = 0; j < tempCenters[0].length; j++) {
                tempCenters[i][j] = dataset.instance(tempRandomOrders[i]).value(j);
            }
        }

        int[] tempClusterLengths = null;
        while (!Arrays.equals(tempOldClusterArray, tempClusterArray)) {
            System.out.println("New loop ...");
            tempOldClusterArray = tempClusterArray;
            tempClusterArray = new int[dataset.numInstances()];

            // Step 2.1 Minimization. Assign cluster to each instance.
            int tempNearestCenter;
            double tempNearestDistance;
            double tempDistance;

            for (int i = 0; i < dataset.numInstances(); i++) {
                tempNearestCenter = -1;
                tempNearestDistance = Double.MAX_VALUE;

                for (int j = 0; j < numClusters; j++) {
                    tempDistance = distance(i, tempCenters[j]);
                    if (tempNearestDistance > tempDistance) {
                        tempNearestDistance = tempDistance;
                        tempNearestCenter = j;
                    }
                }

                tempClusterArray[i] = tempNearestCenter;
            }

            // Step 2.2 Mean. Find new centers.
            tempClusterLengths = new int[numClusters];
            Arrays.fill(tempClusterLengths, 0);
            double[][] tempNewCenters = new double[numClusters][dataset.numAttributes() - 1];
            // Arrays.fill(tempNewCenters, 0);
            for (int i = 0; i < dataset.numInstances(); i++) {
                for (int j = 0; j < tempNewCenters[0].length; j++) {
                    tempNewCenters[tempClusterArray[i]][j] += dataset.instance(i).value(j);
                }
                tempClusterLengths[tempClusterArray[i]]++;
            }

            // Step 2.3 Now average
            for (int i = 0; i < tempNewCenters.length; i++) {
                for (int j = 0; j < tempNewCenters[0].length; j++) {
                    tempNewCenters[i][j] /= tempClusterLengths[i];
                }
            }

            System.out.println("Now the new centers are: " + Arrays.deepToString(tempNewCenters));
            tempCenters = tempNewCenters;
        }

        // Step 3. Form clusters.
        clusters = new int[numClusters][];
        int[] tempCounters = new int[numClusters];
        for (int i = 0; i < numClusters; i++) {
            clusters[i] = new int[tempClusterLengths[i]];
        }

        for (int i = 0; i < tempClusterArray.length; i++) {
            clusters[tempClusterArray[i]][tempCounters[tempClusterArray[i]]] = i;
            tempCounters[tempClusterArray[i]]++;
        }

        System.out.println("The clusters are: " + Arrays.deepToString(clusters));
    }


    public static void testClustering() {
        KMeans tempKMeans = new KMeans("C:/Users/fls/Desktop/sampledata/iris.arff");
        tempKMeans.setNumClusters(3);
        tempKMeans.clustering();
    }

    public static void main(String arags[]) {
        testClustering();
    }
}

  • 代码结果
    日撸 Java 三百行day56-57_第5张图片

你可能感兴趣的:(java,聚类,算法,机器学习,开发语言)