更多数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm
K-Means又名为K均值算法,他是一个聚类算法,这里的K就是聚簇中心的个数,代表数据中存在多少数据簇。K-Means在聚类算法中算是非常简单的一个算法了。有点类似于KNN算法,都用到了距离矢量度量,用欧式距离作为小分类的标准。
(1)、设定数字k,从n个初始数据中随机的设置k个点为聚类中心点。
(2)、针对n个点的每个数据点,遍历计算到k个聚类中心点的距离,最后按照离哪个中心点最近,就划分到那个类别中。
(3)、对每个已经划分好类别的n个点,对同个类别的点求均值,作为此类别新的中心点。
(4)、循环(2),(3)直到最终中心点收敛。
以上的计算过程将会在下面我的程序实现中有所体现。
输入数据:
3 3 4 10 9 6 14 8 18 11 21 7主实现类:
package DataMining_KMeans; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.text.MessageFormat; import java.util.ArrayList; import java.util.Collections; /** * k均值算法工具类 * * @author lyq * */ public class KMeansTool { // 输入数据文件地址 private String filePath; // 分类类别个数 private int classNum; // 类名称 private ArrayList<String> classNames; // 聚类坐标点 private ArrayList<Point> classPoints; // 所有的数据左边点 private ArrayList<Point> totalPoints; public KMeansTool(String filePath, int classNum) { this.filePath = filePath; this.classNum = classNum; readDataFile(); } /** * 从文件中读取数据 */ private void readDataFile() { File file = new File(filePath); ArrayList<String[]> dataArray = new ArrayList<String[]>(); try { BufferedReader in = new BufferedReader(new FileReader(file)); String str; String[] tempArray; while ((str = in.readLine()) != null) { tempArray = str.split(" "); dataArray.add(tempArray); } in.close(); } catch (IOException e) { e.getStackTrace(); } classPoints = new ArrayList<>(); totalPoints = new ArrayList<>(); classNames = new ArrayList<>(); for (int i = 0, j = 1; i < dataArray.size(); i++) { if (j <= classNum) { classPoints.add(new Point(dataArray.get(i)[0], dataArray.get(i)[1], j + "")); classNames.add(i + ""); j++; } totalPoints .add(new Point(dataArray.get(i)[0], dataArray.get(i)[1])); } } /** * K均值聚类算法实现 */ public void kMeansClustering() { double tempX = 0; double tempY = 0; int count = 0; double error = Integer.MAX_VALUE; Point temp; while (error > 0.01 * classNum) { for (Point p1 : totalPoints) { // 将所有的测试坐标点就近分类 for (Point p2 : classPoints) { p2.computerDistance(p1); } Collections.sort(classPoints); // 取出p1离类坐标点最近的那个点 p1.setClassName(classPoints.get(0).getClassName()); } error = 0; // 按照均值重新划分聚类中心点 for (Point p1 : classPoints) { count = 0; tempX = 0; tempY = 0; for (Point p : totalPoints) { if (p.getClassName().equals(p1.getClassName())) { count++; tempX += p.getX(); tempY += p.getY(); } } tempX /= count; tempY /= count; error += Math.abs((tempX - p1.getX())); error += Math.abs((tempY - p1.getY())); // 计算均值 p1.setX(tempX); p1.setY(tempY); } for (int i = 0; i < classPoints.size(); i++) { temp = classPoints.get(i); System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}", (i + 1), temp.getX(), temp.getY())); } System.out.println("----------"); } System.out.println("结果值收敛"); for (int i = 0; i < classPoints.size(); i++) { temp = classPoints.get(i); System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}", (i + 1), temp.getX(), temp.getY())); } } }坐标点类:
package DataMining_KMeans; /** * 坐标点类 * * @author lyq * */ public class Point implements Comparable<Point>{ // 坐标点横坐标 private double x; // 坐标点纵坐标 private double y; //以此点作为聚类中心的类的类名称 private String className; // 坐标点之间的欧式距离 private Double distance; public Point(double x, double y) { this.x = x; this.y = y; } public Point(String x, String y) { this.x = Double.parseDouble(x); this.y = Double.parseDouble(y); } public Point(String x, String y, String className) { this.x = Double.parseDouble(x); this.y = Double.parseDouble(y); this.className = className; } /** * 距离目标点p的欧几里得距离 * * @param p */ public void computerDistance(Point p) { if (p == null) { return; } this.distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y) * (this.y - p.y); } public double getX() { return x; } public void setX(double x) { this.x = x; } public double getY() { return y; } public void setY(double y) { this.y = y; } public String getClassName() { return className; } public void setClassName(String className) { this.className = className; } public double getDistance() { return distance; } public void setDistance(double distance) { this.distance = distance; } @Override public int compareTo(Point o) { // TODO Auto-generated method stub return this.distance.compareTo(o.distance); } }调用类:
/** * K-means(K均值)算法调用类 * @author lyq * */ public class Client { public static void main(String[] args){ String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt"; //聚类中心数量设定 int classNum = 3; KMeansTool tool = new KMeansTool(filePath, classNum); tool.kMeansClustering(); } }
测试输出结果:
聚类中心点1,x=15.5,y=8 聚类中心点2,x=4,y=10 聚类中心点3,x=3,y=3 ---------- 聚类中心点1,x=17.667,y=8.667 聚类中心点2,x=6.5,y=8 聚类中心点3,x=3,y=3 ---------- 聚类中心点1,x=17.667,y=8.667 聚类中心点2,x=6.5,y=8 聚类中心点3,x=3,y=3 ---------- 结果值收敛 聚类中心点1,x=17.667,y=8.667 聚类中心点2,x=6.5,y=8 聚类中心点3,x=3,y=3
1、首先优点当然是算法简单,快速,易懂,没有涉及到特别复杂的数据结构。
2、缺点1是最开始K的数量值以及K个聚类中心点的设置不好定,往往开始时不同的k个中心点的设置对后面迭代计算的走势会有比较大的影响,这时候可以考虑根据类的自动合并和分裂来确定这个k。
3、缺点2由于计算是迭代式的,而且计算距离的时候需要完全遍历一遍中心点,当数据规模比较大的时候,开销就显得比较大了。