Spark DAGScheduler 功能及源码解析

Spark中DAGScheduler的主要作用是将Job按照RDD的依赖关系划分成若干个TaskSet,也称为Stage;之后结合当前缓存情况及数据就近的原则,将Stage提交给TaskScheduler

private[spark]
class DAGScheduler(
    private[scheduler] val sc: SparkContext,
    private[scheduler] val taskScheduler: TaskScheduler,
    listenerBus: LiveListenerBus,
    mapOutputTracker: MapOutputTrackerMaster,
    blockManagerMaster: BlockManagerMaster,
    env: SparkEnv,
    clock: Clock = new SystemClock())
  extends Logging

从类的定义中看到,涉及到作为Spark入口的SparkContext;用于执行task的TaskScheduler;处理RDD计算过程中的Map信息的MapOutputTrackerMaster;以及管理block存储的BlockManagerMaster

RDD的action操作,比如count,reduce等,会触发SparkContext.runJob方法,后者实际最终调用的是DAGScheduler.submitJob方法

// DAGScheduler.submitJob
def submitJob[T, U](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    callSite: CallSite,
    allowLocal: Boolean,
    resultHandler: (Int, U) => Unit,
    properties: Properties): JobWaiter[U] = {
  // Check to make sure we are not launching a task on a partition that does not exist.
  val maxPartitions = rdd.partitions.length
  partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
    throw new IllegalArgumentException(
      "Attempting to access a non-existent partition: " + p + ". " +
        "Total number of partitions: " + maxPartitions)
  }

  val jobId = nextJobId.getAndIncrement()
  if (partitions.size == 0) {
    return new JobWaiter[U](this, jobId, 0, resultHandler)
  }

  assert(partitions.size > 0)
  val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
  val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
  // post方法将event加入队列,执行是由另外的线程遍历队列来处理
  eventProcessLoop.post(JobSubmitted(
    jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter,
    SerializationUtils.clone(properties)))
  waiter
}

JobSubmitted是继承了DAGSchedulerEvent特征的子类,DAGScheduler可以处理的事件类型都被包装成了DAGSchedulerEvent

eventProcessLoopDAGSchedulerEventProcessLoop类的实例,后者是DAGScheduler的私有类,继承了EventLoop类,主要通过调用onReceive方法来单线程的处理队列中的event

Notice:EventLoop.post方法只是将event装入队列,真正的处理是由单线程的eventThread来遍历队列,对取出的事件调用EventLoop.onReceive(event)方法。因此不同的线程可以同时提交事件,不会存在冲突,但不保证事件会立即被执行

DAGSchedulerEventProcessLoop覆盖了父类的onReceive方法,我们可以看到JobSubmitted对应的是DAGScheduler.handleJobSubmitted方法

// DAGSchedulerEventProcessLoop.onReceive
override def onReceive(event: DAGSchedulerEvent): Unit = event match {
  case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
    dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
      listener, properties)

  case StageCancelled(stageId) =>
    dagScheduler.handleStageCancellation(stageId)

  ...
}
// DAGScheduler.handleJobSubmitted
private[scheduler] def handleJobSubmitted(jobId: Int,
    finalRDD: RDD[_],
    func: (TaskContext, Iterator[_]) => _,
    partitions: Array[Int],
    allowLocal: Boolean,
    callSite: CallSite,
    listener: JobListener,
    properties: Properties) {
  var finalStage: ResultStage = null
  try {
    // New stage creation may throw an exception if, for example, jobs are run on a
    // HadoopRDD whose underlying HDFS files have been deleted.
    finalStage = newResultStage(finalRDD, partitions.size, jobId, callSite)
  } catch {
    case e: Exception =>
      logWarning("Creating new stage failed due to exception - job: " + jobId, e)
      listener.jobFailed(e)
      return
  }
  if (finalStage != null) {
    val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
    clearCacheLocs()
    logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format(
      job.jobId, callSite.shortForm, partitions.length, allowLocal))
    logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")")
    logInfo("Parents of final stage: " + finalStage.parents)
    logInfo("Missing parents: " + getMissingParentStages(finalStage))
    val shouldRunLocally =
      localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
    val jobSubmissionTime = clock.getTimeMillis()
    if (shouldRunLocally) {
      // Compute very short actions like first() or take() with no parent stages locally.
      listenerBus.post(
        SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties))
      runLocally(job)
    } else {
      jobIdToActiveJob(jobId) = job
      activeJobs += job
      finalStage.resultOfJob = Some(job)
      val stageIds = jobIdToStageIds(jobId).toArray
      val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
      listenerBus.post(
        SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
      submitStage(finalStage)
    }
  }
  submitWaitingStages()
}

DAGScheduler.handleJobSubmitted方法包含了大部分的处理逻辑,威廉来做下步骤细分

STEP 1:调用DAGScheduler.newResultStage方法创建了ResultStage对象,task的数量与partition的数量是一致的

// DAGScheduler.newResultStage
private def newResultStage(
    rdd: RDD[_],
    numTasks: Int,
    jobId: Int,
    callSite: CallSite): ResultStage = {
  val (parentStages: List[Stage], id: Int) = getParentStagesAndId(rdd, jobId)
  val stage: ResultStage = new ResultStage(id, rdd, numTasks, parentStages, jobId, callSite)

  // stageIdToStage是HashMap[Int, Stage]
  stageIdToStage(id) = stage
  updateJobIdStageIdMaps(jobId, stage)
  stage
}

ResultStage简单继承了Stage父类,用来表示job的最后一个stage

Stage表示具有相同shuffle dependenciestask集合;可以是shuffle map stage或者result stage,前者的输出是另一个Stage的输入,后者的结果就直接返回

private[spark] abstract class Stage(
    val id: Int,
    val rdd: RDD[_],
    val numTasks: Int,
    val parents: List[Stage],
    val jobId: Int,
    val callSite: CallSite)
  extends Logging {

  val numPartitions = rdd.partitions.size

  /** Set of jobs that this stage belongs to. */
  val jobIds = new HashSet[Int]

  var pendingTasks = new HashSet[Task[_]]

  private var nextAttemptId: Int = 0

  val name = callSite.shortForm
  val details = callSite.longForm

  // StageInfo包含了属于该Stage的所有RDD的RDDInfo
  var latestInfo: StageInfo = StageInfo.fromStage(this)

  /** Return a new attempt id, starting with 0. */
  def newAttemptId(): Int = {
    val id = nextAttemptId
    nextAttemptId += 1
    id
  }

  def attemptId: Int = nextAttemptId

  override final def hashCode(): Int = id
  override final def equals(other: Any): Boolean = other match {
    case stage: Stage => stage != null && stage.id == id
    case _ => false
  }
}

DAGScheduler.getParentStages方法取得依赖中的ShuffleMapStage,以此为不同stage的分界

// DAGScheduler.getParentStagesAndId
private def getParentStagesAndId(rdd: RDD[_], jobId: Int): (List[Stage], Int) = {
  val parentStages = getParentStages(rdd, jobId)
  val id = nextStageId.getAndIncrement()
  (parentStages, id)
}

// DAGScheduler.getParentStages
private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = {
  val parents = new HashSet[Stage]
  val visited = new HashSet[RDD[_]]
  // We are manually maintaining a stack here to prevent StackOverflowError
  // caused by recursively visiting
  val waitingForVisit = new Stack[RDD[_]]
  def visit(r: RDD[_]) {
    if (!visited(r)) {
      visited += r
      // Kind of ugly: need to register RDDs with the cache here since
      // we can't do it in its constructor because # of partitions is unknown
      for (dep <- r.dependencies) {
        dep match {
          case shufDep: ShuffleDependency[_, _, _] =>
            parents += getShuffleMapStage(shufDep, jobId)
          case _ =>
            waitingForVisit.push(dep.rdd)
        }
      }
    }
  }
  waitingForVisit.push(rdd)
  while (waitingForVisit.nonEmpty) {
    visit(waitingForVisit.pop())
  }
  parents.toList
}

// DAGScheduler.getShuffleMapStage
// ShuffleMapStage可被不同job共用
private def getShuffleMapStage(
    shuffleDep: ShuffleDependency[_, _, _],
    jobId: Int): ShuffleMapStage = {
  shuffleToMapStage.get(shuffleDep.shuffleId) match {
    case Some(stage) => stage
    case None =>
      // 注册所有的祖先shuffle dependencies
      registerShuffleDependencies(shuffleDep, jobId)
      // 注册当前shuffle dependency
      val stage = newOrUsedShuffleStage(shuffleDep, jobId)
      shuffleToMapStage(shuffleDep.shuffleId) = stage

      stage
  }
}

STEP 2:创建ActiveJob

private[spark] class ActiveJob(
    val jobId: Int,
    val finalStage: ResultStage,
    val func: (TaskContext, Iterator[_]) => _,
    val partitions: Array[Int],
    val callSite: CallSite,
    val listener: JobListener,
    val properties: Properties) {

  val numPartitions = partitions.length
  val finished = Array.fill[Boolean](numPartitions)(false)
  var numFinished = 0
}

STEP 3:判断是否本地运行,localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1

STEP 4:单Partition且没有dependency的RDD可以被本地运行,调用LiveListenerBus.post,传递SparkListenerJobStart事件

LiveListenerBus继承了AsynchronousListenerBus父类;类似于EventLooppost只是将SparkListenerEvent放入队列,有另外的线程去遍历队列,送达对应的SparkListener

调用DAGScheduler.runLocally方法

// DAGScheduler.runLocally
protected def runLocally(job: ActiveJob) {
  logInfo("Computing the requested partition locally")
  // 启用新的线程是为了防止Job运行时间过长,阻塞DAGScheduler的其他操作
  new Thread("Local computation of job " + job.jobId) {
    override def run() {
      runLocallyWithinThread(job)
    }
  }.start()
}

// DAGScheduler.runLocallyWithinThread
protected def runLocallyWithinThread(job: ActiveJob) {
  var jobResult: JobResult = JobSucceeded
  try {
    val rdd = job.finalStage.rdd
    val split = rdd.partitions(job.partitions(0))
    val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
    val taskContext =
      new TaskContextImpl(
        job.finalStage.id,
        job.partitions(0),
        taskAttemptId = 0,
        attemptNumber = 0,
        taskMemoryManager = taskMemoryManager,
        runningLocally = true)
    TaskContext.setTaskContext(taskContext)
    try {
      val result = job.func(taskContext, rdd.iterator(split, taskContext))
      job.listener.taskSucceeded(0, result)
    } finally {
      taskContext.markTaskCompleted()
      TaskContext.unset()
      // Note: this memory freeing logic is duplicated in Executor.run(); when changing this,
      // make sure to update both copies.
      val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
      if (freedMemory > 0) {
        if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
          throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes")
        } else {
          logError(s"Managed memory leak detected; size = $freedMemory bytes")
        }
      }
    }
  } catch {
    case e: Exception =>
      val exception = new SparkDriverExecutionException(e)
      jobResult = JobFailed(exception)
      job.listener.jobFailed(exception)
    case oom: OutOfMemoryError =>
      val exception = new SparkException("Local job aborted due to out of memory error", oom)
      jobResult = JobFailed(exception)
      job.listener.jobFailed(exception)
  } finally {
    val s = job.finalStage
    // clean up data structures that were populated for a local job,
    // but that won't get cleaned up via the normal paths through
    // completion events or stage abort
    stageIdToStage -= s.id
    jobIdToStageIds -= job.jobId
    listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), jobResult))
  }
}

STEP 4:对于无法本地运行的stage,调用submitStage方法

// DAGScheduler.submitStage
private def submitStage(stage: Stage) {
  val jobId = activeJobForStage(stage)
  if (jobId.isDefined) {
    logDebug("submitStage(" + stage + ")")
    if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
      // 检查是否有未完成的祖先ShuffleMapStage,若有的话,优先提交
      val missing = getMissingParentStages(stage).sortBy(_.id)
      logDebug("missing: " + missing)
      if (missing.isEmpty) {
        logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
        submitMissingTasks(stage, jobId.get)
      } else {
        for (parent <- missing) {
          submitStage(parent)
        }
        waitingStages += stage
      }
    }
  } else {
    abortStage(stage, "No active job for stage " + stage.id)
  }
}
// DAGScheduler.submitMissingTasks
private def submitMissingTasks(stage: Stage, jobId: Int) {
  logDebug("submitMissingTasks(" + stage + ")")
  // Get our pending tasks and remember them in our pendingTasks entry
  stage.pendingTasks.clear()

  // 判断哪些partition需要计算
  val partitionsToCompute: Seq[Int] = {
    stage match {
      // ShuffleMapStage判断MapStatus
      case stage: ShuffleMapStage =>
        (0 until stage.numPartitions).filter(id => stage.outputLocs(id).isEmpty)
      // ResultStage判断job是否完成
      case stage: ResultStage =>
        val job = stage.resultOfJob.get
        (0 until job.numPartitions).filter(id => !job.finished(id))
    }
  }

  val properties = jobIdToActiveJob.get(stage.jobId).map(_.properties).orNull

  runningStages += stage
  // SparkListenerStageSubmitted should be posted before testing whether tasks are
  // serializable. If tasks are not serializable, a SparkListenerStageCompleted event
  // will be posted, which should always come after a corresponding SparkListenerStageSubmitted event.
  stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size))
  outputCommitCoordinator.stageStart(stage.id)
  listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))

  // 将二进制码广播至executor,每个task会得到一份RDD的备份,这为task有可能修改引用的对象的场景提供了良好的隔离,比如在Hadoop中,JobConf/Configuration对象就不是线程安全的
  var taskBinary: Broadcast[Array[Byte]] = null
  try {
    // ShuffleMapTask, 序列化并广播(rdd, shuffleDep)
    // ResultTask, 序列化并广播(rdd, func)
    val taskBinaryBytes: Array[Byte] = stage match {
      case stage: ShuffleMapStage =>
        closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array()
      case stage: ResultStage =>
        closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func): AnyRef).array()
    }

    taskBinary = sc.broadcast(taskBinaryBytes)
  } catch {
    // In the case of a failure during serialization, abort the stage.
    case e: NotSerializableException =>
      abortStage(stage, "Task not serializable: " + e.toString)
      runningStages -= stage

      // Abort execution
      return
    case NonFatal(e) =>
      abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
      runningStages -= stage
      return
  }

  val tasks: Seq[Task[_]] = try {
    stage match {
      // ShuffleMapStage创建ShuffleMapTask
      case stage: ShuffleMapStage =>
        partitionsToCompute.map { id =>
          val locs = getPreferredLocs(stage.rdd, id)
          val part = stage.rdd.partitions(id)
          new ShuffleMapTask(stage.id, taskBinary, part, locs)
        }

      // ResultStage创建ResultTask
      case stage: ResultStage =>
        val job = stage.resultOfJob.get
        partitionsToCompute.map { id =>
          val p: Int = job.partitions(id)
          val part = stage.rdd.partitions(p)
          val locs = getPreferredLocs(stage.rdd, p)
          new ResultTask(stage.id, taskBinary, part, locs, id)
        }
    }
  } catch {
    case NonFatal(e) =>
      abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
      runningStages -= stage
      return
  }

  if (tasks.size > 0) {
    logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
    stage.pendingTasks ++= tasks
    logDebug("New pending tasks: " + stage.pendingTasks)
    // 将Tasks封装进TaskSet,递交给TaskScheduler
    taskScheduler.submitTasks(
      new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
    stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
  } else {
    // Because we posted SparkListenerStageSubmitted earlier, we should mark
    // the stage as completed here in case there are no tasks to run
    markStageAsFinished(stage, None)

    val debugString = stage match {
      case stage: ShuffleMapStage =>
        s"Stage ${stage} is actually done; " +
          s"(available: ${stage.isAvailable}," +
          s"available outputs: ${stage.numAvailableOutputs}," +
          s"partitions: ${stage.numPartitions})"
      case stage : ResultStage =>
        s"Stage ${stage} is actually done; (partitions: ${stage.numPartitions})"
    }
    logDebug(debugString)
  }
}

DAGScheduler.getPreferredLocs基于数据就近原则获取运行Task的最佳位置,调用的是DAGScheduler.getPreferredLocsInternal方法

// DAGScheduler.getPreferredLocsInternal
private def getPreferredLocsInternal(
    rdd: RDD[_],
    partition: Int,
    visited: HashSet[(RDD[_], Int)]): Seq[TaskLocation] = {
  // 迭代寻找祖先窄依赖RDD时判断是否已经访问过
  if (!visited.add((rdd, partition))) {
    // Nil has already been returned for previously visited partitions.
    return Nil
  }
  // 优先查询缓存的地址
  val cached = getCacheLocs(rdd)(partition)
  if (cached.nonEmpty) {
    return cached
  }
  // 再优先考虑RDD的preferredLocations
  val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
  if (rddPrefs.nonEmpty) {
    return rddPrefs.map(TaskLocation(_))
  }
  // 针对有窄依赖的RDD,获取最早的祖先窄依赖RDD的partition位置
  rdd.dependencies.foreach {
    case n: NarrowDependency[_] =>
      for (inPart <- n.getParents(partition)) {
        val locs = getPreferredLocsInternal(n.rdd, inPart, visited)
        if (locs != Nil) {
          return locs
        }
      }
    case _ =>
  }
  Nil
}

// DAGScheduler.getCacheLocs,这是一个线程安全的方法
def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized {
  // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times
  if (!cacheLocs.contains(rdd.id)) {
    val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
    val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms =>
      bms.map(bm => TaskLocation(bm.host, bm.executorId))
    }
    cacheLocs(rdd.id) = locs
  }
  cacheLocs(rdd.id)
}

STEP 5:将Tasks封装进TaskSet提交给TaskScheduler运行

TaskScheduler.submitTasks(
      new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))

至此,DAGScheduler的工作就基本结束了,威廉将在下一篇文章中解读TaskScheduler的源码实现

你可能感兴趣的:(spark)