Spark源码剖析——Action操作、runJob流程

文章目录

  • Spark源码剖析——Action操作、runJob流程
    • 当前环境与版本
    • 前言
    • 供分析的代码
    • collect 源码分析
    • DAGScheduler中的处理
    • TaskScheduler中的处理
    • CoarseGrainedSchedulerBackend、DriverEndpoint中的处理
    • Executor中的处理

Spark源码剖析——Action操作、runJob流程

当前环境与版本

环境 版本
JDK java version “1.8.0_231” (HotSpot)
Scala Scala-2.11.12
Spark spark-2.4.4

前言

  • 在前面SparkSubmit提交流程中,我们已经讨论了一个Spark应用的提交流程、申请Application、启动Driver、启动Excutor,最后到Driver反射调用用户编写的类的main方法。
  • 接着,在SparkContext实例化中,我们大致的看了用户代码中的SparkContext实例化的过程。
  • 在本篇中,主要讨论在SparkContext实例化后,接着对用户代码的处理,一个Action操作是如何提交任务的(runJob)。
  • 除此之外的,也可以看看一个简易的实现 实现链式、惰性特点的容器
  • 我做了一幅触发Action的流程示意图(collect),如下

Spark源码剖析——Action操作、runJob流程_第1张图片

供分析的代码

  • 下面我们来看一份简单的示例代码(词频统计)
    object MySparkApp {
    
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf()
          .setAppName("MySparkApp")
        val spark = SparkSession.builder()
          .config(conf)
          .getOrCreate()
    
        // 数据源(csv,逗号分隔)
        val srcRDD = spark.sparkContext.textFile("/test.csv")
    
        // 词频统计
        val resultRDD = srcRDD.flatMap(_.split(","))
          .map((_, 1))
          .reduceByKey(_ + _)
    
        // 触发Action
        val resultArr = resultRDD.collect()
    
        // 省略其他操作
        // ...
    
        spark.stop()
      }
    
    }
    
  • 整个代码较为简单,构建SparkSession,读取原始数据,进行词频统计(Shuffle),最后再利用collect汇总数据到Driver端。
  • 我们主要需要看的是利用collect触发Action的代码。

collect 源码分析

  • org.apache.spark.rdd.RDDorg.apache.spark.SparkContext
  • Ctrl + 鼠标左键点击collect来到RDD的源码中(org.apache.spark.rdd.RDD),代码如下。
      def collect(): Array[T] = withScope {
        val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
        Array.concat(results: _*)
      }
    
  • 此处比较简单,调用sc.runJob(...),然后返回一个Array。继续多点击几次runJob,我们来到SparkContext的2077~2084行,代码如下。
      def runJob[T, U: ClassTag](
          rdd: RDD[T],
          func: (TaskContext, Iterator[T]) => U,
          partitions: Seq[Int]): Array[U] = {
        val results = new Array[U](partitions.size)
        runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res)
        results
      }
    
  • 需要注意的是,此处将Array利用函数(index, res) => results(index) = res传入了runJob方法,后面会对其进行赋值,这样就可以返回结果数组results了。再继续点击runJob方法,来到SparkContext的2047~2064行,代码如下。
      def runJob[T, U: ClassTag](
          rdd: RDD[T],
          func: (TaskContext, Iterator[T]) => U,
          partitions: Seq[Int],
          resultHandler: (Int, U) => Unit): Unit = {
        if (stopped.get()) {
          throw new IllegalStateException("SparkContext has been shutdown")
        }
        val callSite = getCallSite
        // cleanedFunc用于确认闭包可序列化,防止func中存在不可序列化的情况
        val cleanedFunc = clean(func)
        logInfo("Starting job: " + callSite.shortForm)
        if (conf.getBoolean("spark.logLineage", false)) {
          logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString)
        }
        // 将任务交由DAGScheduler处理
        dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get)
        // 使用命令行提交时显示的进度条,结束
        progressBar.foreach(_.finishAll())
        rdd.doCheckpoint()
      }
    
  • 此部分代码中,最重要的是调用dagScheduler.runJob(...),将任务提交到了DAGScheduler。需要注意的是resultHandler,这个函数会一直往下传。

DAGScheduler中的处理

  • org.apache.spark.scheduler.DAGScheduler
  • 接着我来看DAGScheduler中的runJob方法,代码如下。
      def runJob[T, U](
          rdd: RDD[T],
          func: (TaskContext, Iterator[T]) => U,
          partitions: Seq[Int],
          callSite: CallSite,
          resultHandler: (Int, U) => Unit,
          properties: Properties): Unit = {
        val start = System.nanoTime
        // 关键
        val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
        // 阻塞等待任务完成
        ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf)
        waiter.completionFuture.value.get match {
          case scala.util.Success(_) => // 成功
            logInfo("Job %d finished: %s, took %f s".format
              (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
          case scala.util.Failure(exception) => // 失败
            logInfo("Job %d failed: %s, took %f s".format
              (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
            // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
            val callerStackTrace = Thread.currentThread().getStackTrace.tail
            exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
            throw exception
        }
      }
    
  • 此处利用submitJob构建了一个JobWaiter,然后进行了阻塞等待。再来到submitJob(...)的代码中,如下。
      def submitJob[T, U](
          rdd: RDD[T],
          func: (TaskContext, Iterator[T]) => U,
          partitions: Seq[Int],
          callSite: CallSite,
          resultHandler: (Int, U) => Unit,
          properties: Properties): JobWaiter[U] = {
        // 检查分区,确保不在不存在的分区上启动任务
        val maxPartitions = rdd.partitions.length
        partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
          throw new IllegalArgumentException(
            "Attempting to access a non-existent partition: " + p + ". " +
              "Total number of partitions: " + maxPartitions)
        }
    	
    	// 任务id
        val jobId = nextJobId.getAndIncrement()
        if (partitions.size == 0) {
          // 分区为0,说明不需要运行任务
          return new JobWaiter[U](this, jobId, 0, resultHandler)
        }
    
        assert(partitions.size > 0)
        // 重点,封装 JobSubmitted,并提交到队列
        val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
        val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
        eventProcessLoop.post(JobSubmitted(
          jobId, rdd, func2, partitions.toArray, callSite, waiter,
          SerializationUtils.clone(properties)))
        waiter
      }
    
  • 此部分代码的关键在最后几行,封装了JobSubmitted,并利用eventProcessLoop将其提交到了队列。
  • 有兴趣的朋友可以看看这个EventLoop(此处的实现类是DAGSchedulerEventProcessLoop),其内部有一个LinkedBlockingDeque队列,启动后,会有一个守护线程不断轮询队列,取出元素,并调用onReceive进行处理。在此处,则是最终调用到了DAGScheduler的doOnReceive方法,匹配到JobSubmitted
      private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
        case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
          dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)
    
        case ... 省略 ...
    
  • 另外,你也可以按照之前在Master、Worker启动流程中提到的快速查看技巧,快速定位到是什么地方接收到了JobSubmitted
  • 接着,我们来看DAGScheduler的dagScheduler.handleJobSubmitted(...)方法,代码如下。
      private[scheduler] def handleJobSubmitted(jobId: Int,
          finalRDD: RDD[_],
          func: (TaskContext, Iterator[_]) => _,
          partitions: Array[Int],
          callSite: CallSite,
          listener: JobListener,
          properties: Properties) {
        var finalStage: ResultStage = null
        try {
          // 解析划分Stage,根据ShuffleDependency
          // 此处返回的是最后一个Stage,即ResultStage
          finalStage = createResultStage(finalRDD, func, partitions, jobId, callSite)
        } catch {
          // 异常处理,省略...
        }
        // Job submitted, clear internal data.
        barrierJobIdToNumTasksCheckFailures.remove(jobId)
    
        // 封装Job相关信息
        val job = new ActiveJob(jobId, finalStage, callSite, listener, properties)
        clearCacheLocs()
        // 省略部分代码...
    
        val jobSubmissionTime = clock.getTimeMillis()
        jobIdToActiveJob(jobId) = job
        activeJobs += job
        finalStage.setActiveJob(job)
        val stageIds = jobIdToStageIds(jobId).toArray
        val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
        // 提交消息到listenerBus,方便UI界面查看到任务提交
        listenerBus.post(
          SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
        // 最后,提交Stage
        submitStage(finalStage)
      }
    
  • 此处,有两个关键点
    • createResultStage(...) 根据宽窄依赖划分了Stage(ShuffleMapStage、ResultStage),后续再来讲该部分代码,你也可以自己看看(递归较麻烦)
    • 调用submitStage(...)提交Stage
  • 再来看submitStage(...)方法
      private def submitStage(stage: Stage) {
        val jobId = activeJobForStage(stage)
        if (jobId.isDefined) {
          logDebug("submitStage(" + stage + ")")
          // Stage:不是正在等待的、不是正在运行的、不是失败的
          if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
            // 往上找Stage,因为最开始提交进来的是最后一个Stage
            val missing = getMissingParentStages(stage).sortBy(_.id)
            logDebug("missing: " + missing)
            if (missing.isEmpty) {
              logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
              // 上面没有Stage了,那就正式提交Stage了
              submitMissingTasks(stage, jobId.get)
            } else {
              for (parent <- missing) {
             	// 递归
                submitStage(parent)
              }
              waitingStages += stage
            }
          }
        } else {
          abortStage(stage, "No active job for stage " + stage.id, None)
        }
      }
    
  • 因为最开始解析Stage后,返回的是最后一个Stage,因此需要递归往上找到最前面的Stage,再提交Stage。
  • 我们来看看是DAGScheduler如何提交任务的,方法 submitMissingTasks(...)较长(1083~1232行),我们来看其中的关键点。
  • 调用getPreferredLocs,计算出Task的最佳位置(1105~1124行)
        // 利用getPreferredLocs获取最优的处理位置 
        val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
          stage match {
            case s: ShuffleMapStage => // 如果是ShuffleMapStage
              partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
            case s: ResultStage => // 如果是ResultStage
              partitionsToCompute.map { id =>
                val p = s.partitions(id)
                (id, getPreferredLocs(stage.rdd, p))
              }.toMap
          }
        } catch {
           // 省略代码
        }
    
        stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)
    
  • 广播(1140~1176行)
        var taskBinary: Broadcast[Array[Byte]] = null
        var partitions: Array[Partition] = null
        try {
          var taskBinaryBytes: Array[Byte] = null
          
          RDDCheckpointData.synchronized {
            taskBinaryBytes = stage match {
              case stage: ShuffleMapStage => // rdd, shuffleDep
                JavaUtils.bufferToArray(
                  closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
              case stage: ResultStage => // rdd, func
                JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
            }
    
            partitions = stage.rdd.partitions
          }
    
          taskBinary = sc.broadcast(taskBinaryBytes)
        } catch {
          // 省略代码
        }
    
  • 序列化Task(1178~1208行)
        val tasks: Seq[Task[_]] = try {
          val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array()
          stage match {
            case stage: ShuffleMapStage => // 生成ShuffleMapTask,附带之前的广播taskBinary
              stage.pendingPartitions.clear()
              partitionsToCompute.map { id =>
                val locs = taskIdToLocations(id)
                val part = partitions(id)
                stage.pendingPartitions += id
                new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
                  taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
                  Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier())
              }
    
            case stage: ResultStage => // 生成ResultTask,附带之前的广播taskBinary
              partitionsToCompute.map { id =>
                val p: Int = stage.partitions(id)
                val part = partitions(p)
                val locs = taskIdToLocations(id)
                new ResultTask(stage.id, stage.latestInfo.attemptNumber,
                  taskBinary, part, locs, id, properties, serializedTaskMetrics,
                  Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
                  stage.rdd.isBarrier())
              }
          }
        } catch {
          // 省略代码
        }
    
  • 封装TaskSet,利用TaskScheduler提交任务(1210~1231行)
         // 大于0,说明有任务
        if (tasks.size > 0) {
          logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
            s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
            // 利用TaskScheduler提交任务
          taskScheduler.submitTasks(new TaskSet(
            tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties))
        } else {
           // 对于没有Task的处理
          // 省略部分代码
        }
    

TaskScheduler中的处理

  • org.apache.spark.scheduler.TaskSchedulerImpl
  • 我们来看TaskScheduler中的处理,submitTasks(...)方法如下。
      override def submitTasks(taskSet: TaskSet) {
        val tasks = taskSet.tasks
        logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
        this.synchronized {
          // 将TaskSet转换为TaskSetManager
          val manager = createTaskSetManager(taskSet, maxTaskFailures)
          val stage = taskSet.stageId
          val stageTaskSets =
            taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
    
          stageTaskSets.foreach { case (_, ts) =>
            ts.isZombie = true
          }
          stageTaskSets(taskSet.stageAttemptId) = manager
          // 将TaskSetManager提交到任务调度的Pool中,包括FIFO、Fair两种
          schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
    
    	  // 启动定时器,检查任务是否已经运行了
          if (!isLocal && !hasReceivedTask) {
            starvationTimer.scheduleAtFixedRate(new TimerTask() {
              override def run() {
                if (!hasLaunchedTask) {
                  // 这段是平时提交任务后比较常见的日志(如果集群资源不够的话)
                  logWarning("Initial job has not accepted any resources; " +
                    "check your cluster UI to ensure that workers are registered " +
                    "and have sufficient resources")
                } else {
                  this.cancel()
                }
              }
            }, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
          }
          hasReceivedTask = true
        }
        // 重要,利用SchedulerBackend发消息给Driver类
        backend.reviveOffers()
      }
    
  • 此部分代码主要做了两件事
    • 将TaskSet转为TaskSetManager,并提交至了任务调度的Pool中
    • 利用SchedulerBackend发消息给Driver类,使其处理Pool中的任务
  • SchedulerBackend有多个实现类,后面我们用CoarseGrainedSchedulerBackend做示例

CoarseGrainedSchedulerBackend、DriverEndpoint中的处理

  • org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.DriverEndpoint
  • CoarseGrainedSchedulerBackend调用reviveOffers()后,会利用driverEndpoint的Ref向Driver发送一个消息ReviveOffers
  • 如果你还不熟悉RpcEndpoint、RpcEnv的话,可以利用在Master、Worker启动流程中提到的快速查看技巧,可以快速定位到DriverEndpoint(其实就在CoarseGrainedSchedulerBackend中,是个内部类)的receive方法,代码如下。
        override def receive: PartialFunction[Any, Unit] = {
          case 省略代码...
    
          case ReviveOffers =>
            makeOffers()
          case 省略代码...
        }
    
  • 接着,再看makeOffers()方法
        private def makeOffers() {
          // 确保待启动Task的Executor没问题
          val taskDescs = withLock {
            val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
            val workOffers = activeExecutors.map {
              case (id, executorData) =>
                new WorkerOffer(id, executorData.executorHost, executorData.freeCores,
                  Some(executorData.executorAddress.hostPort))
            }.toIndexedSeq
            scheduler.resourceOffers(workOffers)
          }
          // OK,没问题,那么启动Task
          if (!taskDescs.isEmpty) {
            launchTasks(taskDescs)
          }
        }
    
  • 这部分代码应该没什么问题,我们接着往下看launchTasks(...)方法
        private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
          // 循环处理Task
          for (task <- tasks.flatten) {
            val serializedTask = TaskDescription.encode(task)
            if (serializedTask.limit() >= maxRpcMessageSize) {
              // 如果超过了最大的消息限制,就发出提示
              // 省略代码
            }
            else {
              // 更新executor信息
              val executorData = executorDataMap(task.executorId)
              executorData.freeCores -= scheduler.CPUS_PER_TASK
    
              logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
                s"${executorData.executorHost}.")
    		  // 将序列化的Task封装为LaunchTask
    		  // 向Executor发送启动任务的消息
              executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
            }
          }
        }
    

Executor中的处理

  • org.apache.spark.executor.CoarseGrainedExecutorBackend
  • 接着再利用前面多次提到的快速查看技巧,可以定位到Executor处的CoarseGrainedExecutorBackend的receive方法。
      override def receive: PartialFunction[Any, Unit] = {
        case 省略代码...
    
        case LaunchTask(data) =>
          if (executor == null) {
            exitExecutor(1, "Received LaunchTask command but executor was null")
          } else {
            val taskDesc = TaskDescription.decode(data.value)
            logInfo("Got assigned task " + taskDesc.taskId)
            executor.launchTask(this, taskDesc)
          }
    
         case 省略代码...
    
  • 这样,最终就会调用Executor的launchTask方法处理Task了。
      def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
        // 封装一个TaskRunner
        val tr = new TaskRunner(context, taskDescription)
        runningTasks.put(taskDescription.taskId, tr)
        // 提交到线程池中
        threadPool.execute(tr)
      }
    
  • 想继续的朋友,可以再看TaskRunner的run方法 ^_^

你可能感兴趣的:(BigData,#,Spark,Scala)