[机器学习] 聚类算法的轮廓系数,java实现

这次实现一个轮廓系数(wiki , baidu)。目的是为了评估聚类效果的好坏。

我比较推荐大家观看wiki的说法,百度里面的有些说的不是很明白,比如百度百科中的这句话就很费劲 (计算 b(i) = min (i向量到所有非本身所在簇的点的平均距离)

下面是wiki的轮廓系数的说明,大体说一下我的理解: 

a(i)是中心点到自己cluster中的平均距离。

b(i)是中心点到其他cluster的各个距离中的的最小值,下面的就是两者中的最大值。

Assume the data have been clustered via any technique, such as k-means, into {\displaystyle k} clusters. For each datum {\displaystyle i}, let {\displaystyle a(i)} be the average dissimilarity of {\displaystyle i} with all other data within the same cluster. We can interpret {\displaystyle a(i)} as how well {\displaystyle i} is assigned to its cluster (the smaller the value, the better the assignment). We then define the average dissimilarity of point {\displaystyle i} to a cluster {\displaystyle c} as the average of the distance from {\displaystyle i} to all points in {\displaystyle c}.

Let {\displaystyle b(i)} be the lowest average dissimilarity of {\displaystyle i} to any other cluster, of which {\displaystyle i} is not a member. The cluster with this lowest average dissimilarity is said to be the "neighbouring cluster" of {\displaystyle i} because it is the next best fit cluster for point {\displaystyle i}. We now define a silhouette:



代码如下:

package com.mj.datamining.test;

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.rdd.RDD;

import scala.Tuple2;

public class KMeansTest {

	public static void main(String[] args) {
		init(kMeansPre());
	}
	
	public static JavaSparkContext kMeansPre() {
		SparkConf conf = new SparkConf().setAppName("Kmeans").setMaster("local[2]");
		JavaSparkContext jsc = new JavaSparkContext(conf);
		return jsc;
	}
	
	/**
	 *  0.0 0.0 0.0
		0.1 0.1 0.1
		0.2 0.2 0.2
		9.0 9.0 9.0
		9.1 9.1 9.1
		9.2 9.2 9.2
	 * @param jsc
	 */
	public static void init(JavaSparkContext jsc) {
		double[] data1 = {0.0,0.0,0.0};
		double[] data2 = {0.1,0.1,0.1};
		double[] data3 = {0.2,0.2,0.2};
		double[] data4 = {9.0,9.0,9.0};
		double[] data5 = {9.1,9.1,9.1};
		double[] data6 = {9.2,9.2,9.2};
		
		List preData = new ArrayList<>();
		Vector v1 = Vectors.dense(data1);
		Vector v2 = Vectors.dense(data2);
		Vector v3 = Vectors.dense(data3);
		Vector v4 = Vectors.dense(data4);
		Vector v5 = Vectors.dense(data5);
		Vector v6 = Vectors.dense(data6);
		preData.add(v1);
		preData.add(v2);
		preData.add(v3);
		preData.add(v4);
		preData.add(v5);
		preData.add(v6);
		JavaRDD data = jsc.parallelize(preData);

	    // Cluster the data into two classes using KMeans
	    int numClusters = 2;
	    int numIterations = 20;
	    KMeansModel clusters = KMeans.train(data.rdd(), numClusters, numIterations);

	    JavaRDD clusterResult = clusters.predict(data);
	    clusters.clusterCenters();
	    clusterResult.collect();
	    
	    double coef = silhouetteCoefficient(data.collect(),clusterResult.collect(),0,clusters.clusterCenters()[0], clusters.clusterCenters().length);
	    
	    System.out.println("Within Set Sum of Squared Errors = " + coef);

	}
	
	   private static double euclideanDistance(double[] data, double[] center) {
	    	if(data.length == 0 || data == null || center.length == 0 || center == null) {
	    		return 0.0;
	    	} else if(center.length != data.length) {
	    		throw new RuntimeException("执行的时候数据长度和中心长度不一致。");
	    	}
	    	
	    	double sum = 0.0;
	    	
	    	for(int i = 0; i < data.length; i++) {
	    		sum += Math.pow(data[i] - center[i] , 2);
	    	}
	    	
	    	return Math.sqrt(sum);
	    		
	    }
	   
	    /**
	     * a(i) - b(i) / max(a(i), b(i))
	     * a(i) the average of same cluster
	     * b(i) the min average of not same cluster
	     * @param data
	     * @param result
	     * @param flag
	     * @param center
	     * @return
	     */
	    private static double silhouetteCoefficient(List data, List result, int flag, Vector center, int centerSize) {
	    	double sameClusterSum = 0.0;
	    	double otherClusterSum = 0.0;
	    	double min = Double.MAX_VALUE;
	    	
	    	for(int j = 0; j < centerSize; j++) {
	    		if(j != flag) {
	    			for(int i = 0; i < data.size(); i++) {
	    				if(result.get(i) == j) {
	    					otherClusterSum += euclideanDistance(data.get(i).toArray(), center.toArray());
	    				}
	    			}
		    		min = min


可能有考虑不足的地方,谢谢。

你可能感兴趣的:(机器学习,机器学习,java)