用java实现K-means算法,k-means聚类算法原理

1、从包含多个数据点的数据集D中随机取k个点,作为k个簇的各自的中心。

2、分别计算剩下的点到k个簇中心的相异度,将这些元素分别划归到相异度最低的簇。

   两个点之间的相异度大小采用欧氏距离公式衡量,对于两个点T0(x1,y2)和T1(x2,y2),

   T0和T1之间的欧氏距离为

       d = sqrt((x1-x2)^2+(y1-y2)^2)

  欧氏距离越小,说明相异度越小

 

3、根据聚类结果,重新计算k个簇各自的中心,计算方法是取簇中所有点各自维度的算术平均数。

4、将D中全部点按照新的中心重新聚类。

5、重复第4步,直到聚类结果不再变化。

6、将结果输出。

 

举例说明,假设包含9个点数据D如下(见simple_k-means.txt),从D中随机取k个元素,

作为k个簇的各自的中心, 假设选k=2, 即将如下的9个点聚类成两个类(cluster)。

1  1

2  1

1  2

2  2

3  3

8  8

8  9

9  8

9  9

 

1.      假设选C0(1  1)和C1(2  1)前两个点作为两个类的簇心。

2.     分别计算剩下的点到k个簇中心的相异度,将这些元素分别划归到相

异度最低的簇。结果为:

C0 :1 1

C0: 的点为:1.0,2.0

C1:  2 1

C1:的点为:2.0, 2.0

C1:的点为:3.0, 3.0

C1:的点为:8.0, 8.0

C1:的点为:8.0, 9.0

C1:的点为:9.0, 8.0

C1:的点为:9.0, 9.0

 

3.     根据2的聚类结果,重新计算k个簇各自的中心,计算方法是取簇中所

有元素各自维度的算术平均数。

C0 新的簇心为:1.0,1.5

C1新的簇心为:5.857142857142857,  5.714285714285714

4.将D中全部元素按照新的中心重新聚类。

    第2次迭代

    C0:的点为:1.0, 1.0

    C0:的点为:2.0, 1.0

    C0:的点为:1.0, 2.0

    C0:的点为:2.0, 2.0

    C0:的点为:3.0, 3.0

    C1:的点为:8.0, 8.0

    C1:的点为:8.0, 9.0

    C1:的点为:9.0, 8.0

C1:的点为:9.0, 9.0

 

5.  重复第4步,直到聚类结果不再变化。

当每个簇心点前后移动的距离小于某个阈值t的时候,就认为聚类已经

结束了,不需要再迭代,这里的值选t=0.001距离计算采用欧氏距离

------------------------------------------------

C0的簇心为:1.6666666666666667,1.75

C1的簇心为:7.971428571428572, 7.942857142857143

各个簇心移动中最小的距离为,moveDistance=0.7120003121097943

第3次迭代

C0:的点为:1.0, 1.0

C0:的点为:2.0, 1.0

C0:的点为:1.0, 2.0

C0:的点为:2.0, 2.0

C0:的点为:3.0, 3.0

C1:的点为:8.0, 8.0

C1:的点为:8.0, 9.0

C1:的点为:9.0, 8.0

C1:的点为:9.0, 9.0

------------------------------------------------

C0的簇心为:1.777777777777778,1.7916666666666667

C1的簇心为:8.394285714285715,8.388571428571428

各个簇心移动中最小的距离为,moveDistance  = 0.11866671868496578

第4次迭代

C0:的点为:1.0, 1.0

C0:的点为:2.0, 1.0

C0:的点为:1.0, 2.0

C0:的点为:2.0, 2.0

C0:的点为:3.0, 3.0

C1:的点为:8.0, 8.0

C1:的点为:8.0, 9.0

C1:的点为:9.0, 8.0

C1:的点为:9.0, 9.0

------------------------------------------------

C0的簇心为:1.7962962962962965,1.7986111111111114

C1的簇心为:8.478857142857143,8.477714285714285

各个簇心移动中最小的距离为,moveDistance=0.019777786447494432

第5次迭代

C0:的点为:1.0, 1.0

C0:的点为:2.0, 1.0

C0:的点为:1.0, 2.0

C0:的点为:2.0, 2.0

C0:的点为:3.0, 3.0

C1:的点为:8.0, 8.0

C1:的点为:8.0, 9.0

C1:的点为:9.0, 8.0

C1:的点为:9.0, 9.0

------------------------------------------------

C0的簇心为:1.799382716049383,1.7997685185185184

C1的簇心为:8.495771428571429,8.495542857142857

各个簇心移动中最小的距离为,moveDistance=0.003296297741248916

第6次迭代

C0:的点为:1.0, 1.0

C0:的点为:2.0, 1.0

C0:的点为:1.0, 2.0

C0:的点为:2.0, 2.0

C0:的点为:3.0, 3.0

C1:的点为:8.0, 8.0

C1:的点为:8.0, 9.0

C1:的点为:9.0, 8.0

C1:的点为:9.0, 9.0

------------------------------------------------

C0的簇心为:1.7998971193415638,1.7999614197530864

C1的簇心为:8.499154285714287,8.499108571428572

各个簇心移动中最小的距离为,moveDistance=5.49382956874724E-4

*************************************************************************************

K_means类,代码如下:

package Kmeans;

import java.util.ArrayList;  
import java.util.Random;  
   
public class k_means {  
    private int k;// 分成多少簇  
    private int m;// 迭代次数  
    private int dataSetLength;// 数据集元素个数,即数据集的长度  
    private ArrayList dataSet;// 数据集链表  
    private ArrayList center;// 中心链表  
    private ArrayList> cluster; // 簇  
    private ArrayList jc;// 误差平方和,k越接近dataSetLength,误差越小  
    private Random random;  
   
    public void setDataSet(ArrayList dataSet) {  
    	//设置需分组的原始数据集
        this.dataSet = dataSet;  
    }  

    public ArrayList> getCluster() {  
        return cluster;  
    }  
   
    public k_means(int k) {  
    	//传入需要分成的簇数量
        if (k <= 0) {  
            k = 1;  
        }  
        this.k = k;  
    }  
  
    private void init() { 
    	//初始化
        m = 0;  
        random = new Random();  
        if (dataSet == null || dataSet.size() == 0) {   
        	System.out.println("数据为空,请输入数据!!!!");
        } else{
        	dataSetLength = dataSet.size();  
        	if (k > dataSetLength) {  
        		k = dataSetLength;  
        	}  
        	center = initCenters();  
        	cluster = initCluster();  
        	jc = new ArrayList();  
        	}
    }  
  
    private ArrayList initCenters() {  
    	//初始化中心数据链表,分成多少簇就有多少个中心点
        ArrayList center = new ArrayList();  
        int[] randoms = new int[k];  
        boolean flag;  
        int temp = random.nextInt(dataSetLength);  
        randoms[0] = temp;  
        for (int i = 1; i < k; i++) {  
            flag = true;  
            while (flag) {  
                temp = random.nextInt(dataSetLength);  
                int j = 0;  
                while (j < i) {  
                    if (temp == randoms[j]) {  
                        break;  
                    }  
                    j++;  
                }  
                if (j == i) {  
                    flag = false;  
                }  
            }  
            randoms[i] = temp;  
        }   
        for (int i = 0; i < k; i++) {  
            center.add(dataSet.get(randoms[i]));// 生成初始化中心链表  
        }  
        return center;  
    }  
  
 
    private ArrayList> initCluster() {  
    	//初始化簇集合
        ArrayList> cluster = new ArrayList>();  
        for (int i = 0; i < k; i++) {  
            cluster.add(new ArrayList());  
        }  
  
        return cluster;  
    }  
  

    private float distance(float[] element, float[] center) {  
    	//计算两个点之间的距离
        float distance = 0.0f;  
        float x = element[0] - center[0];  
        float y = element[1] - center[1];  
        float z = x * x + y * y;  
        distance = (float) Math.sqrt(z);  
  
        return distance;  
    }  
  
  
    private int minDistance(float[] distance) {  
    	 //获取距离集合中最小距离的位置
        float minDistance = distance[0];  
        int minLocation = 0;  
        for (int i = 1; i < distance.length; i++) {  
            if (distance[i] < minDistance) {  
                minDistance = distance[i];  
                minLocation = i;  
            } else if (distance[i] == minDistance) // 如果相等,随机返回一个位置  
            {  
                if (random.nextInt(10) < 5) {  
                    minLocation = i;  
                }  
            }  
        }  
  
        return minLocation;  
    }  
  
 
    private void clusterSet() {  
    	//将当前元素放到最小距离中心相关的簇中
        float[] distance = new float[k];  
        for (int i = 0; i < dataSetLength; i++) {  
            for (int j = 0; j < k; j++) {  
                distance[j] = distance(dataSet.get(i), center.get(j));  
            }  
            int minLocation = minDistance(distance);  
            cluster.get(minLocation).add(dataSet.get(i));  
  
        }  
    }  
  

    private float errorSquare(float[] element, float[] center) {
    	//求两点误差平方的方法 
        float x = element[0] - center[0];  
        float y = element[1] - center[1];  
  
        float errSquare = x * x + y * y;  
  
        return errSquare;  
    }  
  
  
    private void countRule() {  
    	//计算误差平方和准则函数方法
        float jcF = 0;  
        for (int i = 0; i < cluster.size(); i++) {  
            for (int j = 0; j < cluster.get(i).size(); j++) {  
                jcF += errorSquare(cluster.get(i).get(j), center.get(i));  
  
            }  
        }  
        jc.add(jcF);  
    }  
    private void setNewCenter() { 
    	//设置新的簇中心方法
        for (int i = 0; i < k; i++) {  
            int n = cluster.get(i).size();  
            if (n != 0) {  
                float[] newCenter = { 0, 0 };  
                for (int j = 0; j < n; j++) {  
                    newCenter[0] += cluster.get(i).get(j)[0];  
                    newCenter[1] += cluster.get(i).get(j)[1];  
                }  
                // 设置一个平均值  
                newCenter[0] = newCenter[0] / n;  
                newCenter[1] = newCenter[1] / n;  
                center.set(i, newCenter);  
            }  
        }  
    }  
    
    public void printDataArray(ArrayList dataArray,  
            String dataArrayName) { 
    	//打印数据
        for (int i = 0; i < dataArray.size(); i++) {  
            System.out.println("print:(" + dataArray.get(i)[0] + "," + dataArray.get(i)[1]+")");  
        }  
        System.out.println("===================================");  
    }  
   
    void kmeans() {  
        init();  
        // 循环分组,直到误差不变为止  
        while (true) {  
            clusterSet();  
            countRule();   
            // 误差不变了,分组完成  
            if (m != 0) {  
                if (jc.get(m) - jc.get(m - 1) == 0) {  
                    break;  
                }  
            }  
  
            setNewCenter();   
            m++;  
            cluster.clear();  
            cluster = initCluster();  
        }  
    }  
}  
ReadData类,代码如下:

package Kmeans;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;

public class ReadData {
	//从文件中读取数据
public ArrayList read(String fileName){
	ArrayList arr=new ArrayList();
	try {
		BufferedReader reader = new BufferedReader(new FileReader(fileName));
		String line = null;
        while((line=reader.readLine())!=null){
        	String str[] = line.split("\\s+");
        	float[][] point1 = new float[1][2];
        	point1[0][0]=Float.parseFloat(str[0].trim());
        	point1[0][1]=Float.parseFloat(str[1].trim());
        	arr.add(point1[0]);
        }
	}catch (FileNotFoundException e) {
		e.printStackTrace();
	}catch (IOException e) {
		e.printStackTrace();
	}
    	
	return arr;
    	
    }
}
main类,代码如下:

package Kmeans;

import java.util.ArrayList;
import java.util.Scanner;

public class main {
	public  static void main(String[] args)  
    {  
        //初始化一个Kmean对象,将k置为3
		int num;
		System.out.println("输入要分为的类数:");			
		num=(new Scanner(System.in)).nextInt();
        k_means k=new k_means(num);  
        ArrayList dataSet=new ArrayList();  
        ReadData rd=new ReadData();
        String fileName="data/11.txt";
		dataSet=rd.read(fileName);
        //设置原始数据集  
        k.setDataSet(dataSet);  
        //执行算法  
        k.kmeans();
        //得到聚类结果  
        ArrayList> cluster=k.getCluster();  
        //查看结果  
        for(int i=0;i

以上代码可能存在一些冗余的部分,没有进行修改,如果需要可以自行拷贝和修改。

你可能感兴趣的:(java)