聚类属于无监督学习,即样本事先并没有被打上标注,算法根据样本之间的相似度或是距离将其归类。
常用的衡量相似度或距离的有以下概念:
KMeans 聚类是将 n \textit{n} n 个样本分到 k \textit{k} k 个不相交的集合中,每个样本只能属于一个类。
算法流程如下:
Step 1. 随机选 k \textit{k} k 个样本点为聚类中心。
Step 2. 对每个样本点,计算它到各个中心的距离,将其指派到与其最近的中心那一簇。
Step 3. 对于第二步得到的分簇结果,对每一簇计算簇内样本的均值,作为新的聚类中心。
Step 4. 反复进行 2,3两步,直到分簇结果不再变化,则算法结束。
1. 成员变量
/**
* Manhattan distance.
*/
public static final int MANHATTAN = 0;
/**
* Euclidean distance.
*/
public static final int EUCLIDEAN = 1;
/**
* The distance measure.
*/
public int distanceMeasure = EUCLIDEAN;
/**
* A random instance.
*/
public static final Random random = new Random();
/**
* The data.
*/
Instances dataset;
/**
* The number of clusters.
*/
int numClusters = 2;
/**
* The clusters.
*/
int[][] clusters;
2. 构造器
读入数据
/**
*************************
* The first constructor.
*
* @param paraFilename The data filename.
*************************
*/
public KMeans(String paraFilename) {
dataset = null;
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception ee) {
System.out.println("Cannot read the file: " + paraFilename + "\r\n" + ee);
System.exit(0);
} // Of try
} // Of the first constructor
3.设定分几簇
/**
*************************
* A setter.
*************************
*/
public void setNumClusters(int paraNumClusters) {
numClusters = paraNumClusters;
} // Of the setter
4.打乱下标
这个方法将 [ 0,paraLength ) 内的整数打乱,用来随机的选定第一轮的聚类中心。在 k \textit{k} k-NN 中也用来打乱数据。
/**
*************************
* Get a random indices for data randomization.
*
* @param paraLength The length of the sequence.
* @return An array of indices.
*************************
*/
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
// Step 1. Initialize.
for(int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
} // Of for i
// Step 2. Randomly swap.
int tempFirst, tempSecond, tempValue;
for(int i = 0; i < paraLength; i++) {
// Generate two random indices.
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
// Swap.
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
} // Of for i
return resultIndices;
} // Of getRandomIndices
5. 求两个样本点之间的距离
/**
*************************
* The distance between two instances.
*
* @param paraI The index of the first instance.
* @param paraArray The array representing a point in the space.
* @return The distance.
*************************
*/
public double distance(int paraI, double[] paraArray) {
int resultDistance = 0;
double tempDifference;
switch (distanceMeasure) {
case MANHATTAN:
for(int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
if(tempDifference < 0) {
resultDistance -= tempDifference;
} else {
resultDistance += tempDifference;
} // Of if
} // Of for i
break;
case EUCLIDEAN:
for(int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
resultDistance += tempDifference * tempDifference;
} // Of for i
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
} // Of switch
return resultDistance;
} // Of distance
6. 核心代码:聚类
/**
*************************
* Clustering.
*************************
*/
public void clustering() {
int[] tempOldClusterArray = new int[dataset.numInstances()];
tempOldClusterArray[0] = -1;
int[] tempClusterArray = new int[dataset.numInstances()];
Arrays.fill(tempClusterArray, 0);
double[][] tempCenters = new double[numClusters][dataset.numAttributes() - 1];
// Step 1. Initialize centers.
int[] tempRandomOrders = getRandomIndices(dataset.numInstances());
for(int i = 0; i < numClusters; i++) {
for(int j = 0; j < tempCenters[0].length; j++) {
tempCenters[i][j] = dataset.instance(tempRandomOrders[i]).value(j);
} // Of for j
} // Of for i
int[] tempClusterLengths = null;
while(!Arrays.equals(tempOldClusterArray, tempClusterArray)) {
System.out.println("New loop ...");
tempOldClusterArray = tempClusterArray;
tempClusterArray = new int[dataset.numInstances()];
// Step 2.1 Minimization. Assign cluster to each instance.
int tempNearestCenter;
double tempNearestDistance;
double tempDistance;
for(int i = 0; i < dataset.numInstances(); i++) {
tempNearestCenter = -1;
tempNearestDistance = Double.MAX_VALUE;
for(int j = 0; j < numClusters; j++) {
tempDistance = distance(i, tempCenters[j]);
if(tempNearestDistance > tempDistance) {
tempNearestDistance = tempDistance;
tempNearestCenter = j;
} // Of if
} // Of for j
tempClusterArray[i] = tempNearestCenter;
} // Of for i
// Step 2.2 Mean. Find new centers.
tempClusterLengths = new int[numClusters];
Arrays.fill(tempClusterLengths, 0);
double[][] tempNewCenters = new double[numClusters][dataset.numAttributes() - 1];
for(int i = 0; i < dataset.numInstances(); i++) {
for(int j = 0; j < tempNewCenters[0].length; j++) {
tempNewCenters[tempClusterArray[i]][j] += dataset.instance(i).value(j);
} // Of for j
tempClusterLengths[tempClusterArray[i]]++;
} // Of for i
// Step 2.3 Now average.
for(int i = 0; i < tempNewCenters.length; i++) {
for(int j = 0; j < tempNewCenters[0].length; j++) {
tempNewCenters[i][j] /= tempClusterLengths[i];
} // Of for j
} // Of for i
System.out.println("Now the new centers are: " + Arrays.deepToString(tempNewCenters));
tempCenters = tempNewCenters;
} // Of while
// Step 3. Form clusters.
clusters = new int[numClusters][];
int[] tempCounters = new int[numClusters];
for(int i = 0; i < numClusters; i++) {
clusters[i] = new int[tempClusterLengths[i]];
} // Of for i
for(int i = 0; i < tempClusterArray.length; i++) {
clusters[tempClusterArray[i]][tempCounters[tempClusterArray[i]]] = i;
tempCounters[tempClusterArray[i]]++;
} // Of for i
System.out.println("The clusters are: " + Arrays.deepToString(clusters));
} // Of clustering
7. 测试
/**
*************************
* Clustering.
*************************
*/
public static void testClustering() {
KMeans tempKMeans = new KMeans("G:/Program Files/Weka-3-8-6/data/iris.arff");
tempKMeans.setNumClusters(3);
tempKMeans.clustering();
} // Of testClustering
/**
*************************
* A testing method.
*************************
*/
public static void main(String args[]) {
testClustering();
} // Of main
对每个虚拟中心,求其与各样本点的距离,将它更新为离其最近的样本点即可。
// Step 2.4 Get actual center.
for(int i = 0; i < numClusters; i++) {
tempNearestCenter = -1;
tempNearestDistance = Double.MAX_VALUE;
for(int j = 0; j < dataset.numInstances(); j++) {
tempDistance = distance(j, tempNewCenters[i]);
if(tempNearestDistance > tempDistance) {
tempNearestDistance = tempDistance;
tempNearestCenter = j;
} // Of if
} // Of for j
for(int k = 0; k < dataset.numAttributes() - 1; k++) {
tempNewCenters[i][k] = dataset.instance(tempNearestCenter).value(k);
} // Of for k
} // Of for i