谱聚类(Spectral Clustering, SC)是一种基于图论的聚类方法——将带权无向图划分为两个或两个以上的最优子图,使子图内部尽量相似,而子图间距离尽量距离较远,以达到常见的聚类的目的。其中的最优是指最优目标函数不同,可以是Min Cut、Nomarlized Cut、Ratio Cut等。谱聚类能够识别任意形状的样本空间且收敛于全局最优解,其基本思想是利用样本数据的相似矩阵(拉普拉斯矩阵)进行特征分解后得到的特征向量进行聚类。
Spectral Clustering 算法步骤:
1)根据数据构造一个Graph,Graph的每一个节点对应一个数据点,将相似的点连接起来,并且边的权重用于表示数据之间的相似度。把这个Graph用邻接矩阵的形式表示出来,记为 W。
2)把W的每一列元素活者行元素加起来得到N个数,把它们放在对角线上(其他地方都是零),组成一个N*N的度矩阵,记为D 。
3)根据度矩阵与邻接矩阵得出拉普拉斯矩阵 L = D - W 。
4)求出拉普拉斯矩阵L的前k个特征值(除非特殊说明,否则“前k个”指按照特征值的大小从小到大的顺序)以及对应的特征向量。
5)把这k个特征(列)向量排列在一起组成一个N*k的矩阵,将其中每一行看作k维空间中的一个向量,并使用 K-Means算法进行聚类。聚类的结果中每一行所属的类别就是原来Graph中的节点亦即最初的N个数据点分别所属的类别。
示例
Spectral Clustering 和传统的聚类方法(如 K-Means等)对比:
1)和 K-Medoids 类似,Spectral Clustering 只需要数据之间的相似度矩阵就可以了,而不必像K-means那样要求数据必须是 N 维欧氏空间中的向量。Spectral Clustering 所需要的所有信息都包含在 W 中。不过一般 W 并不总是等于最初的相似度矩阵——回忆一下,W 是我们构造出来的 Graph 的邻接矩阵表示,通常我们在构造 Graph 的时候为了方便进行聚类,更加强到“局部”的连通性,亦即主要考虑把相似的点连接在一起,比如:我们可以设置一个阈值,如果两个点的相似度小于这个阈值,就把他们看作是不连接的。另一种构造 Graph 邻接的方法是将 n 个与节点最相似的点与其连接起来。
2)由于抓住了主要矛盾,忽略了次要的东西,因此比传统的聚类算法更加健壮一些,对于不规则的误差数据不是那么敏感,而且性能也要好一些。许多实验都证明了这一点。事实上,在各种现代聚类算法的比较中,K-means 通常都是作为 baseline 而存在的。实际上 Spectral Clustering 是在用特征向量的元素来表示原来的数据,并在这种“更好的表示形式”上进行 K-Means 。实际上这种“更好的表示形式”是用 Laplacian Eig进行降维的后的结果。而降维的目的正是“抓住主要矛盾,忽略次要的东西”。
3)计算复杂度比 K-means 要小。这个在高维数据上表现尤为明显。例如文本数据,通常排列起来是维度非常高(比如几千或者几万)的稀疏矩阵,对稀疏矩阵求特征值和特征向量有很高效的办法,得到的结果是一些 k 维的向量(通常 k 不会很大),在这些低维的数据上做 K-Means 运算量非常小。但是对于原始数据直接做 K-Means 的话,虽然最初的数据是稀疏矩阵,但是 K-Means 中有一个求 Centroid 的运算,就是求一个平均值:许多稀疏的向量的平均值求出来并不一定还是稀疏向量,事实上,在文本数据里,很多情况下求出来的 Centroid 向量是非常稠密,这时再计算向量之间的距离的时候,运算量就变得非常大,直接导致普通的 K-Means 巨慢无比,而 Spectral Clustering 等工序更多的算法则迅速得多的结果。
Java简单实现代码如下:
public class SpectralClusteringBuilder { public static int DIMENSION = 30; public static double THRESHOLD = 0.01; public Data getInitData() { Data data = new Data(); try { String path = SpectralClustering.class.getClassLoader() .getResource("测试").toURI().getPath(); DocumentSet documentSet = DocumentLoader.loadDocumentSet(path); List<Document> documents = documentSet.getDocuments(); DocumentUtils.calculateTFIDF_0(documents); DocumentUtils.calculateSimilarity(documents, new CosineDistance()); Map<String, Map<String, Double>> nmap = new HashMap<String, Map<String, Double>>(); Map<String, String> cmap = new HashMap<String, String>(); for (Document document : documents) { String name = document.getName(); cmap.put(name, document.getCategory()); Map<String, Double> similarities = nmap.get(name); if (null == similarities) { similarities = new HashMap<String, Double>(); nmap.put(name, similarities); } for (DocumentSimilarity similarity : document.getSimilarities()) { if (similarity.getDoc2().getName().equalsIgnoreCase(similarity.getDoc1().getName())) { similarities.put(similarity.getDoc2().getName(), 0.0); } else { similarities.put(similarity.getDoc2().getName(), similarity.getDistance()); } } } String[] docnames = nmap.keySet().toArray(new String[0]); data.setRow(docnames); data.setColumn(docnames); data.setDocnames(docnames); int len = docnames.length; double[][] original = new double[len][len]; for (int i = 0; i < len; i++) { Map<String, Double> similarities = nmap.get(docnames[i]); for (int j = 0; j < len; j++) { double distance = similarities.get(docnames[j]); original[i][j] = distance; } } data.setOriginal(original); data.setCmap(cmap); data.setNmap(nmap); } catch (Exception e) { e.printStackTrace(); } return data; } /** * 获取距离阀值在一定范围内的点 * @param data * @return */ public double[][] getWByDistance(Data data) { Map<String, Map<String, Double>> nmap = data.getNmap(); String[] docnames = data.getDocnames(); int len = docnames.length; double[][] w = new double[len][len]; for (int i = 0; i < len; i++) { Map<String, Double> similarities = nmap.get(docnames[i]); for (int j = 0; j < len; j++) { double distance = similarities.get(docnames[j]); w[i][j] = distance < THRESHOLD ? 1 : 0; } } return w; } /** * 获取距离最近的K个点 * @param data * @return */ public double[][] getWByKNearestNeighbors(Data data) { Map<String, Map<String, Double>> nmap = data.getNmap(); String[] docnames = data.getDocnames(); int len = docnames.length; double[][] w = new double[len][len]; for (int i = 0; i < len; i++) { List<Map.Entry<String, Double>> similarities = new ArrayList<Map.Entry<String, Double>>(nmap.get(docnames[i]).entrySet()); sortSimilarities(similarities, DIMENSION); for (int j = 0; j < len; j++) { String name = docnames[j]; boolean flag = false; for (Map.Entry<String, Double> entry : similarities) { if (name.equalsIgnoreCase(entry.getKey())) { flag = true; break; } } w[i][j] = flag ? 1 : 0; } } return w; } /** * 垂直求和 * @param W * @return */ public double[][] getVerticalD(double[][] W) { int row = W.length; int column = W[0].length; double[][] d = new double[row][column]; for (int j = 0; j < column; j++) { double sum = 0; for (int i = 0; i < row; i++) { sum += W[i][j]; } d[j][j] = sum; } return d; } /** * 水平求和 * @param W * @return */ public double[][] getHorizontalD(double[][] W) { int row = W.length; int column = W[0].length; double[][] d = new double[row][column]; for (int i = 0; i < row; i++) { double sum = 0; for (int j = 0; j < column; j++) { sum += W[i][j]; } d[i][i] = sum; } return d; } /** * 相似度排序,并取前K个,倒叙 * @param similarities * @param k */ public void sortSimilarities(List<Map.Entry<String, Double>> similarities, int k) { Collections.sort(similarities, new Comparator<Map.Entry<String, Double>>() { @Override public int compare(Entry<String, Double> o1, Entry<String, Double> o2) { return o2.getValue().compareTo(o1.getValue()); } }); while (similarities.size() > k) { similarities.remove(similarities.size() - 1); } } public void print(double[][] values) { for (int i = 0, il = values.length; i < il; i++) { for (int j = 0, jl = values[0].length; j < jl; j++) { System.out.print(values[i][j] + " "); } System.out.println("\n"); } } // 随机生成中心点,并生成初始的K个聚类 public List<DataPointCluster> genInitCluster(List<DataPoint> points, int k) { List<DataPointCluster> clusters = new ArrayList<DataPointCluster>(); Random random = new Random(); Set<String> categories = new HashSet<String>(); while (clusters.size() < k) { DataPoint center = points.get(random.nextInt(points.size())); String category = center.getCategory(); if (categories.contains(category)) continue; categories.add(category); DataPointCluster cluster = new DataPointCluster(); cluster.setCenter(center); cluster.getDataPoints().add(center); clusters.add(cluster); } return clusters; } // 将点归入到聚类中 public void handleCluster(List<DataPoint> points, List<DataPointCluster> clusters, int iterNum) { for (DataPoint point : points) { DataPointCluster maxCluster = null; double maxDistance = Integer.MIN_VALUE; for (DataPointCluster cluster : clusters) { DataPoint center = cluster.getCenter(); double distance = DistanceUtils.cosine(point.getValues(), center.getValues()); if (distance > maxDistance) { maxDistance = distance; maxCluster = cluster; } } if (null != maxCluster) { maxCluster.getDataPoints().add(point); } } // 终止条件定义为原中心点与新中心点距离小于一定阀值 // 当然也可以定义为原中心点等于新中心点 boolean flag = true; for (DataPointCluster cluster : clusters) { DataPoint center = cluster.getCenter(); DataPoint newCenter = cluster.computeMediodsCenter(); double distance = DistanceUtils.cosine(newCenter.getValues(), center.getValues()); if (distance > 0.5) { flag = false; cluster.setCenter(newCenter); } } if (!flag && iterNum < 25) { for (DataPointCluster cluster : clusters) { cluster.getDataPoints().clear(); } handleCluster(points, clusters, ++iterNum); } } /** * KMeans方法 * @param dataPoints */ public void kmeans(List<DataPoint> dataPoints) { List<DataPointCluster> clusters = genInitCluster(dataPoints, 4); handleCluster(dataPoints, clusters, 0); int success = 0, failure = 0; for (DataPointCluster cluster : clusters) { String category = cluster.getCenter().getCategory(); for (DataPoint dataPoint : cluster.getDataPoints()) { String dpCategory = dataPoint.getCategory(); if (category.equals(dpCategory)) { success++; } else { failure++; } } } System.out.println("total: " + (success + failure) + " success: " + success + " failure: " + failure); } public void build() { Data data = getInitData(); double[][] w = getWByKNearestNeighbors(data); double[][] d = getHorizontalD(w); Matrix W = new Matrix(w); Matrix D = new Matrix(d); Matrix L = D.minus(W); EigenvalueDecomposition eig = L.eig(); double[][] v = eig.getV().getArray(); double[][] vs = new double[v.length][DIMENSION]; for (int i = 0, li = v.length; i < li; i++) { for (int j = 1, lj = DIMENSION; j <= lj; j++) { vs[i][j-1] = v[i][j]; } } Matrix V = new Matrix(vs); Matrix O = new Matrix(data.getOriginal()); double[][] t = O.times(V).getArray(); List<DataPoint> dataPoints = new ArrayList<DataPoint>(); for (int i = 0; i < t.length; i++) { DataPoint dataPoint = new DataPoint(); dataPoint.setCategory(data.getCmap().get(data.getColumn()[i])); dataPoint.setValues(t[i]); dataPoints.add(dataPoint); } for (int n = 0; n < 10; n++) { kmeans(dataPoints); } } public static void main(String[] args) { new SpectralClusteringBuilder().build(); } }