下图为k-means 算法指定2 个聚类中心,迭代十次得到的结果
package com.yc;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
public class KMeans {
private static final String fileName = "points.txt";
private static final String resultFileName = "result.txt";
private static final int k = 2;
private static final int n = 10;
private static Location[] center = new Location[k];
private static List> microCluster = new CopyOnWriteArrayList<>();
static class Location {
private List location;// 点的坐标
public Location(List location) {
this.location = location;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result
+ ((location == null) ? 0 : location.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Location other = (Location) obj;
if (location == null) {
if (other.location != null)
return false;
} else if (!location.equals(other.location))
return false;
return true;
}
@Override
public String toString() {
return "Location [location=" + location + "]";
}
}
static class Point {
private int id; // 点的序列号
private List location;// 点的坐标
private int flag; // 0表示未处理,1表示处理了
Point(int id, List location, int flag) {
this.id = id;
this.location = location;
this.flag = flag;
}
public int getId() {
return id;
}
public void setId(int id) {
this.id = id;
}
public List getLocation() {
return location;
}
public void setLocation(List location) {
this.location = location;
}
public int getFlag() {
return flag;
}
public void setFlag(int flag) {
this.flag = flag;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + flag;
result = prime * result + id;
result = prime * result
+ ((location == null) ? 0 : location.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Point other = (Point) obj;
if (flag != other.flag)
return false;
if (id != other.id)
return false;
if (location == null) {
if (other.location != null)
return false;
} else if (!location.equals(other.location))
return false;
return true;
}
@Override
public String toString() {
return "Point [id=" + id + ", location=" + location + ", flag="
+ flag + "]";
}
}
public static List readFile() throws Exception {
BufferedReader file = new BufferedReader(new FileReader(fileName));
List points = new ArrayList();
String line = file.readLine();
int i = 0;
while (line != null) {
String p[] = line.split("[;|,|\\s]");
List location = new ArrayList();
for (int j = 0; j < p.length; j++) {
location.add(Double.parseDouble(p[j]));
}
points.add(new Point(i++, location, 0));
line = file.readLine();
}
file.close();
return points;
}
public static void saveResult() throws Exception {
BufferedWriter bw = new BufferedWriter(new FileWriter(resultFileName));
int i = 1;
for (List mic : microCluster) {
for (Point p : mic) {
StringBuffer sb = new StringBuffer();
for (int j = 0; j < p.location.size(); j++) {
sb.append(p.location.get(j) + ",");
}
bw.write(sb.append(i).toString());
bw.newLine();
}
i++;
}
bw.flush();
bw.close();
}
public static void randomSet(int min, int max, int n, Set set) {
if (n > (max - min + 1) || max < min) {
return;
}
for (int i = 0; i < n; i++) {
// 调用Math.random()方法
int num = (int) (Math.random() * (max - min)) + min;
set.add(num);// 将不同的数存入HashSet中
}
int setSize = set.size();
// 如果存入的数小于指定生成的个数,则调用递归再生成剩余个数的随机数,如此循环,直到达到指定大小
if (setSize < n) {
randomSet(min, max, n - setSize, set);// 递归
}
}
public static Location getNewCenter(List points) {
List sum = new ArrayList<>();
for (Point p : points) {
int size = p.location.size();
for (int i = 0; i < size; i++) {
sum.add(0.0);
if (sum.get(i) == 0.0) sum.set(i,p.location.get(i));
else sum.set(i, sum.get(i) + p.location.get(i));
}
}
List loc = new ArrayList<>();
for (Double s : sum)
loc.add(s / points.size());
return new Location(loc);
}
public static double getDistance(Point point1, Location loc) {
int wide = point1.location.size(); // 共多少维
double sum = 0;
for (int i = 0; i < wide; i++) {
try {
sum += Math.pow((point1.location.get(i) - loc.location.get(i)), 2);
} catch (Exception e) {
// System.out.println("i="+i+" pp :"+point1.location+"loc:"+loc.location);
}
}
return Math.sqrt(sum);
}
public static void micCluster(List points){
microCluster = new ArrayList<>();
for(int i=0;inew ArrayList<>());
}
for (Point p : points) {
List dis = new ArrayList<>();
for (int i = 0; i < k; i++)
dis.add(getDistance(p, center[i]));
int min = getMin(dis);
List po = microCluster.get(min);
if(po == null){
List pp = new ArrayList<>();
pp.add(p);
microCluster.set(min,pp );
}
else {
po.add(p);
microCluster.set(min,po );
}
}
}
public static void cluster(List points) {
micCluster(points);
for(int i =0;ifor(int j = 0;j//新质心
}
micCluster(points);
}
}
public static int getMin(List dis) {
if(dis.size() == 1) return 0;
double min = dis.get(0);
int loc = 0;
for(int i = 0;iif(dis.get(i) < min){
min = dis.get(i);
loc = i;
}
}
//System.out.println(dis+""+loc);
return loc;
}
public static void main(String[] args) throws Exception {
long start = System.nanoTime();
List points = readFile();
Set centers = new HashSet<>();
randomSet(0, points.size(), k, centers);
for (int i = 0; i < k; i++){
center[i] = new Location(points
.get(new ArrayList<>(centers).get(i)).getLocation());
}
cluster(points);
saveResult();
System.out.println("use time:" + (System.nanoTime() - start));
}
}