Kmeans聚类

k-means聚类

之前看了好多关于kmeans聚类的资料,大致意思稍微了解了,但是自己写程序还是编不出来。最主要的原因是自己的编程能力太差了。本科时候没有怎么写过代码,java话说是使用率很高的编程语言,我只能说我只是能看懂简单的代码。C也就是个计算机二级、三级的水平吧。计算机四级考了两次都没过。。。废话少说扯正题。


算法流程:

1.随机选取k个点作为初始质心

2.计算所有数据到初始k个质心的距离,将数据分配到与某个质心最近的类,得到k个类

3.重新计算每个类的质心

4.重复2、3步直到质心不变为止。(或者到达迭代的最大次数)


下面举个例子更好的说明问题

Kmeans聚类_第1张图片


下面我就不手工计算了,上面是先手工写好,再公式编辑器写的,截的图。

关于k-means聚类的说明:

1.聚类中的k,要自己输入,在不知道聚多少类的情况下,只有自己多次测试,再检测聚类的效果了

2.聚类的好坏受到初始质心的影响,容易出现局部最优,可以多运行几次程序

java程序:

package cluster.kmeans;

import java.util.Random;
import java.util.ArrayList;

public class Kmeans {
	private int k;// 参数k

	private int m;// 迭代次数m

	private int dataSetLength; // 数据集长度,即数据集个数

	private ArrayList<Float[]> dataSet; // 数据集链表

	private ArrayList<Float[]> center;// 中心链表

	private ArrayList<ArrayList<Float[]>> cluster;

	private ArrayList<Float> jc;// 误差平方合

	private Random random;

	public Kmeans(int k) {
		this.k = k;
	}

	public void init() {
		m = 0;
		random = new Random();
		dataSet = initDataSet();// 初始化数据集

	}
	/**
	 * 初始化数据集方法
	 * @return
	 */

	public ArrayList<Float[]> initDataSet() {
		ArrayList<Float[]> dataSet = new ArrayList<Float[]>();
		//15
		Float[][] dataSetArray = new Float[][] {
				{8f, 2f}, {3f, 4f}, {2f, 5f}, {4f, 2f}, {7f, 3f},
				{6f, 2f}, {4f, 7f}, {6f, 3f}, {5f, 3f}, {6f, 3f},
				{6f, 9f}, {1f, 6f}, {3f, 9f}, {4f, 1f}, {8f, 6f}
		};
		dataSetLength=dataSetArray.length;
		for(int i = 0; i<dataSetArray.length;i++)
		{
			dataSet.add(dataSetArray[i]);
		}
		//System.out.println(dataSet.toArray());
		return dataSet;
	}
	/**
	 * 初始化中心方法
	 */
	public ArrayList<Float[]> initCenter()
	{
		ArrayList<Float[]> center = new ArrayList<Float[]>();//给center分配内存
		int[] randoms = new int[k];//定义一个一维数组randoms,数组length=3
		boolean flag ;
		int temp = random.nextInt(dataSetLength);//随机生成一个0到dataSetLength的int型数值,此处为0~15
		randoms[0] = temp;//把随机数传给randoms[0]
		for(int i = 1;i <k;i++)//执行2次 参数k个随机数,下面是判断使其互不相等
		{
			flag = true;
			while (flag)
			{
				temp=random.nextInt(dataSetLength);//生成一个0~15的随机数交给temp
				int j;
				for(j=0;j<i;j++)
				{
					if(temp == randoms[j])//如果此随机数等于randoms[0]
					{
						break;//跳出循环
					}
				}if(j==i)
				{
					flag=false;
				}
			}
			randoms[i] = temp;
		}
//		
//		//测试随机数组生成情况
//		for(int i = 0; i < k; i ++){
//			System.out.println("test1: randoms[" + i + "] = " + randoms[i]);
//		}
		System.out.println();
		
		for (int i = 0; i < k; i ++){
			center.add(dataSet.get(randoms[i])); //生成初始化中心链表
		}
		return center;
	}
	public ArrayList<Float> initJc()
	{
		
		ArrayList<Float> jc = new ArrayList<Float>();
		
		return jc;
	}
	/*
	 * 初始化簇集方法
	 */
	public ArrayList<ArrayList<Float[]>> initCluster(){
		ArrayList<ArrayList<Float[]>> cluster = new ArrayList<ArrayList<Float[]>>();
		for(int i = 0; i < k; i ++){
			cluster.add(new ArrayList<Float[]>());
		}
		return cluster;
	}
	
	/*
	 * 求距离方法
	 */
	public float distance(Float[] element, Float[] center){
		float distance = 0.0f;
		float x = element[0] - center[0];
		float y = element[1] - center[1];
		float z = x*x + y*y;
		distance = (float) Math.sqrt(z);
		return distance;
	}
	
	/*
	 * 求最小距离位置方法
	 */
	public int minDistance(float distance[]){
		float minDistance = distance[0];
		int minLocation = 0;
		for(int i = 1; i < distance.length; i ++){
			if(distance[i] < minDistance){
				minDistance = distance[i];
				minLocation = i;
			}else if(distance[i] == minDistance){ //如果和当前最短距离相等,则随机选取一个
				if(random.nextInt(10) < 5){
					minLocation = i;
				}
			}
		}
		return minLocation;
	}
	
	/**
	 * 生成簇集元素方法
	 */
	public void clusterSet(){
		float[] distance = new float[k];
		for(int i = 0; i < dataSetLength; i ++){
			for(int j = 0; j < k; j ++){
				distance[j] =distance(dataSet.get(i), center.get(j));
				//System.out.println("test2: " + "dataSet[" + i + "], centers[" + j + "], distance = " + distance[j]); //测试元素与中心距离
			}
			int minLocation = minDistance(distance);
			cluster.get(minLocation).add(dataSet.get(i)); //核心:将当前元素放到最小距离中心相关的簇中
		}
	}
	
	/*
	 * 求误差平方的方法
	 */
	public float errorSquare(Float[] element, Float[] center){
		float x = element[0] - center[0];
		float y = element[1] - center[1];
		float errorSquare = x*x + y*y;
		return errorSquare;
	}
	
	/**
	 * 计算误差平方和准则函数方法
	 */
	public void countRule(){
		float JcF = 0;
		for(int i = 0; i < cluster.size(); i ++){//即i<3
			for(int j = 0; j < cluster.get(i).size(); j ++){//执行的第i个簇的长度的次数
				JcF += errorSquare(cluster.get(i).get(j), center.get(i));//jcf=jcf+(聚类的点x-质心x)的平方+(聚类的点y-质心y)的平方
			}
		}
		jc.add(JcF);
	}
	
	/**
	 * 计算新的簇中心方法
	 */
	public void findNewCenter(){
		for(int i = 0; i < k; i ++){//k=3依次运算3个质心
			int n = cluster.get(i).size();//得到第i个簇的数量 (n=5),(一共3个蔟)
			if(n != 0){//如果该蔟的点的个书不为空
				Float[] newCenter = {0.0f, 0.0f};//声明一个新的质心点(默认坐标为0,0)
				for(int j = 0; j < n; j ++){//依次对n个点做循环
					newCenter[0] += cluster.get(i).get(j)[0];//新的质心的x坐标为第i个蔟的所有点的x坐标的和
					newCenter[1] += cluster.get(i).get(j)[1];//新的质心的y坐标为第i个蔟的所有点的y坐标的和
				}
				newCenter[0] = newCenter[0] / n;//求质心
				newCenter[1] = newCenter[1] / n;//求质心
				center.set(i, newCenter);//重新设第i个蔟的质心为新的newCenter
			}
		}
	}
	
	/*
	 * 打印数据数组
	 */
	public void printDataArray(ArrayList<Float[]> dataArray, String dataArrayName){
		for(int i = 0; i < dataArray.size(); i ++){
			System.out.println("print: " + dataArrayName + "[" + i + "] = {" + dataArray.get(i)[0] + ", " + dataArray.get(i)[1] + "}");
		}
		System.out.println();
	}
	
	
	/*
	 * Kmeans算法核心过程方法 
	 */
	public void kmeans(){
		init(); //初始化
		center=initCenter();
		cluster=initCluster();
		jc=initJc();
		printDataArray(dataSet, "initDataSet"); //输出初始化数据集
		printDataArray(center, "initCenter"); //输出初始化中心
		while(true){
			clusterSet(); //生成簇集元素
			for (int i = 0; i < cluster.size(); i++) {
				printDataArray(cluster.get(i), "cluster[" + i + "]"); //输出簇集生成结果
			}
			countRule(); //计算误差平方和
			System.out.println("count:" + "Jc[" + m + "] = " + jc.get(m)); //输出误差平方和
			System.out.println();
			//判断退出迭代条件
			if(m != 0){//如果m不为0而且jc的变化为0则跳出循环
				if (jc.get(m) - jc.get(m - 1) == 0) {
					break;
				}
			}
			findNewCenter(); //计算新的中心
			printDataArray(center, "newCenter"); //输出新的中心
			m++;//m+1
			cluster.clear(); //簇集清空
			cluster = initCluster(); //簇集初始化
		}
		System.out.println("note: the times of repeat: m = " + m); //输出迭代次数
	}

	/*
	 * 主函数
	 */
	public static void main(String[] args){
		long startTime = System.currentTimeMillis(); //获取开始时间
		System.out.println("note: program begins.");
		Kmeans myKmeans = new Kmeans(3);
		myKmeans.kmeans(); //调用Kmeans核心方法
		long endTime = System.currentTimeMillis(); //获取结束时间
		System.out.println("note: running time = " +(endTime - startTime) + "ms.");
		System.out.println("note: program ends.");
	}
}

程序来源于网络,自己编程能力太差,实在写不出来,信息与计算科学专业的孩子就只能用笔算算简单的数据了。。。。。。




你可能感兴趣的:(数据挖掘,机器学习,kmeans聚类)