spark checkpoint 原理

1. 在 rdd 上调用 checkpoint() 方法,并没有立刻执行

只是在 rdd 上创建了一个 ReliableRDDCheckpointData 对象, 该对象包含 checkpoint 进度的 CheckpointState 枚举标记 , 初始化为 Initialized

// RDD.scala
private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None

def checkpoint(): Unit = RDDCheckpointData.synchronized {
    if (context.checkpointDir.isEmpty) {
      throw new SparkException("Checkpoint directory has not been set in the SparkContext")
    } else if (checkpointData.isEmpty) {
      checkpointData = Some(new ReliableRDDCheckpointData(this))
    }
}

用户调用的 checkpoint() 方法做了两件事:
(1)如果没设置 checkpoint 目录,执行该方法直接报错
(2)设置了 checkpoint 目录,则 new 了一个包含该 rdd 引用的 ReliableRDDCheckpointData 对象,赋值给 RDD 的 checkpointData 属性 。该对象是否为空表示 rdd 是否需要 checkpoint;其 cpState 记录了checkpoint 的进度, 初始化为 Initialized

// 状态枚举类
private[spark] object CheckpointState extends Enumeration {
  type CheckpointState = Value
  val Initialized, CheckpointingInProgress, Checkpointed = Value
}

// ReliableRDDCheckpointData.scala
class ReliableRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T])
     extends RDDCheckpointData[T](rdd) 

// RDDCheckpointData.scala
abstract class RDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T])
  // checkpoint 执行状态的枚举
  import CheckpointState._

  // checkpoint 状态初始化为 Initialized
  protected var cpState = Initialized

  // The RDD that contains our checkpointed data
  private var cpRDD: Option[CheckpointRDD[T]] = None
  ... ...
}

2. 何时执行真正的 checkPoint 动作?

结论:在碰到一个 action 算子时, 才会执行真正的checkPoint 动作。
首先,任何一个 action 算子, 都会执行 SparkContext 的 runJob() 方法

// RDD.scala
def foreach(f: T => Unit): Unit = withScope {
    val cleanF = sc.clean(f)
    sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}

runJob() 会先执行 rdd 的 dag 计算,计算完毕后再开启一个任务执行 checkpoin 操作。

def runJob[T, U: ClassTag](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      resultHandler: (Int, U) => Unit): Unit = {
    ...
    dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
    ...
    rdd.doCheckpoint()
  }

checkpoint 的具体流程写在 RDDCheckpointData 类的 checkpoint() 方法里

// RDDCheckpointData.scala
final def checkpoint(): Unit = {
    // 1. 判断状态是 Initialized 的话,才执行操作,更改状态为 CheckpointingInProgress
    RDDCheckpointData.synchronized {
      if (cpState == Initialized) {
        cpState = CheckpointingInProgress
      } else {
        return
      }
    }
  
   // 2. 写在子类 ReliableRDDCheckpointData 真正的实现
    val newRDD = doCheckpoint()

    // 3. 更新 rdd 的状态并清空 rdd 的血缘
    RDDCheckpointData.synchronized {
      cpRDD = Some(newRDD)
      cpState = Checkpointed
      // 该方法将 rdd 的 dependencies_ 属性置为 null
      //  private var dependencies_ : Seq[Dependency[_]]
      rdd.markCheckpointed()
    }
}

// 2. 写在子类 ReliableRDDCheckpointData.scala 真正的实现
protected override def doCheckpoint(): CheckpointRDD[T] = {
    val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir)
    ... ...
    newRDD
}

writeRDDToCheckpointDirectory() 方法

def writeRDDToCheckpointDirectory[T: ClassTag](
      originalRDD: RDD[T],
      checkpointDir: String,
      blockSize: Int = -1): ReliableCheckpointRDD[T] = {

    val sc = originalRDD.sparkContext

    // Create the output path for the checkpoint
    val checkpointDirPath = new Path(checkpointDir)
    val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration)
    ... ... 
    // 出现了 runJob, 可见是新启动一个任务, 按照 rdd 的 dependencies 血缘关系再次计算后写入
    sc.runJob(originalRDD,
      writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _)
    ... ...
    newRDD
}

为了解决 checkpoint 会导致重新根据血缘计算的问题,要先对 rdd 进行 cache(),这样即使新开任务,也不会重新计算 RDD,而是直接从缓存中取。这个直接取缓存的过程在 RDD.iterator() 方法.
我们知道,每个任务都会最终转化为 ShuffleMapTask 或 ResultTask , 这些 task 会调用获取 RDD.iterator() 获取每个 partition 的数据。

// RDD.scala
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      // 有缓存时直接获取
      getOrCompute(split, context)
    } else {
      computeOrReadCheckpoint(split, context)
    }
}

private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
    // 委托 BlockManager 中读取缓存
    val blockId = RDDBlockId(id, partition.index)
    var readCachedBlock = true
    // This method is called on executors, so we need call SparkEnv.get instead of sc.env.
    SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
       ... ...
    }) match {
      ... ...
    }
}

3. 总结带 checkpoint 的写法

rdd.persist     // cache
rdd.checkpoint()
rdd.foreach()  // action算子出发所有计算
rdd.unpersist()

你可能感兴趣的:(spark checkpoint 原理)