数据挖掘笔记-聚类-SpectralClustering-原理与简单实现

谱聚类(Spectral Clustering, SC)是一种基于图论的聚类方法——将带权无向图划分为两个或两个以上的最优子图,使子图内部尽量相似,而子图间距离尽量距离较远,以达到常见的聚类的目的。其中的最优是指最优目标函数不同,可以是Min Cut、Nomarlized Cut、Ratio Cut等。谱聚类能够识别任意形状的样本空间且收敛于全局最优解,其基本思想是利用样本数据的相似矩阵(拉普拉斯矩阵)进行特征分解后得到的特征向量进行聚类。

Spectral Clustering 算法步骤:

1)根据数据构造一个GraphGraph的每一个节点对应一个数据点,将相似的点连接起来,并且边的权重用于表示数据之间的相似度。把这个Graph用邻接矩阵的形式表示出来,记为 W

2)把W的每一列元素活者行元素加起来得到N个数,把它们放在对角线上(其他地方都是零),组成一个N*N的度矩阵,记为

3)根据度矩阵与邻接矩阵得出拉普拉斯矩阵 L = D - W 

4)求出拉普拉斯矩阵L的前k个特征值(除非特殊说明,否则k指按照特征值的大小从小到大的顺序)以及对应的特征向量。

5)把这k个特征(列)向量排列在一起组成一个N*k的矩阵,将其中每一行看作k维空间中的一个向量,并使用 K-Means算法进行聚类。聚类的结果中每一行所属的类别就是原来Graph中的节点亦即最初的N个数据点分别所属的类别。

示例

数据挖掘笔记-聚类-SpectralClustering-原理与简单实现_第1张图片

Spectral Clustering 和传统的聚类方法(如 K-Means等)对比:
1)和 K-Medoids 类似,Spectral Clustering 只需要数据之间的相似度矩阵就可以了,而不必像K-means那样要求数据必须是 维欧氏空间中的向量。Spectral Clustering 所需要的所有信息都包含在 中。不过一般 W 并不总是等于最初的相似度矩阵——回忆一下,W 是我们构造出来的 Graph 的邻接矩阵表示,通常我们在构造 Graph 的时候为了方便进行聚类,更加强到“局部”的连通性,亦即主要考虑把相似的点连接在一起,比如:我们可以设置一个阈值,如果两个点的相似度小于这个阈值,就把他们看作是不连接的。另一种构造 Graph 邻接的方法是将 n 个与节点最相似的点与其连接起来。
2)由于抓住了主要矛盾,忽略了次要的东西,因此比传统的聚类算法更加健壮一些,对于不规则的误差数据不是那么敏感,而且性能也要好一些。许多实验都证明了这一点。事实上,在各种现代聚类算法的比较中,K-means 通常都是作为 baseline 而存在的。实际上 Spectral Clustering 是在用特征向量的元素来表示原来的数据,并在这种“更好的表示形式”上进行 K-Means 。实际上这种“更好的表示形式”是用 Laplacian Eig进行降维的后的结果。而降维的目的正是“抓住主要矛盾,忽略次要的东西”。
3计算复杂度比 K-means 要小。这个在高维数据上表现尤为明显。例如文本数据,通常排列起来是维度非常高(比如几千或者几万)的稀疏矩阵,对稀疏矩阵求特征值和特征向量有很高效的办法,得到的结果是一些 k 维的向量(通常 k 不会很大),在这些低维的数据上做 K-Means 运算量非常小。但是对于原始数据直接做 K-Means 的话,虽然最初的数据是稀疏矩阵,但是 K-Means 中有一个求 Centroid 的运算,就是求一个平均值:许多稀疏的向量的平均值求出来并不一定还是稀疏向量,事实上,在文本数据里,很多情况下求出来的 Centroid 向量是非常稠密,这时再计算向量之间的距离的时候,运算量就变得非常大,直接导致普通的 K-Means 巨慢无比,而 Spectral Clustering 等工序更多的算法则迅速得多的结果。


Java简单实现代码如下:

public class SpectralClusteringBuilder {

	public static int DIMENSION = 30;
	
	public static double THRESHOLD = 0.01;

	public Data getInitData() {
		Data data = new Data();
		try {
			String path = SpectralClustering.class.getClassLoader()
					.getResource("测试").toURI().getPath();
			DocumentSet documentSet = DocumentLoader.loadDocumentSet(path);
			List<Document> documents = documentSet.getDocuments();
			DocumentUtils.calculateTFIDF_0(documents);
			DocumentUtils.calculateSimilarity(documents, new CosineDistance());
			Map<String, Map<String, Double>> nmap = new HashMap<String, Map<String, Double>>();
			Map<String, String> cmap = new HashMap<String, String>();
			for (Document document : documents) {
				String name = document.getName();
				cmap.put(name, document.getCategory());
				Map<String, Double> similarities = nmap.get(name);
				if (null == similarities) {
					similarities = new HashMap<String, Double>();
					nmap.put(name, similarities);
				}
				for (DocumentSimilarity similarity : document.getSimilarities()) {
					if (similarity.getDoc2().getName().equalsIgnoreCase(similarity.getDoc1().getName())) {
						similarities.put(similarity.getDoc2().getName(), 0.0);
					} else {
						similarities.put(similarity.getDoc2().getName(), similarity.getDistance());
					}
				}
			}
			String[] docnames = nmap.keySet().toArray(new String[0]);
			data.setRow(docnames);
			data.setColumn(docnames);
			data.setDocnames(docnames);
			int len = docnames.length;
			double[][] original = new double[len][len];
			for (int i = 0; i < len; i++) {
				Map<String, Double> similarities = nmap.get(docnames[i]);
				for (int j = 0; j < len; j++) {
					double distance = similarities.get(docnames[j]);
					original[i][j] = distance;
				}
			}
			data.setOriginal(original);
			data.setCmap(cmap);
			data.setNmap(nmap);
		} catch (Exception e) {
			e.printStackTrace();
		}
		return data;
	}

	/**
	 * 获取距离阀值在一定范围内的点
	 * @param data
	 * @return
	 */
	public double[][] getWByDistance(Data data) {
		Map<String, Map<String, Double>> nmap = data.getNmap();
		String[] docnames = data.getDocnames();
		int len = docnames.length;
		double[][] w = new double[len][len];
		for (int i = 0; i < len; i++) {
			Map<String, Double> similarities = nmap.get(docnames[i]);
			for (int j = 0; j < len; j++) {
				double distance = similarities.get(docnames[j]);
				w[i][j] = distance < THRESHOLD ? 1 : 0;
			}
		}
		return w;
	}
	
	/**
	 * 获取距离最近的K个点
	 * @param data
	 * @return
	 */
	public double[][] getWByKNearestNeighbors(Data data) {
		Map<String, Map<String, Double>> nmap = data.getNmap();
		String[] docnames = data.getDocnames();
		int len = docnames.length;
		double[][] w = new double[len][len];
		for (int i = 0; i < len; i++) {
			List<Map.Entry<String, Double>> similarities = 
					new ArrayList<Map.Entry<String, Double>>(nmap.get(docnames[i]).entrySet());
			sortSimilarities(similarities, DIMENSION);
			for (int j = 0; j < len; j++) {
				String name = docnames[j];
				boolean flag = false;
				for (Map.Entry<String, Double> entry : similarities) {
					if (name.equalsIgnoreCase(entry.getKey())) {
						flag = true;
						break;
					}
				}
				w[i][j] = flag ? 1 : 0;
			}
		}
		return w;
	}

	/**
	 * 垂直求和
	 * @param W
	 * @return
	 */
	public double[][] getVerticalD(double[][] W) {
		int row = W.length;
		int column = W[0].length;
		double[][] d = new double[row][column];
		for (int j = 0; j < column; j++) {
			double sum = 0;
			for (int i = 0; i < row; i++) {
				sum += W[i][j];
			}
			d[j][j] = sum;
		}
		return d;
	}

	/**
	 * 水平求和
	 * @param W
	 * @return
	 */
	public double[][] getHorizontalD(double[][] W) {
		int row = W.length;
		int column = W[0].length;
		double[][] d = new double[row][column];
		for (int i = 0; i < row; i++) {
			double sum = 0;
			for (int j = 0; j < column; j++) {
				sum += W[i][j];
			}
			d[i][i] = sum;
		}
		return d;
	}
	
	/**
	 * 相似度排序,并取前K个,倒叙
	 * @param similarities
	 * @param k
	 */
	public void sortSimilarities(List<Map.Entry<String, Double>> similarities, int k) {
		Collections.sort(similarities, new Comparator<Map.Entry<String, Double>>() {
			@Override
			public int compare(Entry<String, Double> o1,
					Entry<String, Double> o2) {
				return o2.getValue().compareTo(o1.getValue());
			}
		});
		while (similarities.size() > k) {
			similarities.remove(similarities.size() - 1);
		}
	}

	public void print(double[][] values) {
		for (int i = 0, il = values.length; i < il; i++) {
			for (int j = 0, jl = values[0].length; j < jl; j++) {
				System.out.print(values[i][j] + "  ");
			}
			System.out.println("\n");
		}
	}

	// 随机生成中心点,并生成初始的K个聚类
	public List<DataPointCluster> genInitCluster(List<DataPoint> points, int k) {
		List<DataPointCluster> clusters = new ArrayList<DataPointCluster>();
		Random random = new Random();
		Set<String> categories = new HashSet<String>();
		while (clusters.size() < k) {
			DataPoint center = points.get(random.nextInt(points.size()));
			String category = center.getCategory();
			if (categories.contains(category))
				continue;
			categories.add(category);
			DataPointCluster cluster = new DataPointCluster();
			cluster.setCenter(center);
			cluster.getDataPoints().add(center);
			clusters.add(cluster);
		}
		return clusters;
	}

	// 将点归入到聚类中
	public void handleCluster(List<DataPoint> points,
			List<DataPointCluster> clusters, int iterNum) {
		for (DataPoint point : points) {
			DataPointCluster maxCluster = null;
			double maxDistance = Integer.MIN_VALUE;
			for (DataPointCluster cluster : clusters) {
				DataPoint center = cluster.getCenter();
				double distance = DistanceUtils.cosine(point.getValues(),
						center.getValues());
				if (distance > maxDistance) {
					maxDistance = distance;
					maxCluster = cluster;
				}
			}
			if (null != maxCluster) {
				maxCluster.getDataPoints().add(point);
			}
		}
		// 终止条件定义为原中心点与新中心点距离小于一定阀值
		// 当然也可以定义为原中心点等于新中心点
		boolean flag = true;
		for (DataPointCluster cluster : clusters) {
			DataPoint center = cluster.getCenter();
			DataPoint newCenter = cluster.computeMediodsCenter();
			double distance = DistanceUtils.cosine(newCenter.getValues(),
					center.getValues());
			if (distance > 0.5) {
				flag = false;
				cluster.setCenter(newCenter);
			}
		}
		if (!flag && iterNum < 25) {
			for (DataPointCluster cluster : clusters) {
				cluster.getDataPoints().clear();
			}
			handleCluster(points, clusters, ++iterNum);
		}
	}

	/**
	 * KMeans方法
	 * @param dataPoints
	 */
	public void kmeans(List<DataPoint> dataPoints) {
		List<DataPointCluster> clusters = genInitCluster(dataPoints, 4);
		handleCluster(dataPoints, clusters, 0);
		int success = 0, failure = 0;
		for (DataPointCluster cluster : clusters) {
			String category = cluster.getCenter().getCategory();
			for (DataPoint dataPoint : cluster.getDataPoints()) {
				String dpCategory = dataPoint.getCategory();
				if (category.equals(dpCategory)) {
					success++;
				} else {
					failure++;
				}
			}
		}
		System.out.println("total: " + (success + failure) + " success: "
				+ success + " failure: " + failure);
	}

	public void build() {
		Data data = getInitData();
		double[][] w = getWByKNearestNeighbors(data);
		double[][] d = getHorizontalD(w);
		Matrix W = new Matrix(w);
		Matrix D = new Matrix(d);
		Matrix L = D.minus(W);
		EigenvalueDecomposition eig = L.eig();
		double[][] v = eig.getV().getArray();
		double[][] vs = new double[v.length][DIMENSION];
		for (int i = 0, li = v.length; i < li; i++) {
			for (int j = 1, lj = DIMENSION; j <= lj; j++) {
				vs[i][j-1] = v[i][j];
			}
		}
		Matrix V = new Matrix(vs);
		Matrix O = new Matrix(data.getOriginal());
	    double[][] t = O.times(V).getArray();
	    List<DataPoint> dataPoints = new ArrayList<DataPoint>();
		for (int i = 0; i < t.length; i++) {
			DataPoint dataPoint = new DataPoint();
			dataPoint.setCategory(data.getCmap().get(data.getColumn()[i]));
			dataPoint.setValues(t[i]);
			dataPoints.add(dataPoint);
		}
		for (int n = 0; n < 10; n++) {
			kmeans(dataPoints);
		}
	}

	public static void main(String[] args) {
		new SpectralClusteringBuilder().build();
	}
}

代码托管:https://github.com/fighting-one-piece/repository-datamining.git



你可能感兴趣的:(数据挖掘,聚类,谱聚类,Spectral,拉普拉斯)