Spark源码分析 – Checkpoint

CP的步骤

1. 首先如果RDD需要CP, 调用RDD.checkpoint()来mark
注释说了, 这个需要在Job被执行前被mark, 原因后面看, 并且最好选择persist这个RDD, 否则在存CP文件时需要重新computeRDD内容
并且当RDD被CP后, 所有dependencies都会被清除, 因为既然RDD已经被CP, 那么就可以直接从文件读取, 没有必要保留之前的parents的dependencies(保留只是为了replay)

2. 在SparkContext.runJob中, 最后会调用rdd.doCheckpoint()
如果前面已经mark过, 那么这里就会将rdd真正CP到文件中去, 这也是前面为什么说, mark必须在run job之前完成

  def runJob[T, U: ClassManifest](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      allowLocal: Boolean,
      resultHandler: (Int, U) => Unit) {
    val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,
      localProperties.get)
    rdd.doCheckpoint()
    result
  }

3. 在RDDCheckpointData.doCheckpoint中

会调用rdd.markCheckpointed(newRDD), 清除dependencies信息
并最终将状态设为, Checkpointed, 以表示完成CP

4. Checkpoint如何使用, 在RDD.computeOrReadCheckpoint中, 看到如果已经完成CP, 会直接从firstParent中读数据, 刚开始会觉得比较奇怪

    if (isCheckpointed) {
      firstParent[T].iterator(split, context)
    }

RDD.firstParent的定义如下, 就是从dependencies中取第一个dependency的rdd

 dependencies.head.rdd.asInstanceOf[RDD[U]]

而RDD.dependencies的定义如下, 可用看到在完成CP的情况下, 从dependencies中读到的其实就是CP RDD, 所以可以直接用

 checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {dependencies_}

RDD

  /** An Option holding our checkpoint RDD, if we are checkpointed */
  private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
  private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None

  // Avoid handling doCheckpoint multiple times to prevent excessive recursion
  private var doCheckpointCalled = false
 
  /**
   * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
   * directory set with SparkContext.setCheckpointDir() and all references to its parent
   * RDDs will be removed. This function must be called before any job has been
   * executed on this RDD. It is strongly recommended that this RDD is persisted in
   * memory, otherwise saving it on a file will require recomputation.
   */
  def checkpoint() {
    if (context.checkpointDir.isEmpty) {
      throw new Exception("Checkpoint directory has not been set in the SparkContext")
    } else if (checkpointData.isEmpty) {
      checkpointData = Some(new RDDCheckpointData(this)) // 创建RDDCheckpointData, 记录关于CP的所有信息
      checkpointData.get.markForCheckpoint()  // 标记为Marked
    }
  }  
  /**
   * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
   * after a job using this RDD has completed (therefore the RDD has been materialized and
   * potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
   */
  private[spark] def doCheckpoint() {
    if (!doCheckpointCalled) {
      doCheckpointCalled = true
      if (checkpointData.isDefined) { 
        checkpointData.get.doCheckpoint() // 调用RDDCheckpointData.doCheckpoint
      } else {
        dependencies.foreach(_.rdd.doCheckpoint()) // 当checkpointData没有被创建, 就checkpoint所有的父RDD
      }
    }
  }
  /**
   * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
   * created from the checkpoint file, and forget its old dependencies and partitions.
   */
  private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
    clearDependencies()  // dependencies_ = null
    partitions_ = null
    deps = null    // Forget the constructor argument for dependencies too
  }
 
  /**
   * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
   */
  private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] = {
    if (isCheckpointed) {
      firstParent[T].iterator(split, context)
    } else {
      compute(split, context)
    }
  }

 

RDDCheckpointData

 

/**
 * Enumeration to manage state transitions of an RDD through checkpointing
 * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
 */
private[spark] object CheckpointState extends Enumeration { // 定义Checkpoint过程中的各个状态
  type CheckpointState = Value
  val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
}

/**
 * This class contains all the information related to RDD checkpointing. Each instance of this class
 * is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as,
 * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations
 * of the checkpointed RDD.
 */
private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
  extends Logging with Serializable {

  import CheckpointState._

  // The checkpoint state of the associated RDD.
  var cpState = Initialized  // cp状态,先设为Initialized 

  // The file to which the associated RDD has been checkpointed to
  @transient var cpFile: Option[String] = None

  // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
  var cpRDD: Option[RDD[T]] = None

  // Mark the RDD for checkpointing
  def markForCheckpoint() {
    RDDCheckpointData.synchronized { // 先mark, 防止两个同时做cp
      if (cpState == Initialized) cpState = MarkedForCheckpoint
    }
  }

  // Is the RDD already checkpointed
  def isCheckpointed: Boolean = {
    RDDCheckpointData.synchronized { cpState == Checkpointed }
  }

  // Get the file to which this RDD was checkpointed to as an Option
  def getCheckpointFile: Option[String] = {
    RDDCheckpointData.synchronized { cpFile }
  }

  // Do the checkpointing of the RDD. Called after the first job using that RDD is over.
  def doCheckpoint() {
    // If it is marked for checkpointing AND checkpointing is not already in progress,
    // then set it to be in progress, else return
    RDDCheckpointData.synchronized {
      if (cpState == MarkedForCheckpoint) { // 只有已经是MarkedForCheckpoint状态, 才能继续CP
        cpState = CheckpointingInProgress
      } else {
        return
      }
    }

    // Create the output path for the checkpoint
    val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
    val fs = path.getFileSystem(new Configuration())
    if (!fs.mkdirs(path)) {
      throw new SparkException("Failed to create checkpoint path " + path)
    }

    // Save to file, and reload it as an RDD
    rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _)  // 将RDD存入文件, 完成cp
    val newRDD = new CheckpointRDD[T](rdd.context, path.toString) // 将RDD从文件重新载入

    // Change the dependencies and partitions of the RDD
    RDDCheckpointData.synchronized {
      cpFile = Some(path.toString)
      cpRDD = Some(newRDD)
      rdd.markCheckpointed(newRDD)   // 调用rdd.markCheckpointed清除deps
      cpState = Checkpointed // 将状态置为Checkpointed 
      RDDCheckpointData.clearTaskCaches()
      logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
    }
  }

  // Get preferred location of a split after checkpointing
  def getPreferredLocations(split: Partition): Seq[String] = {
    RDDCheckpointData.synchronized {
      cpRDD.get.preferredLocations(split)
    }
  }

  def getPartitions: Array[Partition] = {
    RDDCheckpointData.synchronized {
      cpRDD.get.partitions
    }
  }

  def checkpointRDD: Option[RDD[T]] = {
    RDDCheckpointData.synchronized {
      cpRDD
    }
  }
}

 

CheckpointRDD

 

private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}

/**
 * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
 */
private[spark]
class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
  extends RDD[T](sc, Nil) {

  @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)

  override def getPartitions: Array[Partition] = { // 应该是每个partitions是一个文件, 所以就是看cp目录里面的文件个数
    val cpath = new Path(checkpointPath)
    val numPartitions =
    // listStatus can throw exception if path does not exist.
    if (fs.exists(cpath)) {
      val dirContents = fs.listStatus(cpath)
      val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
      val numPart =  partitionFiles.size
      if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
          ! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) {
        throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
      }
      numPart
    } else 0

    Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
  }

  checkpointData = Some(new RDDCheckpointData[T](this))
  checkpointData.get.cpFile = Some(checkpointPath)

  override def getPreferredLocations(split: Partition): Seq[String] = {
    val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)))
    val locations = fs.getFileBlockLocations(status, 0, status.getLen)
    locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
  }

  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
    val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
    CheckpointRDD.readFromFile(file, context) // compute就是从cp文件里面把RDD读取出来
  }  

  override def checkpoint() {
    // Do nothing. CheckpointRDD should not be checkpointed.
  }
}

你可能感兴趣的:(spark)