前几天突然在网上看到关于k-means算法的一个实现,不过写错,当我想回答时,贴子以关闭,现在我就在此把正确的贴出来,有错的还请指教:
package com.lele; import java.util.ArrayList; import java.util.List; public class Kmeans { private static int K = 5; // 类数(簇) 此程序为2 private static int TOTAL = 20; // 点个数 此程序为20 private int test = 0; private Point[] unknown = new Point[TOTAL]; // 点数组 private int[] type = new int[TOTAL]; // 每个点暂时的类(簇) private Point[] z = new Point[K];// 保存新的聚类中心 private Point[] z0 = new Point[K]; // 保存上一次的聚类中心 private Point sum=null; private double[] Distance=new double[TOTAL]; private int temp = 0; private int I = 0; // 迭代次数 /** Creates a new instance of Kmeans */ public Kmeans() { /** 进行聚类运算的20个点 */ unknown[0] = new Point(5,5); unknown[1] = new Point(7,9); unknown[2] = new Point(0,1); unknown[3] = new Point(1,1); unknown[4] = new Point(2,1); unknown[5] = new Point(1,2); unknown[6] = new Point(2,2); unknown[7] = new Point(3,2); unknown[8] = new Point(6,6); unknown[9] = new Point(7,6); unknown[10] = new Point(8,6); unknown[11] = new Point(6,7); unknown[12] = new Point(7,7); unknown[13] = new Point(8,7); unknown[14] = new Point(9,7); unknown[15] = new Point(7,8); unknown[16] = new Point(8,8); unknown[17] = new Point(9,8); unknown[18] = new Point(8,9); unknown[19] = new Point(9,9); for(int i = 0;i < TOTAL; i++){ type[i] = 0; } for(int i = 0; i < K; i++){ z[i] = unknown[i]; // 伪随机选取 z0[i] = new Point(0.0,0.0); } } /** 计算新的聚类中心 */ public Point newCenter(int m){ int n = 0; sum=new Point(0,0); for(int i = 0;i < TOTAL; i++){ if(type[i] == m){ sum.setX(sum.getX() + unknown[i].getX()); sum.setY(sum.getY() + unknown[i].getY()); n += 1; } } sum.setX(sum.getX() / n); sum.setY(sum.getY() / n); System.out.println("第"+m+"类有"+n+"个"); return sum; } /** 比较两个聚类中心是否相等 */ public boolean isEqual(Point p1,Point p2){ System.out.println(p1.getX()+"**********"+p2.getX()); System.out.println(p1.getY()+"**********"+p2.getY()); if(Double.doubleToLongBits(p1.getX()) == Double.doubleToLongBits(p2.getX()) && Double.doubleToLongBits( p1.getY()) == Double.doubleToLongBits(p2.getY())) return true; else return false; } /** 计算两点之间的欧式距离 */ public static double distance(Point p1,Point p2){ return (p1.getX() - p2.getX()) * (p1.getX() - p2.getX()) + (p1.getY() - p2.getY()) * (p1.getY() - p2.getY()); } /** 进行迭代,对TOTAL个样本根据聚类中心进行分类 */ public void order(){ int temp=0; for(int i = 0; i < TOTAL;i++){ for(int j = 0; j < K;j++) if(distance(unknown[i],z[temp])>distance(unknown[i],z[j])) temp=j; type[i]=temp; System.out.println(unknown[i].toString()+"被归为"+temp); } } public void main(){ System.out.println("共有如下个未知样本:"); for(int i = 0; i < TOTAL;i++){ System.out.println(unknown[i]); //System.out.println("初始时,第" + i + "类中心:" + z[i].toString()); } for(int i = 0; i < K;i++) System.out.println("初始时,第" + i + "类中心:" + z[i].toString()); while(test < K){ System.out.println("current test:"+test); order(); for(int i = 0; i < K;i ++){ z[i] = newCenter(i); System.out.println("第" + i + " 类新中心:" + z[i].toString()); if(isEqual(z[i],z0[i])) test += 1; else z0[i] = z[i]; } I += 1; System.out.println("已完成第" + I + "次迭代"); System.out.println("分类后有:"); for(int j = 0;j < K;j++){ System.out.println("第" + j + "类分类有: "); for(int i = 0;i < TOTAL;i++){ if(type[i] == j) System.out.println(unknown[i].toString()); } } } } /** * * @param args */ public static void main(String[] args){ new Kmeans().main(); } }
package com.lele; /** * * @author zhaole609 * define a point class * */ public class Point { private double x = 0; private double y = 0; /** Creates a new instance of Point */ public Point(double x,double y) { this.setX(x); this.setY(y); } public double getX() { return x; } public void setX(double x) { this.x = x; } public double getY() { return y; } public void setY(double y) { this.y = y; } public String toString(){ return "[" + x + "," + y + "]"; } /** public static void main(String[] args) { System.out.println(new Point(3,4).toString()); }*/ }
代码基本上和原来的一样,只是中间稍微变了一下。