[Scala] NDCG 的 Scala 实现

一、关于 NDCG

[LTR] 信息检索评价指标(RP/MAP/DCG/NDCG/RR/ERR)

二、代码实现

1、训练数据的加载解析

import scala.io.Source

/*
* 训练行数据
* */
case class TrainDataRow(target: Int, qid: Int, features: Array[Double])

object TrainDataRow {
  // 加载文件数据
  // 格式:
  //  .=.  qid: : : ... : # 
  //  .=. 
  //  .=. 
  //  .=. 
  //  .=. 
  //  .=. 
  def loadFile(file: String): List[TrainDataRow] = {
    Source.fromFile(file).getLines.toList.par.map(x => {
      val strArray = x.split(' ')
      val label = strArray(0).toInt
      val qid = strArray(1).split(':')(1).toInt
      val fValArray = strArray.drop(2).map(x => x.split(':')(1).toDouble)
      new TrainDataRow(label, qid, fValArray)
    }).toList
  }
}

2、NDCG 的实现

object NDCG {
  /*
  * 计算 NDCG 分值
  * */
  def score(rows: List[TrainDataRow], k: Int): Double = {
    val size = k.min(rows.length - 1)
    // 理想 DCG
    var idealDcg: Double = 0
    val sortedList = rows.sortWith((x, y) => x.target > y.target)
    for (i <- 0 to size) {
      // 计算累计效益
      val gain = (1 << sortedList(i).target) - 1
      // 计算折扣因子
      val discount = 1.0 / (Math.log(i + 2) / Math.log(2))
      idealDcg += gain * discount
    }
    if (idealDcg > 0) {
      var dcg: Double = 0
      for (i <- 0 to size) {
        // 计算累计效益
        val gain = (1 << rows(i).target) - 1
        // 计算折扣因子
        val discount = 1.0 / (Math.log(i + 2) / Math.log(2))
        dcg += gain * discount
      }
      dcg / idealDcg
    }
    else 0
  }
}

3、训练数据集的 NDCG 计算

def calcNDCG(trainDataFile: String, k: Int): Double = {
  println("开始计算...")
  val start = System.nanoTime()
  val data = TrainDataRow.loadFile(trainDataFile) // 加载训练数据文件
  println("数据量:" + data.length + ",用时:" + (System.nanoTime() - start) / 1000000 + " ms")
  val grpData: Map[Int, List[TrainDataRow]] = data.groupBy(_.qid) // 根据 qid 分组
  val resultNDCG = grpData.map(x => NDCG.score(x._2, k)).sum / grpData.size
  println(s"NDCG@$k: $resultNDCG")
  val end = System.nanoTime()
  println("计算运行时间:" + (end - start) / 1000000 + " ms")
  resultNDCG
}

 

by. Memento

转载于:https://www.cnblogs.com/memento/p/8675800.html

你可能感兴趣的:(scala)