聚类算法评价指标:Purity、NMI、MI、entropy、precision、recall、F值、RI值

不得不吐槽一下,Scala真的是一门奇葩的语言,太强大了,只要你好好思考,写出来的代码绝对很漂亮,瞬间感觉java太low了。

package Utils

import com.google.common.math.{DoubleMath, IntMath}

/**
  * Created by fhqplzj on 16-8-24 at 下午2:12.
  */
object Evaluation {
  /**
    * 检查标签
    *
    * @param labelsTrue
    * @param labelsPred
    */
  private def labelChecker(labelsTrue: Array[Int], labelsPred: Array[Int]): Unit = {
    require(labelsTrue.length == labelsPred.length && labelsTrue.length >= 2, "The length must be equal!" +
      "The size of labels must be greater than 1!")
  }

  /**
    * 纯度:Purity
    *
    * @param labelsTrue
    * @param labelsPred
    * @return
    */
  def purity(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
    labelChecker(labelsTrue, labelsPred)
    val eachCount: Map[(Int, Int), Int] = labelsTrue.zip(labelsPred).groupBy(x => x).mapValues(_.length)
    eachCount.groupBy(_._1._1).mapValues(_.values.max).values.sum.toDouble / labelsTrue.length
  }

  /**
    * 互信息:Mutual Information
    *
    * @param labelsTrue
    * @param labelsPred
    */
  private def mutualInformation(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
    labelChecker(labelsTrue, labelsPred)
    val N: Int = labelsTrue.length
    val mapTrue: Map[Int, Int] = labelsTrue.groupBy(x => x).mapValues(_.length)
    val mapPred: Map[Int, Int] = labelsPred.groupBy(x => x).mapValues(_.length)
    labelsTrue.zip(labelsPred).groupBy(x => x).mapValues(_.length).map {
      case ((x, y), z) =>
        val wk = mapTrue(x)
        val cj = mapPred(y)
        val common = z.toDouble
        common / N * DoubleMath.log2(N * common / (wk * cj))
    }.sum
  }

  /**
    * 熵:Entropy
    *
    * @param labels
    * @return
    */
  private def entropy(labels: Array[Int]) = {
    val N: Int = labels.length
    val array: Array[Int] = labels.groupBy(x => x).values.map(_.length).toArray
    array.map(x => -1.0 * x / N * DoubleMath.log2(1.0 * x / N)).sum
  }

  /**
    * 标准化互信息:Normalized Mutual Information
    *
    * @param labelsTrue
    * @param labelsPred
    * @return
    */
  def normalizedMutualInformation(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
    labelChecker(labelsTrue, labelsPred)
    2 * mutualInformation(labelsTrue, labelsPred) / (entropy(labelsTrue) + entropy(labelsPred))
  }

  /**
    * 混淆矩阵
    *
    * @param TP
    * @param FP
    * @param FN
    * @param TN
    */
  case class Table(TP: Int, FP: Int, FN: Int, TN: Int)

  /**
    * 计算混淆矩阵
    *
    * @param labelsTrue
    * @param labelsPred
    * @return
    */
  private def contingencyTable(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
    labelChecker(labelsTrue, labelsPred)
    def binomial(x: Int) = if (x < 2) 0 else IntMath.binomial(x, 2)
    val TPAndFP: Int = labelsPred.groupBy(x => x).values.map(x => binomial(x.length)).sum
    val tmp: Map[(Int, Int), Array[(Int, Int)]] = labelsTrue.zip(labelsPred).groupBy(x => x)
    val TP: Int = tmp.values.map(x => binomial(x.length)).sum
    val FP: Int = TPAndFP - TP
    def fun(xs: Array[Int]) = {
      val length: Int = xs.length
      val sums: Array[Int] = xs.tails.slice(1, length).toArray.map(_.sum)
      (xs.init, sums).zipped.map(_ * _).sum
    }
    val FN: Int = tmp.groupBy(_._1._1).mapValues(_.values.map(_.length).toArray).values.map(fun).sum
    val total: Int = binomial(labelsTrue.length)
    val TN: Int = total - TPAndFP - FN
    Table(TP, FP, FN, TN)
  }

  /**
    * Rand Index值
    *
    * @param labelsTrue
    * @param labelsPred
    * @return
    */
  def randIndex(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
    labelChecker(labelsTrue, labelsPred)
    val table: Table = contingencyTable(labelsTrue, labelsPred)
    1.0 * (table.TP + table.TN) / (table.TP + table.FP + table.FN + table.TN)
  }

  /**
    * 准确率:Precision
    *
    * @param labelsTrue
    * @param labelsPred
    * @return
    */
  def precision(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
    labelChecker(labelsTrue, labelsPred)
    val table: Table = contingencyTable(labelsTrue, labelsPred)
    1.0 * table.TP / (table.TP + table.FP)
  }

  /**
    * 召回率:Recall
    *
    * @param labelsTrue
    * @param labelsPred
    * @return
    */
  def recall(labelsTrue: Array[Int], labelsPred: Array[Int]) = {
    labelChecker(labelsTrue, labelsPred)
    val table: Table = contingencyTable(labelsTrue, labelsPred)
    1.0 * table.TP / (table.TP + table.FN)
  }

  /**
    * FMeasure
    * F值
    *
    * @param labelsTrue
    * @param labelsPred
    * @param beta
    * @return
    */
  def FMeasure(labelsTrue: Array[Int], labelsPred: Array[Int])(implicit beta: Double = 1.0) = {
    labelChecker(labelsTrue, labelsPred)
    val precision1: Double = precision(labelsTrue, labelsPred)
    val recall1: Double = recall(labelsTrue, labelsPred)
    (math.pow(beta, 2) + 1) * precision1 * recall1 / (math.pow(beta, 2) * precision1 + recall1)
  }

  def main(args: Array[String]): Unit = {
    val labelTrue = Array.fill(8)(1) ++ Array.fill(5)(2) ++ Array.fill(4)(3)
    val labelPred = Array(1, 1, 1, 1, 1, 2, 3, 3, 1, 2, 2, 2, 2, 2, 3, 3, 3)
    println(contingencyTable(labelTrue, labelPred))
  }
}


你可能感兴趣的:(clustering,scala)