scala源代码走读

今天先分析位于scala.mllib.clustering中最简单的KMeans模型,即文件KMeans.scala。

KMeans作为较简单的聚类算法,mllib中KMeans的实现方法也很简单。

KMeans类的定义
class KMeans private (
    private var k: Int,                        // 簇的个数
    private var maxIterations: Int,            // 模型迭代次数
    private var initializationMode: String,    // 初始化簇内中心点的算法
    private var initializationSteps: Int,      // 默认值是2,按照源代码说法这个一般不用调整
    private var epsilon: Double,               // 用于判断聚类中心收敛的距离阈值
    private var seed: Long,                    // 距离度量方法
    private var distanceMeasure: String)
    
    private var initialModel: Option[KMeansModel] = None  // 可以人为选择初始簇内中心点
    def setInitialModel(model: KMeansModel): this.type = {
        require(model.k == k, "mismatched cluster count")
        initialModel = Some(model)
        this
    }
    

初始化簇内中心点的算法分为"random" or "k-means||"

KMeans模型收敛过程的代码如下:

private[spark] def run(
      data: RDD[Vector],
      instr: Option[Instrumentation]): KMeansModel = {

    // KMeans模型需要迭代多次,因此数据需要被缓存到cache中
    if (data.getStorageLevel == StorageLevel.NONE) {
      logWarning("The input data is not directly cached, which may hurt performance if its"
        + " parent RDDs are also uncached.")
    }

    // Compute squared norms and cache them.  计算数据的二范数,即计算x*x;
    val norms = data.map(Vectors.norm(_, 2.0))
    norms.persist()                        // 对数据进行持久化操作,此处持久化到内存中
    // 将数据做成 (数据点, norm值)
    val zippedData = data.zip(norms).map { case (v, norm) => 
      new VectorWithNorm(v, norm)
    }
    val model = runAlgorithm(zippedData, instr)
    norms.unpersist()

    // Warn at the end of the run as well, for increased visibility.
    if (data.getStorageLevel == StorageLevel.NONE) {
      logWarning("The input data was not directly cached, which may hurt performance if its"
        + " parent RDDs are also uncached.")
    }
    model
  }

关于scala中数据的持久化,是在数据需要被多次使用时,通过数据持久化,以减少IO时间从而节约计算时间,详细可参考https://blog.csdn.net/asd136912/article/details/80885136

持久化的方法分为cache() 和 persist(),区别在于cache方法默认且只能缓存到内存,而persist方法自定义缓存级别

KMeans的核心

private def runAlgorithm(
      data: RDD[VectorWithNorm],
      instr: Option[Instrumentation]): KMeansModel = {

    val sc = data.sparkContext

    val initStartTime = System.nanoTime()

    // 距离度量方法,用于计算点点之间的距离
    val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)
    // 生成簇内中心点
    val centers = initialModel match {
      case Some(kMeansCenters) =>
        kMeansCenters.clusterCenters.map(new VectorWithNorm(_))  // 用户自己选择的中心点
      case None =>
        if (initializationMode == KMeans.RANDOM) {
          initRandom(data)                                       // 随机选择中心点
        } else {
          initKMeansParallel(data, distanceMeasureInstance)  // 使用Parallel方法生成中心点
        }
    }
    val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
    logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")

    var converged = false
    var cost = 0.0
    var iteration = 0

    val iterationStartTime = System.nanoTime()

    instr.foreach(_.logNumFeatures(centers.head.vector.size))

    // Execute iterations of Lloyd's algorithm until converged
    while (iteration < maxIterations && !converged) {
      val costAccum = sc.doubleAccumulator   // 创建累加器
      val bcCenters = sc.broadcast(centers)  // 创建广播变量

      // Find the new centers
      val collected = data.mapPartitions { points =>
        val thisCenters = bcCenters.value
        val dims = thisCenters.head.vector.size

        val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims)) // 创建二维数组,行数为中心点个数,列数为点的纬度
        val counts = Array.fill(thisCenters.length)(0L) // 创建一维数组

        points.foreach { point =>
          // 计算点与所有中心点的距离,返回最近的点及距离
          val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)  
          costAccum.add(cost)
          distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
          counts(bestCenter) += 1  // 簇内个数计数
        }
        // sum中记录了同一个簇中所有点相加的和
        counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
      }.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
        axpy(1.0, sum2, sum1)
        (sum1, count1 + count2)
      }.collectAsMap()

      if (iteration == 0) {
        instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
      }

      val newCenters = collected.mapValues { case (sum, count) =>
        distanceMeasureInstance.centroid(sum, count)
      }

      bcCenters.destroy()  // 中心点重新计算了

      // Update the cluster centers and costs
      converged = true
      newCenters.foreach { case (j, newCenter) =>
        if (converged &&
         !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) 
        // 判断更新的中心点与原有中心点的距离,如果小于阈值,则认为算法收敛
        {
          converged = false
        }
        centers(j) = newCenter
      }

      cost = costAccum.value
      iteration += 1
    }

    val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
    logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.")

    if (iteration == maxIterations) {
      logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
    } else {
      logInfo(s"KMeans converged in $iteration iterations.")
    }

    logInfo(s"The cost is $cost.")

    new KMeansModel(centers.map(_.vector), distanceMeasure, cost, iteration)
  }

 

你可能感兴趣的:(源代码走读,scala,scala源码走读)