Purity、NMI、RI、Precision、Recall、F值,聚类指标的计算JAVA实现。

关于聚类的这六个evaluation metrics,参考evaluation of clustering,讲得很好了,我就不赘述了,直接上代码:

第一个:计算NMI的:

package clusters;

import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * DATE: 16-6-18 TIME: 上午10:00
 */

/**
 * 参考文献:http://www-nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html
 */
public class NormalizedMutualInformation {
    public static String path = "/home/fhqplzj/IdeaProjects/Vein/src/main/resources/nmi_data";

    public static void loadData(List> lists) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
            String line;
            while ((line = bufferedReader.readLine()) != null) {
                String[] data = line.split("\\s+");
                ArrayList integers = new ArrayList<>();
                for (String s : data) {
                    integers.add(Integer.parseInt(s));
                }
                lists.add(integers);
            }
            bufferedReader.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] args) {
        List> lists = new ArrayList<>();
        loadData(lists);
        int K = lists.size();
        int N = 0;
        int[] clusters = new int[K];
        for (int i = 0; i < K; i++) {
            clusters[i] = lists.get(i).size();
            N += clusters[i];
        }
        Map map = new HashMap<>();
        for (List list : lists) {
            for (Integer integer : list) {
                map.put(integer, map.getOrDefault(integer, 0) + 1);
            }
        }
        double clusterEntropy = 0;
        for (int cluster : clusters) {
            double tmp = 1.0 * cluster / N;
            clusterEntropy -= (tmp * (Math.log(tmp) / Math.log(2)));
        }
//        System.out.println("clusterEntropy = " + clusterEntropy);
        double classEntropy = 0;
        for (Integer integer : map.values()) {
            double tmp = 1.0 * integer / N;
            classEntropy -= (tmp * (Math.log(tmp) / Math.log(2)));
        }
//        System.out.println("classEntropy = " + classEntropy);
        double totalEntropy = 0;
        Map tmpMap = new HashMap<>();
        for (int i = 0; i < K; i++) {
            int wk = clusters[i];
            tmpMap.clear();
            for (Integer integer : lists.get(i)) {
                tmpMap.put(integer, tmpMap.getOrDefault(integer, 0) + 1);
            }
            for (Map.Entry entry : tmpMap.entrySet()) {
                int cj = map.get(entry.getKey());
                int value = entry.getValue();
                totalEntropy += (1.0 * value / N * (Math.log(1.0 * N * value / (wk * cj)) / Math.log(2)));
            }
        }
//        System.out.println("totalEntropy = " + totalEntropy);
        double nmi = 2 * totalEntropy / (clusterEntropy + classEntropy);
        System.out.println(String.format("nmi = %.2f", nmi));
    }
}

第二个,一些工具类:

package clusters;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
 * DATE: 16-6-18 TIME: 上午11:07
 */
public class ClusterUtils {
    public static int combination(int n, int k) {
        if (k > n) {
            return 0;
        }
        int[] data = new int[n + 1];
        data[0] = 1;
        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j >= 1; j--) {
                data[j] += data[j - 1];
            }
        }
        return data[k];
    }

    public static int computeTPAndFP(int[] clusters) {
        int result = 0;
        for (int cluster : clusters) {
            result += combination(cluster, 2);
        }
        return result;
    }

    public static int computeFP(List> mapList) {
        int FP = 0;
        for (Map map : mapList) {
            for (Integer integer : map.values()) {
                if (integer >= 2) {
                    FP += combination(integer, 2);
                }
            }
        }
        return FP;
    }

    public static int computeOneClass(List list) {
        int n = list.size();
        if (n == 0) {
            return 0;
        }
        int result = 0;
        for (int i = 0; i < n - 1; i++) {
            for (int j = i + 1; j < n; j++) {
                result += list.get(i) * list.get(j);
            }
        }
        return result;
    }

    public static int computeFN(List> lists) {
        int result = 0;
        for (List list : lists) {
            result += computeOneClass(list);
        }
        return result;
    }

    public static double computeFValue(double P, double R, double beta) {
        return (beta * beta + 1) * P * R / (beta * beta * P + R);
    }

    public static void main(String[] args) {
        List list = Arrays.asList(1, 4, 0);
        System.out.println("computeOneClass(list) = " + computeOneClass(list));
    }
}

第三个,计算RI、P、R、F以及Purity的,顺便调用了NMI,一起打印输出,beta取1和5,如stanford文章所述,计算F1和F5,上代码:

package clusters;

import java.util.*;

/**
 * DATE: 16-6-18 TIME: 上午11:05
 */
public class RandIndex {
    public static void main(String[] args) {
        List> lists = new ArrayList<>();
        NormalizedMutualInformation.loadData(lists);
        int K = lists.size();
        int N = 0;
        int[] clusters = new int[K];
        for (int i = 0; i < K; i++) {
            clusters[i] = lists.get(i).size();
            N += clusters[i];
        }
        int TPAndFP = ClusterUtils.computeTPAndFP(clusters);
        List> mapList = new ArrayList<>();
        for (List list : lists) {
            Map map = new HashMap<>();
            for (Integer integer : list) {
                map.put(integer, map.getOrDefault(integer, 0) + 1);
            }
            mapList.add(map);
        }
        Set set = new HashSet<>();
        for (Map map : mapList) {
            set.addAll(map.keySet());
        }
        int FP = ClusterUtils.computeFP(mapList);
        int TP = TPAndFP - FP;
        List> lists1 = new ArrayList<>();
        for (Integer integer : set) {
            List list = new ArrayList<>();
            for (Map map : mapList) {
                if (map.containsKey(integer)) {
                    list.add(map.get(integer));
                }
            }
            lists1.add(list);
        }
        int FN = ClusterUtils.computeFN(lists1);
        int TN = ClusterUtils.combination(N, 2) - TPAndFP - FN;
//        System.out.println("TP = " + TP);
//        System.out.println("FP = " + FP);
//        System.out.println("FN = " + FN);
//        System.out.println("TN = " + TN);
        double RI = 1.0 * (TP + TN) / (TP + FP + FN + TN);
        /**
         * compute Purity
         */
        int totalMax = 0;
        for (Map map : mapList) {
            totalMax += map.values().stream().reduce(Math::max).get();
        }
        double purity = 1.0 * totalMax / N;
        System.out.println(String.format("purity = %.2f", purity));
        /**
         * println Normalized Mutual Information
         */
        NormalizedMutualInformation.main(null);
        System.out.println(String.format("RI = %.2f", RI));
        /**
         * compute F5
         */
        double P = 1.0 * TP / (TP + FP);
        double R = 1.0 * TP / (TP + FN);
        double beta = 1;
        System.out.println(String.format("P = %.2f", P));
        System.out.printf("R = %.3f\n", R);
        System.out.println(String.format("beta = 1, F = %.2f", ClusterUtils.computeFValue(P, R, beta)));
        beta = 5;
        System.out.println(String.format("beta = 5, F = %.3f", ClusterUtils.computeFValue(P, R, beta)));
    }
}
输入数据就是stanford文中的3个类簇:

1 1 1 1 1 2
1 2 2 2 2 3
1 1 3 3 3

顺便贴个运行图:

Purity、NMI、RI、Precision、Recall、F值,聚类指标的计算JAVA实现。_第1张图片
OK,完毕!


你可能感兴趣的:(clustering)