kMeans算法是最常用的聚类算法,该算法的主要作用是将相似的样本自动归到一个类别中。
kMeans算法十分简单易懂而且非常有效,但是合理的确定K值和K个初始类簇中心点对于聚类效果的好坏有很大的影响。
同时,因为每次分簇是我们是依据每个散点到中心点的平均距离来确定的,因此任意选取点总是围绕中心点为一定半径范围内,因此kMeans很适合于球形数据。
(1)给定K值和K个初始类簇中心点
(2)把每个点分到离其最近的类簇中心点所代表的类簇中
(3)所有点分配完毕之后,根据一个类簇内的所有点重新计算该类簇的中心点(取平均值)
(4)然后再迭代的进行分配点和更新类簇中心点的步骤,直至类簇中心点的变化很小,或者达到指定的迭代次数。
KMeans的getRandomIndices()方法和 kNN的完全相同。
package machinelearning.kmeans;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.Instances;
/**
*
* @author Ling Lin E-mail:[email protected]
*
* @version 创建时间:2022年4月30日 下午9:33:54
*
*/
public class kMeans {
// Manhattan distance.
// 曼哈顿距离
public static final int MANHATTAN = 0;
// Euclidean distance.
// 欧几里得距离
public static final int EUCLIDEAN = 1;
// The distance measure.
public int distanceMeasure = EUCLIDEAN;
// A random instance;
public static final Random random = new Random();
// The data.数据集
Instances dataset;
// The number of clusters.
int numClusters = 2;
// The clusters.
int[][] clusters;
/**
* The first constructor.
*
* @param paraFilename
* The data filename. 读文件
*/
public kMeans(String paraFilename) {
dataset = null;
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception ee) {
System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
System.exit(0);
} // Of try
}// Of the first constructor
/**
* A setter.给定K值
*/
public void setNumClusters(int paraNumClusters) {
numClusters = paraNumClusters;
}// Of the setter
/**
* Get a random indices for data randomization.
*
* @param paraLength
* The length of the sequence.
* @return An array of indices, e.g., {4, 3, 1, 5, 0, 2} with length 6.
*/
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
// Step 1. Initialize.
for (int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
} // Of for i
// Step 2. Randomly swap.
int tempFirst, tempSecond, tempValue;
for (int i = 0; i < paraLength; i++) {
// Generate two random indices.
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
// Swap.
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
} // Of for i
return resultIndices;
}// Of getRandomIndices
/**
* The distance between two instances.
*
* @param paraI
* The index of the first instance.数据集中点的下标
* @param paraArray
* The array representing a point in the space.每簇的中心点
* @return The distance.
*/
public double distance(int paraI, double[] paraArray) {
int resultDistance = 0;
double tempDifference;
switch (distanceMeasure) {
case MANHATTAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
if (tempDifference < 0) {
resultDistance -= tempDifference;
} else {
resultDistance += tempDifference;
} // Of if
} // Of for i
break;
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
resultDistance += tempDifference * tempDifference;
} // Of for i
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
}// Of switch
return resultDistance;
}// Of distance
/**
* Clustering.
*/
public void clustering() {
int[] tempOldClusterArray = new int[dataset.numInstances()];
tempOldClusterArray[0] = -1;
int[] tempClusterArray = new int[dataset.numInstances()];
Arrays.fill(tempClusterArray, 0);
// 中心点的二维数组,有几簇就有几个中心点,即numClusters行。
// dataset.numAttributes() - 1是要去掉iris数据集中的决策属性
double[][] tempCenters = new double[numClusters][dataset.numAttributes() - 1];
// Step 1. Initialize centers.
// 从打乱了的数据集中随机选取前K个作为初始中心点
int[] tempRandomOrders = getRandomIndices(dataset.numInstances());
for (int i = 0; i < numClusters; i++) {
for (int j = 0; j < tempCenters[0].length; j++) {
tempCenters[i][j] = dataset.instance(tempRandomOrders[i]).value(j);
} // Of for j
} // Of for i
int[] tempClusterLengths = null;
while (!Arrays.equals(tempOldClusterArray, tempClusterArray)) {
System.out.println("New loop ...");
tempOldClusterArray = tempClusterArray;
tempClusterArray = new int[dataset.numInstances()];
// Step 2.1 Minimization. Assign cluster to each instance.
// 给数据集中的每一行数据找到离它最近的中心点
int tempNearestCenter;
double tempNearestDistance;
double tempDistance;
for (int i = 0; i < dataset.numInstances(); i++) {
tempNearestCenter = -1;
tempNearestDistance = Double.MAX_VALUE;
for (int j = 0; j < numClusters; j++) {
tempDistance = distance(i, tempCenters[j]);
if (tempNearestDistance > tempDistance) {
tempNearestDistance = tempDistance;
tempNearestCenter = j;
} // Of if
} // Of for j
// 保存每i行数据属于哪个中心点,即中心点数组的下标tempNearestCenter
tempClusterArray[i] = tempNearestCenter;
} // Of for i
// Step 2.2 Mean. Find new centers.
// 分好簇之后,从各个簇里取平均值找到新的中心点
tempClusterLengths = new int[numClusters];// 保存每一簇里有多少个点
Arrays.fill(tempClusterLengths, 0);
double[][] tempNewCenters = new double[numClusters][dataset.numAttributes() - 1];
// 遍历数据集,把每一簇里的每一行数据的每一列累加到对应中心点的每一列上
for (int i = 0; i < dataset.numInstances(); i++) {
for (int j = 0; j < tempNewCenters[0].length; j++) {
tempNewCenters[tempClusterArray[i]][j] += dataset.instance(i).value(j);
} // Of for j
tempClusterLengths[tempClusterArray[i]]++;
} // Of for i
// Step 2.3 Now average
// 取平均数得到新的中心点
for (int i = 0; i < tempNewCenters.length; i++) {
for (int j = 0; j < tempNewCenters[0].length; j++) {
tempNewCenters[i][j] /= tempClusterLengths[i];
} // Of for j
} // Of for i
System.out.println("Now the new centers are: " + Arrays.deepToString(tempNewCenters));
tempCenters = tempNewCenters;
} // Of while
// Step 3. Form clusters.
// 将一维数组转换成二维数组
clusters = new int[numClusters][];
int[] tempCounters = new int[numClusters];
for (int i = 0; i < numClusters; i++) {
clusters[i] = new int[tempClusterLengths[i]];// clusters数组的每一行个数等于对应中心点所属簇的个数
} // Of for i
for (int i = 0; i < tempClusterArray.length; i++) {
clusters[tempClusterArray[i]][tempCounters[tempClusterArray[i]]] = i;
tempCounters[tempClusterArray[i]]++;
} // Of for i
System.out.println("The clusters are: " + Arrays.deepToString(clusters));
}// Of clustering
/**
*******************************
* Clustering.
*******************************
*/
public static void testClustering() {
kMeans tempKMeans = new kMeans("D:/00/data/iris.arff");
tempKMeans.setNumClusters(3);
tempKMeans.clustering();
}// Of testClustering
/**
*************************
* A testing method.
*************************
*/
public static void main(String arags[]) {
testClustering();
}// Of main
}// Of class kMeans