1、从包含多个数据点的数据集D中随机取k个点,作为k个簇的各自的中心。
2、分别计算剩下的点到k个簇中心的相异度,将这些元素分别划归到相异度最低的簇。
两个点之间的相异度大小采用欧氏距离公式衡量,对于两个点T0(x1,y2)和T1(x2,y2),
T0和T1之间的欧氏距离为
d = sqrt((x1-x2)^2+(y1-y2)^2)
欧氏距离越小,说明相异度越小
3、根据聚类结果,重新计算k个簇各自的中心,计算方法是取簇中所有点各自维度的算术平均数。
4、将D中全部点按照新的中心重新聚类。
5、重复第4步,直到聚类结果不再变化。
6、将结果输出。
举例说明,假设包含9个点数据D如下(见simple_k-means.txt),从D中随机取k个元素,
作为k个簇的各自的中心, 假设选k=2, 即将如下的9个点聚类成两个类(cluster)。
1 1
2 1
1 2
2 2
3 3
8 8
8 9
9 8
9 9
1. 假设选C0(1 1)和C1(2 1)前两个点作为两个类的簇心。
2. 分别计算剩下的点到k个簇中心的相异度,将这些元素分别划归到相
异度最低的簇。结果为:
C0 :1 1
C0: 的点为:1.0,2.0
C1: 2 1
C1:的点为:2.0, 2.0
C1:的点为:3.0, 3.0
C1:的点为:8.0, 8.0
C1:的点为:8.0, 9.0
C1:的点为:9.0, 8.0
C1:的点为:9.0, 9.0
3. 根据2的聚类结果,重新计算k个簇各自的中心,计算方法是取簇中所
有元素各自维度的算术平均数。
C0 新的簇心为:1.0,1.5
C1新的簇心为:5.857142857142857, 5.714285714285714
4.将D中全部元素按照新的中心重新聚类。
第2次迭代
C0:的点为:1.0, 1.0
C0:的点为:2.0, 1.0
C0:的点为:1.0, 2.0
C0:的点为:2.0, 2.0
C0:的点为:3.0, 3.0
C1:的点为:8.0, 8.0
C1:的点为:8.0, 9.0
C1:的点为:9.0, 8.0
C1:的点为:9.0, 9.0
5. 重复第4步,直到聚类结果不再变化。
当每个簇心点前后移动的距离小于某个阈值t的时候,就认为聚类已经
结束了,不需要再迭代,这里的值选t=0.001,距离计算采用欧氏距离
------------------------------------------------
C0的簇心为:1.6666666666666667,1.75
C1的簇心为:7.971428571428572, 7.942857142857143
各个簇心移动中最小的距离为,moveDistance=0.7120003121097943
第3次迭代
C0:的点为:1.0, 1.0
C0:的点为:2.0, 1.0
C0:的点为:1.0, 2.0
C0:的点为:2.0, 2.0
C0:的点为:3.0, 3.0
C1:的点为:8.0, 8.0
C1:的点为:8.0, 9.0
C1:的点为:9.0, 8.0
C1:的点为:9.0, 9.0
------------------------------------------------
C0的簇心为:1.777777777777778,1.7916666666666667
C1的簇心为:8.394285714285715,8.388571428571428
各个簇心移动中最小的距离为,moveDistance = 0.11866671868496578
第4次迭代
C0:的点为:1.0, 1.0
C0:的点为:2.0, 1.0
C0:的点为:1.0, 2.0
C0:的点为:2.0, 2.0
C0:的点为:3.0, 3.0
C1:的点为:8.0, 8.0
C1:的点为:8.0, 9.0
C1:的点为:9.0, 8.0
C1:的点为:9.0, 9.0
------------------------------------------------
C0的簇心为:1.7962962962962965,1.7986111111111114
C1的簇心为:8.478857142857143,8.477714285714285
各个簇心移动中最小的距离为,moveDistance=0.019777786447494432
第5次迭代
C0:的点为:1.0, 1.0
C0:的点为:2.0, 1.0
C0:的点为:1.0, 2.0
C0:的点为:2.0, 2.0
C0:的点为:3.0, 3.0
C1:的点为:8.0, 8.0
C1:的点为:8.0, 9.0
C1:的点为:9.0, 8.0
C1:的点为:9.0, 9.0
------------------------------------------------
C0的簇心为:1.799382716049383,1.7997685185185184
C1的簇心为:8.495771428571429,8.495542857142857
各个簇心移动中最小的距离为,moveDistance=0.003296297741248916
第6次迭代
C0:的点为:1.0, 1.0
C0:的点为:2.0, 1.0
C0:的点为:1.0, 2.0
C0:的点为:2.0, 2.0
C0:的点为:3.0, 3.0
C1:的点为:8.0, 8.0
C1:的点为:8.0, 9.0
C1:的点为:9.0, 8.0
C1:的点为:9.0, 9.0
------------------------------------------------
C0的簇心为:1.7998971193415638,1.7999614197530864
C1的簇心为:8.499154285714287,8.499108571428572
各个簇心移动中最小的距离为,moveDistance=5.49382956874724E-4
*************************************************************************************
K_means类,代码如下:
package Kmeans;
import java.util.ArrayList;
import java.util.Random;
public class k_means {
private int k;// 分成多少簇
private int m;// 迭代次数
private int dataSetLength;// 数据集元素个数,即数据集的长度
private ArrayList dataSet;// 数据集链表
private ArrayList center;// 中心链表
private ArrayList> cluster; // 簇
private ArrayList jc;// 误差平方和,k越接近dataSetLength,误差越小
private Random random;
public void setDataSet(ArrayList dataSet) {
//设置需分组的原始数据集
this.dataSet = dataSet;
}
public ArrayList> getCluster() {
return cluster;
}
public k_means(int k) {
//传入需要分成的簇数量
if (k <= 0) {
k = 1;
}
this.k = k;
}
private void init() {
//初始化
m = 0;
random = new Random();
if (dataSet == null || dataSet.size() == 0) {
System.out.println("数据为空,请输入数据!!!!");
} else{
dataSetLength = dataSet.size();
if (k > dataSetLength) {
k = dataSetLength;
}
center = initCenters();
cluster = initCluster();
jc = new ArrayList();
}
}
private ArrayList initCenters() {
//初始化中心数据链表,分成多少簇就有多少个中心点
ArrayList center = new ArrayList();
int[] randoms = new int[k];
boolean flag;
int temp = random.nextInt(dataSetLength);
randoms[0] = temp;
for (int i = 1; i < k; i++) {
flag = true;
while (flag) {
temp = random.nextInt(dataSetLength);
int j = 0;
while (j < i) {
if (temp == randoms[j]) {
break;
}
j++;
}
if (j == i) {
flag = false;
}
}
randoms[i] = temp;
}
for (int i = 0; i < k; i++) {
center.add(dataSet.get(randoms[i]));// 生成初始化中心链表
}
return center;
}
private ArrayList> initCluster() {
//初始化簇集合
ArrayList> cluster = new ArrayList>();
for (int i = 0; i < k; i++) {
cluster.add(new ArrayList());
}
return cluster;
}
private float distance(float[] element, float[] center) {
//计算两个点之间的距离
float distance = 0.0f;
float x = element[0] - center[0];
float y = element[1] - center[1];
float z = x * x + y * y;
distance = (float) Math.sqrt(z);
return distance;
}
private int minDistance(float[] distance) {
//获取距离集合中最小距离的位置
float minDistance = distance[0];
int minLocation = 0;
for (int i = 1; i < distance.length; i++) {
if (distance[i] < minDistance) {
minDistance = distance[i];
minLocation = i;
} else if (distance[i] == minDistance) // 如果相等,随机返回一个位置
{
if (random.nextInt(10) < 5) {
minLocation = i;
}
}
}
return minLocation;
}
private void clusterSet() {
//将当前元素放到最小距离中心相关的簇中
float[] distance = new float[k];
for (int i = 0; i < dataSetLength; i++) {
for (int j = 0; j < k; j++) {
distance[j] = distance(dataSet.get(i), center.get(j));
}
int minLocation = minDistance(distance);
cluster.get(minLocation).add(dataSet.get(i));
}
}
private float errorSquare(float[] element, float[] center) {
//求两点误差平方的方法
float x = element[0] - center[0];
float y = element[1] - center[1];
float errSquare = x * x + y * y;
return errSquare;
}
private void countRule() {
//计算误差平方和准则函数方法
float jcF = 0;
for (int i = 0; i < cluster.size(); i++) {
for (int j = 0; j < cluster.get(i).size(); j++) {
jcF += errorSquare(cluster.get(i).get(j), center.get(i));
}
}
jc.add(jcF);
}
private void setNewCenter() {
//设置新的簇中心方法
for (int i = 0; i < k; i++) {
int n = cluster.get(i).size();
if (n != 0) {
float[] newCenter = { 0, 0 };
for (int j = 0; j < n; j++) {
newCenter[0] += cluster.get(i).get(j)[0];
newCenter[1] += cluster.get(i).get(j)[1];
}
// 设置一个平均值
newCenter[0] = newCenter[0] / n;
newCenter[1] = newCenter[1] / n;
center.set(i, newCenter);
}
}
}
public void printDataArray(ArrayList dataArray,
String dataArrayName) {
//打印数据
for (int i = 0; i < dataArray.size(); i++) {
System.out.println("print:(" + dataArray.get(i)[0] + "," + dataArray.get(i)[1]+")");
}
System.out.println("===================================");
}
void kmeans() {
init();
// 循环分组,直到误差不变为止
while (true) {
clusterSet();
countRule();
// 误差不变了,分组完成
if (m != 0) {
if (jc.get(m) - jc.get(m - 1) == 0) {
break;
}
}
setNewCenter();
m++;
cluster.clear();
cluster = initCluster();
}
}
}
ReadData类,代码如下:
package Kmeans;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
public class ReadData {
//从文件中读取数据
public ArrayList read(String fileName){
ArrayList arr=new ArrayList();
try {
BufferedReader reader = new BufferedReader(new FileReader(fileName));
String line = null;
while((line=reader.readLine())!=null){
String str[] = line.split("\\s+");
float[][] point1 = new float[1][2];
point1[0][0]=Float.parseFloat(str[0].trim());
point1[0][1]=Float.parseFloat(str[1].trim());
arr.add(point1[0]);
}
}catch (FileNotFoundException e) {
e.printStackTrace();
}catch (IOException e) {
e.printStackTrace();
}
return arr;
}
}
main类,代码如下:
package Kmeans;
import java.util.ArrayList;
import java.util.Scanner;
public class main {
public static void main(String[] args)
{
//初始化一个Kmean对象,将k置为3
int num;
System.out.println("输入要分为的类数:");
num=(new Scanner(System.in)).nextInt();
k_means k=new k_means(num);
ArrayList dataSet=new ArrayList();
ReadData rd=new ReadData();
String fileName="data/11.txt";
dataSet=rd.read(fileName);
//设置原始数据集
k.setDataSet(dataSet);
//执行算法
k.kmeans();
//得到聚类结果
ArrayList> cluster=k.getCluster();
//查看结果
for(int i=0;i