机器学习——kMeans聚类

文章目录

  • 一、kMeans是什么?
  • 二、算法步骤
  • 三、实现代码

一、kMeans是什么?

kMeans算法是最常用的聚类算法,该算法的主要作用是将相似的样本自动归到一个类别中。

kMeans算法十分简单易懂而且非常有效,但是合理的确定K值和K个初始类簇中心点对于聚类效果的好坏有很大的影响。

同时,因为每次分簇是我们是依据每个散点到中心点的平均距离来确定的,因此任意选取点总是围绕中心点为一定半径范围内,因此kMeans很适合于球形数据。

二、算法步骤

(1)给定K值和K个初始类簇中心点
(2)把每个点分到离其最近的类簇中心点所代表的类簇中
(3)所有点分配完毕之后,根据一个类簇内的所有点重新计算该类簇的中心点(取平均值)
(4)然后再迭代的进行分配点和更新类簇中心点的步骤,直至类簇中心点的变化很小,或者达到指定的迭代次数。

三、实现代码

KMeans的getRandomIndices()方法和 kNN的完全相同。

package machinelearning.kmeans;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

import weka.core.Instances;

/**
 * 
 * @author Ling Lin E-mail:[email protected]
 * 
 * @version 创建时间:2022年4月30日 下午9:33:54
 * 
 */
public class kMeans {

	// Manhattan distance.
	// 曼哈顿距离
	public static final int MANHATTAN = 0;

	// Euclidean distance.
	// 欧几里得距离
	public static final int EUCLIDEAN = 1;

	// The distance measure.
	public int distanceMeasure = EUCLIDEAN;

	// A random instance;
	public static final Random random = new Random();

	// The data.数据集
	Instances dataset;

	// The number of clusters.
	int numClusters = 2;

	// The clusters.
	int[][] clusters;

	/**
	 * The first constructor.
	 * 
	 * @param paraFilename
	 *            The data filename. 读文件
	 */
	public kMeans(String paraFilename) {
		dataset = null;
		try {
			FileReader fileReader = new FileReader(paraFilename);
			dataset = new Instances(fileReader);
			fileReader.close();
		} catch (Exception ee) {
			System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
			System.exit(0);
		} // Of try
	}// Of the first constructor

	/**
	 * A setter.给定K值
	 */
	public void setNumClusters(int paraNumClusters) {
		numClusters = paraNumClusters;
	}// Of the setter

	/**
	 * Get a random indices for data randomization.
	 * 
	 * @param paraLength
	 *            The length of the sequence.
	 * @return An array of indices, e.g., {4, 3, 1, 5, 0, 2} with length 6.
	 */
	public static int[] getRandomIndices(int paraLength) {
		int[] resultIndices = new int[paraLength];

		// Step 1. Initialize.
		for (int i = 0; i < paraLength; i++) {
			resultIndices[i] = i;
		} // Of for i

		// Step 2. Randomly swap.
		int tempFirst, tempSecond, tempValue;
		for (int i = 0; i < paraLength; i++) {
			// Generate two random indices.
			tempFirst = random.nextInt(paraLength);
			tempSecond = random.nextInt(paraLength);

			// Swap.
			tempValue = resultIndices[tempFirst];
			resultIndices[tempFirst] = resultIndices[tempSecond];
			resultIndices[tempSecond] = tempValue;
		} // Of for i

		return resultIndices;
	}// Of getRandomIndices

	/**
	 * The distance between two instances.
	 * 
	 * @param paraI
	 *            The index of the first instance.数据集中点的下标
	 * @param paraArray
	 *            The array representing a point in the space.每簇的中心点
	 * @return The distance.
	 */
	public double distance(int paraI, double[] paraArray) {
		int resultDistance = 0;
		double tempDifference;
		switch (distanceMeasure) {
		case MANHATTAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
				if (tempDifference < 0) {
					resultDistance -= tempDifference;
				} else {
					resultDistance += tempDifference;
				} // Of if
			} // Of for i
			break;

		case EUCLIDEAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
				resultDistance += tempDifference * tempDifference;
			} // Of for i
			break;
		default:
			System.out.println("Unsupported distance measure: " + distanceMeasure);
		}// Of switch

		return resultDistance;
	}// Of distance

	/**
	 * Clustering.
	 */
	public void clustering() {

		int[] tempOldClusterArray = new int[dataset.numInstances()];
		tempOldClusterArray[0] = -1;

		int[] tempClusterArray = new int[dataset.numInstances()];
		Arrays.fill(tempClusterArray, 0);

		// 中心点的二维数组,有几簇就有几个中心点,即numClusters行。
		// dataset.numAttributes() - 1是要去掉iris数据集中的决策属性
		double[][] tempCenters = new double[numClusters][dataset.numAttributes() - 1];

		// Step 1. Initialize centers.
		// 从打乱了的数据集中随机选取前K个作为初始中心点
		int[] tempRandomOrders = getRandomIndices(dataset.numInstances());
		for (int i = 0; i < numClusters; i++) {
			for (int j = 0; j < tempCenters[0].length; j++) {
				tempCenters[i][j] = dataset.instance(tempRandomOrders[i]).value(j);
			} // Of for j
		} // Of for i

		int[] tempClusterLengths = null;

		while (!Arrays.equals(tempOldClusterArray, tempClusterArray)) {
			System.out.println("New loop ...");
			tempOldClusterArray = tempClusterArray;
			tempClusterArray = new int[dataset.numInstances()];

			// Step 2.1 Minimization. Assign cluster to each instance.
			// 给数据集中的每一行数据找到离它最近的中心点

			int tempNearestCenter;
			double tempNearestDistance;
			double tempDistance;

			for (int i = 0; i < dataset.numInstances(); i++) {
				tempNearestCenter = -1;
				tempNearestDistance = Double.MAX_VALUE;

				for (int j = 0; j < numClusters; j++) {
					tempDistance = distance(i, tempCenters[j]);
					if (tempNearestDistance > tempDistance) {
						tempNearestDistance = tempDistance;
						tempNearestCenter = j;
					} // Of if
				} // Of for j

				// 保存每i行数据属于哪个中心点,即中心点数组的下标tempNearestCenter
				tempClusterArray[i] = tempNearestCenter;
			} // Of for i

			// Step 2.2 Mean. Find new centers.
			// 分好簇之后,从各个簇里取平均值找到新的中心点

			tempClusterLengths = new int[numClusters];// 保存每一簇里有多少个点
			Arrays.fill(tempClusterLengths, 0);

			double[][] tempNewCenters = new double[numClusters][dataset.numAttributes() - 1];

			// 遍历数据集,把每一簇里的每一行数据的每一列累加到对应中心点的每一列上
			for (int i = 0; i < dataset.numInstances(); i++) {
				for (int j = 0; j < tempNewCenters[0].length; j++) {
					tempNewCenters[tempClusterArray[i]][j] += dataset.instance(i).value(j);
				} // Of for j

				tempClusterLengths[tempClusterArray[i]]++;
			} // Of for i

			// Step 2.3 Now average
			// 取平均数得到新的中心点
			for (int i = 0; i < tempNewCenters.length; i++) {
				for (int j = 0; j < tempNewCenters[0].length; j++) {
					tempNewCenters[i][j] /= tempClusterLengths[i];
				} // Of for j
			} // Of for i

			System.out.println("Now the new centers are: " + Arrays.deepToString(tempNewCenters));
			tempCenters = tempNewCenters;
		} // Of while

		// Step 3. Form clusters.
		// 将一维数组转换成二维数组
		clusters = new int[numClusters][];
		int[] tempCounters = new int[numClusters];
		for (int i = 0; i < numClusters; i++) {
			clusters[i] = new int[tempClusterLengths[i]];// clusters数组的每一行个数等于对应中心点所属簇的个数
		} // Of for i

		for (int i = 0; i < tempClusterArray.length; i++) {
			clusters[tempClusterArray[i]][tempCounters[tempClusterArray[i]]] = i;
			tempCounters[tempClusterArray[i]]++;
		} // Of for i

		System.out.println("The clusters are: " + Arrays.deepToString(clusters));
	}// Of clustering

	/**
	 ******************************* 
	 * Clustering.
	 ******************************* 
	 */
	public static void testClustering() {
		kMeans tempKMeans = new kMeans("D:/00/data/iris.arff");
		tempKMeans.setNumClusters(3);
		tempKMeans.clustering();
	}// Of testClustering

	/**
	 ************************* 
	 * A testing method.
	 ************************* 
	 */
	public static void main(String arags[]) {
		testClustering();
	}// Of main

}// Of class kMeans


机器学习——kMeans聚类_第1张图片

你可能感兴趣的:(聚类,机器学习,kmeans)