k-means算法的java实现

      k-means是一种最常用的聚类算法。关于k-means算法的介绍到处都能找到,并且比较容易理解。mahout里面也实现了k-means算法。下面贴出的是自己写的实现。目的是帮助大家能更清楚的认识和更快的掌握k-means算法。

   

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;

/**
 * 
 * @author aturbo
 * 1、随机选择k个点作为中心点(centroid)
 * 2、计算点到各个类中心的距离;
 * 3、将点放入最近的中心点所在的类
 * 4、重新计算中心点
 * 5、判断目标函数是否收敛,收敛停止,否则循环2-4步
 * 
 */
public class MyKmeans {

	public static final double[][] points = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 8, 8 }, { 9, 8 },
			{ 8, 9 }, { 9, 9 } };
	
	/**
	 * 随机选择k个点作为中心点
	 * @param k 
	 * @return k个中心点
	 */
	private static double[][] chooseinitK(int k){
		double[][] cluster = new double[k][];
		Set set = new HashSet<Integer>();
		for(int i = 0;i<points.length;i++){
			set.add(i);
		}
		//在set中剩下的序列点就为随机选择的
		for(int i = 0;i<(points.length-k);){
			Random random = new Random();
			int a = random.nextInt(points.length);
			if(!set.contains(a))
				continue;
			   set.remove(a);
			System.out.println("a"+a);
			i++;
		}
		Iterator<Integer> iterator = set.iterator();
		int j =0;
		while(iterator.hasNext()){
			cluster[j]=points[iterator.next()];
			j++;
		}
		for(int i = 0;i<cluster.length;i++){
			System.out.println("随机选择的k个节点:"+cluster[i][0]+"\t"+cluster[i][1]);
		}
		return cluster;
	}
	/**
	 * 欧式距离计算公式
	 * @param center (中心)点
	 * @param otherpoint 
	 * @return 欧式距离
	 */
	private static double eurDistance(double[] center,double[] otherpoint){
		double distance=0.0;
		for(int i = 0;i<center.length;i++){
			distance += ((center[i]-otherpoint[i])*(center[i]-otherpoint[i]));
		}
		distance = Math.sqrt(distance);
		return distance;
	}
	/**
	 * 目标函数——也就是每个聚类中的点到它中心点的距离和
	 * @param center 中心点
	 * @param cluster 划分的(中间)聚类
	 * @return cost
	 */
	private static double cost(double[][] center,List<double[]>[] cluster){
		double cost = 0.0;
		for(int i = 0;i<cluster.length;i++){
			for(int j = 0;j<cluster[i].size();j++){
				double tempCost = 0.0;
				for(int k = 0;k<center.length;k++){
					System.out.println(cluster[i].get(j)[k]);
					tempCost += (cluster[i].get(j)[k]-center[i][k])*(cluster[i].get(j)[k]-center[i][k]);
				}
				cost+=Math.sqrt(tempCost);
			}
		}
		return cost;
	}
	/**
	 * 聚类算法——将所有点和各中心点计算距离,将点放入最近距离点的类中
	 * @param points 所有点
	 * @param centers 中心点
	 * @param k
	 * @return 聚类
	 */
	private static List<double[]>[] returnCluster(double[][] points,double[][] centers,int k){
		List[] cluster = new ArrayList[k];
		for(int i = 0;i<cluster.length;i++){
			cluster[i] = new ArrayList<Double[]>();
		}
		for(double[] point:points){
			double min_distance = Double.MAX_VALUE;
			int clusterNum = 0;
			int flag =  0;
			double distance =0.0;
			for(double[] center:centers){
				distance = eurDistance(center, point);
				if(distance<min_distance){
					flag = clusterNum;
					min_distance = distance;
				}		
				clusterNum++;
			}
			cluster[flag].add(point);
		}
		return cluster;
	}
	/**
	 * 计算类的中心点的坐标
	 * @param cluster 聚类
	 * @return 
	 */
	private static double[][] countCenter(List<double[]>[] cluster){
		double x = 0.0;
		double y = 0.0;
		int k = cluster.length;
		double[][] initk =new  double[k][2];
		for(int i = 0;i<cluster.length;i++){
		   for(int j = 0;j<cluster[i].size();j++){
			   x += cluster[i].get(j)[0];
			   y += cluster[i].get(j)[1];
		   }
		   x = x/cluster[i].size();
		   y = y/cluster[i].size();
		   initk[i][0]=x;
		   initk[i][1]=y;
		}
		return initk;
	}
	
	public static void main(String[] args){
		int k = 2;
		double[][] initk = chooseinitK(k);
		
		double minCost = Double.MAX_VALUE;
		double tempCost = Double.MAX_VALUE;
		List[] cluster;
		do{
			minCost = tempCost;
			cluster = returnCluster(points, initk, k);
			initk = countCenter(cluster);
			tempCost = cost(initk, cluster);
		}while(tempCost<minCost);//当目标函数收敛后,停止
		
		
	}
}

你可能感兴趣的:(java,数据挖掘,聚类,k-means)