k-means是一种最常用的聚类算法。关于k-means算法的介绍到处都能找到,并且比较容易理解。mahout里面也实现了k-means算法。下面贴出的是自己写的实现。目的是帮助大家能更清楚的认识和更快的掌握k-means算法。
import java.util.ArrayList; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.Set; /** * * @author aturbo * 1、随机选择k个点作为中心点(centroid) * 2、计算点到各个类中心的距离; * 3、将点放入最近的中心点所在的类 * 4、重新计算中心点 * 5、判断目标函数是否收敛,收敛停止,否则循环2-4步 * */ public class MyKmeans { public static final double[][] points = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 8, 8 }, { 9, 8 }, { 8, 9 }, { 9, 9 } }; /** * 随机选择k个点作为中心点 * @param k * @return k个中心点 */ private static double[][] chooseinitK(int k){ double[][] cluster = new double[k][]; Set set = new HashSet<Integer>(); for(int i = 0;i<points.length;i++){ set.add(i); } //在set中剩下的序列点就为随机选择的 for(int i = 0;i<(points.length-k);){ Random random = new Random(); int a = random.nextInt(points.length); if(!set.contains(a)) continue; set.remove(a); System.out.println("a"+a); i++; } Iterator<Integer> iterator = set.iterator(); int j =0; while(iterator.hasNext()){ cluster[j]=points[iterator.next()]; j++; } for(int i = 0;i<cluster.length;i++){ System.out.println("随机选择的k个节点:"+cluster[i][0]+"\t"+cluster[i][1]); } return cluster; } /** * 欧式距离计算公式 * @param center (中心)点 * @param otherpoint * @return 欧式距离 */ private static double eurDistance(double[] center,double[] otherpoint){ double distance=0.0; for(int i = 0;i<center.length;i++){ distance += ((center[i]-otherpoint[i])*(center[i]-otherpoint[i])); } distance = Math.sqrt(distance); return distance; } /** * 目标函数——也就是每个聚类中的点到它中心点的距离和 * @param center 中心点 * @param cluster 划分的(中间)聚类 * @return cost */ private static double cost(double[][] center,List<double[]>[] cluster){ double cost = 0.0; for(int i = 0;i<cluster.length;i++){ for(int j = 0;j<cluster[i].size();j++){ double tempCost = 0.0; for(int k = 0;k<center.length;k++){ System.out.println(cluster[i].get(j)[k]); tempCost += (cluster[i].get(j)[k]-center[i][k])*(cluster[i].get(j)[k]-center[i][k]); } cost+=Math.sqrt(tempCost); } } return cost; } /** * 聚类算法——将所有点和各中心点计算距离,将点放入最近距离点的类中 * @param points 所有点 * @param centers 中心点 * @param k * @return 聚类 */ private static List<double[]>[] returnCluster(double[][] points,double[][] centers,int k){ List[] cluster = new ArrayList[k]; for(int i = 0;i<cluster.length;i++){ cluster[i] = new ArrayList<Double[]>(); } for(double[] point:points){ double min_distance = Double.MAX_VALUE; int clusterNum = 0; int flag = 0; double distance =0.0; for(double[] center:centers){ distance = eurDistance(center, point); if(distance<min_distance){ flag = clusterNum; min_distance = distance; } clusterNum++; } cluster[flag].add(point); } return cluster; } /** * 计算类的中心点的坐标 * @param cluster 聚类 * @return */ private static double[][] countCenter(List<double[]>[] cluster){ double x = 0.0; double y = 0.0; int k = cluster.length; double[][] initk =new double[k][2]; for(int i = 0;i<cluster.length;i++){ for(int j = 0;j<cluster[i].size();j++){ x += cluster[i].get(j)[0]; y += cluster[i].get(j)[1]; } x = x/cluster[i].size(); y = y/cluster[i].size(); initk[i][0]=x; initk[i][1]=y; } return initk; } public static void main(String[] args){ int k = 2; double[][] initk = chooseinitK(k); double minCost = Double.MAX_VALUE; double tempCost = Double.MAX_VALUE; List[] cluster; do{ minCost = tempCost; cluster = returnCluster(points, initk, k); initk = countCenter(cluster); tempCost = cost(initk, cluster); }while(tempCost<minCost);//当目标函数收敛后,停止 } }