K中心点算法(K-medoids)提出了新的质点选取方式,而不是简单像k-means算法采用均值计算法。在K中心点算法中,每次迭代后的质点都是从聚类的样本点中选取,而选取的标准就是当该样本点成为新的质点后能提高类簇的聚类质量,使得类簇更紧凑。该算法使用绝对误差标准来定义一个类簇的紧凑程度。
如果某样本点成为质点后,绝对误差能小于原质点所造成的绝对误差,那么K中心点算法认为该样本点是可以取代原质点的,在一次迭代重计算类簇质点的时候,我们选择绝对误差最小的那个样本点成为新的质点。较好的解决了对离群点/噪声数据的敏感,但时间复杂度上升至O(k(m-k)^2)。计算量显然要比KMeans要大,一般只适合小数据量。
二分KMeans是对基本KMeans的直接扩充,它基于一种简单想法:为了得到K个簇,将所有点集合分裂成两个簇,从这些簇中选取一个继续分裂,直到产生K个簇。
二分k均值(bisecting k-means)算法的主要思想是:首先将所有点作为一个簇,然后将该簇一分为二。之后选择能最大程度降低聚类代价函数(也就是误差平方和)的簇划分为两个簇。以此进行下去,直到簇的数目等于用户给定的数目k为止。
以上隐含着一个原则是:因为聚类的误差平方和能够衡量聚类性能,该值越小表示数据点月接近于它们的质心,聚类效果就越好。所以我们就需要对误差平方和最大的簇进行再一次的划分,因为误差平方和越大,表示该簇聚类越不好,越有可能是多个簇被当成一个簇了,所以我们首先需要对这个簇进行划分。
二分k均值算法的伪代码如下:
将所有数据点看成一个簇
当簇数目小于k时
对每一个簇
计算总误差
在给定的簇上面进行k-均值聚类(k=2)
计算将该簇一分为二后的总误差
选择使得误差最小的那个簇进行划分操作
下面用Java来简单实现算法,考虑简单,点只用了二维。
public class KMeansCluster extends AbstractCluster {
public static final double THRESHOLD = 1.0;
public List initData() {
List points = new ArrayList();
InputStream in = null;
BufferedReader br = null;
try {
in = KMeansCluster.class.getClassLoader().getResourceAsStream("kmeans1.txt");
br = new BufferedReader(new InputStreamReader(in));
String line = br.readLine();
while (null != line && !"".equals(line)) {
StringTokenizer tokenizer = new StringTokenizer(line);
double x = Double.parseDouble(tokenizer.nextToken());
double y = Double.parseDouble(tokenizer.nextToken());
points.add(new Point(x , y));
line = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
} finally {
IOUtils.closeQuietly(in);
IOUtils.closeQuietly(br);
}
return points;
}
//随机生成中心点,并生成初始的K个聚类
public List genInitCluster(List points, int k) {
List clusters = new ArrayList();
Random random = new Random();
for (int i = 0, len = points.size(); i < k; i++) {
PointCluster cluster = new PointCluster();
Point center = points.get(random.nextInt(len));
cluster.setCenter(center);
cluster.getPoints().add(center);
clusters.add(cluster);
}
return clusters;
}
//将点归入到聚类中
public void handleCluster(List points, List clusters) {
for (Point point : points) {
PointCluster minCluster = null;
double minDistance = Integer.MAX_VALUE;
for (PointCluster cluster : clusters) {
Point center = cluster.getCenter();
double distance = euclideanDistance(point, center);
// double distance = manhattanDistance(point, center);
if (distance < minDistance) {
minDistance = distance;
minCluster = cluster;
}
}
if (null != minCluster) {
minCluster.getPoints().add(point);
}
}
//终止条件定义为原中心点与新中心点距离小于一定阀值
//当然也可以定义为原中心点等于新中心点
boolean flag = true;
for (PointCluster cluster : clusters) {
Point center = cluster.getCenter();
System.out.println("center: " + center);
Point newCenter = cluster.computeMeansCenter();
System.out.println("new center: " + newCenter);
// if (!center.equals(newCenter)) {
double distance = euclideanDistance(center, newCenter);
System.out.println("distaince: " + distance);
if (distance > THRESHOLD) {
flag = false;
cluster.setCenter(newCenter);
}
}
if (!flag) {
for (PointCluster cluster : clusters) {
cluster.getPoints().clear();
}
handleCluster(points, clusters);
}
}
public List cluster(List points, int k) {
List clusters = genInitCluster(points, k);
handleCluster(points, clusters);
return clusters;
}
public void build() {
List points = initData();
List clusters = cluster(points, 4);
printClusters(clusters);
}
public static void main(String[] args) {
KMeansCluster builder = new KMeansCluster();
builder.build();
}
}
KMediodsCluster
public class KMediodsCluster extends AbstractCluster {
public static final double THRESHOLD = 2.0;
public List initData() {
List points = new ArrayList();
InputStream in = null;
BufferedReader br = null;
try {
in = KMediodsCluster.class.getClassLoader().getResourceAsStream("kmeans1.txt");
br = new BufferedReader(new InputStreamReader(in));
String line = br.readLine();
while (null != line && !"".equals(line)) {
StringTokenizer tokenizer = new StringTokenizer(line);
double x = Double.parseDouble(tokenizer.nextToken());
double y = Double.parseDouble(tokenizer.nextToken());
points.add(new Point(x , y));
line = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
} finally {
IOUtils.closeQuietly(in);
IOUtils.closeQuietly(br);
}
return points;
}
//随机生成中心点,并生成初始的K个聚类
public List genInitCluster(List points, int k) {
List clusters = new ArrayList();
Random random = new Random();
for (int i = 0, len = points.size(); i < k; i++) {
PointCluster cluster = new PointCluster();
Point center = points.get(random.nextInt(len));
cluster.setCenter(center);
cluster.getPoints().add(center);
clusters.add(cluster);
}
return clusters;
}
//将点归入到聚类中
public void handleCluster(List points, List clusters) {
for (Point point : points) {
PointCluster minCluster = null;
double minDistance = Integer.MAX_VALUE;
for (PointCluster cluster : clusters) {
Point center = cluster.getCenter();
double distance = euclideanDistance(point, center);
// double distance = manhattanDistance(point, center);
if (distance < minDistance) {
minDistance = distance;
minCluster = cluster;
}
}
if (null != minCluster) {
minCluster.getPoints().add(point);
}
}
//终止条件定义为原中心点与新中心点距离小于一定阀值
//当然也可以定义为原中心点等于新中心点
boolean flag = true;
for (PointCluster cluster : clusters) {
Point center = cluster.getCenter();
System.out.println("center: " + center);
Point newCenter = cluster.computeMediodsCenter();
System.out.println("new center: " + newCenter);
// if (!center.equals(newCenter)) {
double distance = euclideanDistance(center, newCenter);
System.out.println("distaince: " + distance);
if (distance > THRESHOLD) {
flag = false;
cluster.setCenter(newCenter);
}
}
if (!flag) {
for (PointCluster cluster : clusters) {
cluster.getPoints().clear();
}
handleCluster(points, clusters);
}
}
public List cluster(List points, int k) {
List clusters = genInitCluster(points, k);
handleCluster(points, clusters);
return clusters;
}
public void build() {
List points = initData();
List clusters = cluster(points, 4);
printClusters(clusters);
}
public static void main(String[] args) {
KMediodsCluster builder = new KMediodsCluster();
builder.build();
}
}
PointCluster
public class PointCluster {
private Point center = null;
private List points = null;
public Point getCenter() {
return center;
}
public void setCenter(Point center) {
this.center = center;
}
public List getPoints() {
if (null == points) {
points = new ArrayList();
}
return points;
}
public void setPoints(List points) {
this.points = points;
}
public Point computeMeansCenter() {
int len = getPoints().size();
double a = 0.0, b = 0.0;
for (Point point : getPoints()) {
a += point.getX();
b += point.getY();
}
return new Point(a / len, b / len);
}
public Point computeMediodsCenter() {
Point targetPoint = null;
double distance = Integer.MAX_VALUE;
for (Point point : getPoints()) {
double d = 0.0;
for (Point temp : getPoints()) {
d += manhattanDistance(point, temp);
}
if (d < distance) {
distance = d;
targetPoint = point;
}
}
return targetPoint;
}
public double computeSSE() {
double result = 0.0;
for (Point point : getPoints()) {
result += euclideanDistance(point, center);
}
return result;
}
//计算两点之间的曼哈顿距离
protected double manhattanDistance(Point a, Point b) {
return Math.abs(a.getX() - b.getX()) + Math.abs(a.getY() - b.getY());
}
//计算两点之间的欧氏距离
protected double euclideanDistance(Point a, Point b) {
double sum = Math.pow(a.getX() - b.getX(), 2) + Math.pow(a.getY() - b.getY(), 2);
return Math.sqrt(sum);
}
}
代码托管:https://github.com/fighting-one-piece/repository-datamining.git