Standard Kmean Cluster的实现[Java]

Kmean Cluster是一种机器学习中常用的无监督分析方法,例如,在最近的项目中,要从数以百万、千万计的高维图像特征中提取具有代表性的视觉词,就用到了此类技术。

Kmean并不是一种高效的算法,理论可以证明,在欧几里得空间中的Kmean问题是NP-Hard(即使聚类数仅为2)。假设单个向量维度为d,向量数为n,目标聚类数为k,则算法的时间复杂度=n^(dk+1)*logn。

kmean的示意图如下:

Standard Kmean Cluster的实现[Java]

一些启发式的算法对Standard Kmean的效率进行了优化,常见的包括:
  • 基于最大期望的算法(EM algorithm):采用概率的方式将输入向量分配到各个聚类当中(而非Standard Kmean中的绝对分配),并且采用高斯分布代替Standard Kmean中的算术平均值计算聚类中心。
  • Kmean++:通过选取初始聚类集合来达到更好的效果
  • filtering algorithm:通过使用kd树来加速Kmean效率
  • 其它优化算法:诸如coresets、triangle inequ以及Escape local optima等算法

尽管如此,经典的Standard Kmean仍然是使用频率较高的聚类算法,在数据量不大或者demo阶段被广为使用,下面给出java实现的代码:

public Cluster[] cluster(final List<Clusterable> values, int numClusters) {
    Cluster[] clusters = calculateInitialClusters(values, numClusters);
    
    boolean recalculateClusters = true;

    int numIterations = 0;
    while (recalculateClusters) {
        // add all items to nearest cluster
	clusters = assignClusters(clusters, values);

	// see if the cluster distance hasn't moved
	recalculateClusters = mChecker.recalculateClusters(clusters);

	// if it needs to be run again, set up new clusters on the updated
	// centers
	if (recalculateClusters) {
	    if (numIterations > mMaxReclustering) {
	        recalculateClusters = false;
	    }

	    clusters = getNewClusters(clusters);

	    numIterations++;
        }
    }

    return clusters;
}


import java.util.ArrayList;
import java.util.List;
import java.util.Random;



public class KMeansClusterer extends AbstractKClusterer {
	public KMeansClusterer() {
		super();
	}
	
	protected Cluster[] assignClusters(Cluster[] clusters,final List<Clusterable> values){
		assignClustersByDistance(clusters, values);
		return clusters;
	}
	
	protected void assignClustersByDistance(Cluster[] clusters, List<Clusterable> values){
		for ( int j = 0; j < values.size(); j++ ){
			Clusterable val = values.get(j);
			Cluster nearestCluster = null;
			double minDistance = Float.MAX_VALUE;
			for ( int i = 0; i < clusters.length; i++ ){
				Cluster cluster =  clusters[i];
				double distance = ClusterUtils.getEuclideanDistance(val,cluster);
				//System.out.println("cluster " + i + ", point " + j + ",distance: " + distance);
				if ( distance < minDistance ){
					nearestCluster = cluster;
					minDistance = distance;
				}
			}
			nearestCluster.addItem(val);
		}
	}
	
	protected Cluster[] getNewClusters(Cluster[] clusters){
		for ( int i = 0; i < clusters.length; i++ ){
			if ( clusters[i].getItems().size() > 0 )
				clusters[i] = new Cluster(clusters[i].getClusterMean(),i);
		}
		return clusters;
	}
	
	public static void main(String args[]){
		Random random = new Random(System.currentTimeMillis());
		int numPoints = 50;
		List<Clusterable> points = new ArrayList<Clusterable>(numPoints);
		for ( int i = 0; i < numPoints; i++ ){
			int x = random.nextInt(1000) - 500;
			int y = random.nextInt(1000) - 500;
			points.add(new Point((float)x,(float)y));
		}
		KClusterer clusterer = new KMeansClusterer();
		Cluster[] clusters = clusterer.cluster(points,10);
		for ( Cluster c : clusters ){
			System.out.println(c.getItems().size());
		}
	}
	
	public static boolean hasBadValue(double[] values){
		for ( double value : values ){
			if ( !(value < 1 && value > -1) ){
				System.out.println(value + " is 'bad'");
				return true;
			}
		}
		return false;
	}
}


Standard Kmean在大数据量时其表现往往不尽如人意,后续我会附上kd random forest的优化算法

你可能感兴趣的:(java,C++,c,算法,J#)