spark 内核源码剖析十:Task原理

spark 内核源码剖析十:Task原理_第1张图片
image.png

下面我们从源码中跟追上面的流程
入口是org.apache.spark.executor.Executor.TaskRunner#run
在上一篇中,我们最后一步是把创建的线程(TaskRunner)放入线程中执行,这里
继续分析接下里的步骤

    override def run() {
      val deserializeStartTime = System.currentTimeMillis()
      Thread.currentThread.setContextClassLoader(replClassLoader)
      val ser = env.closureSerializer.newInstance()
      logInfo(s"Running $taskName (TID $taskId)")
      execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
      var taskStart: Long = 0
      startGCTime = gcTime

      try {
//对序列化的task数据进行反序列化
        val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
//通过网络通信,将需要的文件、资源、jar拷贝过来
        updateDependencies(taskFiles, taskJars)
//通过正式的反序列化操作,将整个task的数据集反序列化回来
        task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)

        // If this task has been killed before we deserialized it, let's quit now. Otherwise,
        // continue executing the task.
        if (killed) {
          // Throw an exception rather than returning, because returning within a try{} block
          // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
          // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
          // for the task.
          throw new TaskKilledException
        }

        attemptedTask = Some(task)
        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        env.mapOutputTracker.updateEpoch(task.epoch)

        // Run the actual task and measure its runtime.
//计算出task开始的
        taskStart = System.currentTimeMillis()
//最关键的地方是这里,执行task的run()方法
//这里的value,对于ShuffleMapTask来说,就是MapStatus,里面封装了ShuffleMaoTask计算的数据,输出的位置
//那么就会去联系MapOutputTracker,来获取上一个ShuffleMapTask的输出位置,然后通过网络拉取数据
//ResultTask也是一样
        val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
//计算出task的结束时间        
        val taskFinish = System.currentTimeMillis()

        // If the task has been killed, let's fail it.
        if (task.killed) {
          throw new TaskKilledException
        }
    //这个,其实就是会MapStatus进行了各种序列化和封装,后面发送给Driver(通过网络)
        val resultSer = env.serializer.newInstance()
        val beforeSerialization = System.currentTimeMillis()
        val valueBytes = resultSer.serialize(value)
        val afterSerialization = System.currentTimeMillis()

        for (m <- task.metrics) {
          m.setExecutorDeserializeTime(taskStart - deserializeStartTime)
          m.setExecutorRunTime(taskFinish - taskStart)
          m.setJvmGCTime(gcTime - startGCTime)
          m.setResultSerializationTime(afterSerialization - beforeSerialization)
        }

        val accumUpdates = Accumulators.values

        val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
        val serializedDirectResult = ser.serialize(directResult)
        val resultSize = serializedDirectResult.limit

        // directSend = sending directly back to the driver
        val serializedResult = {
          if (maxResultSize > 0 && resultSize > maxResultSize) {
            logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
              s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
              s"dropping it.")
            ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
          } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
            val blockId = TaskResultBlockId(taskId)
            env.blockManager.putBytes(
              blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
            logInfo(
              s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
            ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
          } else {
            logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
            serializedDirectResult
          }
        }
//这里是调用了Executor所在的CoarseGrainedExecutorBackend的statusUptate()方法,见后面
        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)

      } catch {
        case ffe: FetchFailedException => {
          val reason = ffe.toTaskEndReason
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
        }

        case _: TaskKilledException | _: InterruptedException if task.killed => {
          logInfo(s"Executor killed $taskName (TID $taskId)")
          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
        }

        case cDE: CommitDeniedException => {
          val reason = cDE.toTaskEndReason
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
        }

        case t: Throwable => {
          // Attempt to exit cleanly by informing the driver of our failure.
          // If anything goes wrong (or this was a fatal exception), we will delegate to
          // the default uncaught exception handler, which will terminate the Executor.
          logError(s"Exception in $taskName (TID $taskId)", t)

          val serviceTime = System.currentTimeMillis() - taskStart
          val metrics = attemptedTask.flatMap(t => t.metrics)
          for (m <- metrics) {
            m.setExecutorRunTime(serviceTime)
            m.setJvmGCTime(gcTime - startGCTime)
          }
          val reason = new ExceptionFailure(t, metrics)
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

          // Don't forcibly exit unless the exception was inherently fatal, to avoid
          // stopping other tasks unnecessarily.
          if (Utils.isFatalError(t)) {
            SparkUncaughtExceptionHandler.uncaughtException(t)
          }
        }
      } finally {
        // Release memory used by this thread for shuffles
        env.shuffleMemoryManager.releaseMemoryForThisThread()
        // Release memory used by this thread for unrolling blocks
        env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
        // Release memory used by this thread for accumulators
        Accumulators.clear()
        runningTasks.remove(taskId)
      }
    }

org.apache.spark.scheduler.Task#run

final def run(taskAttemptId: Long, attemptNumber: Int): T = {
//创建一个TaskContext,就是task的执行上下文,里面记录了task执行的一些全局性的数据
//比如,task重试了几次,task属于哪个stage,task要处理的是rdd的哪个partition等
    context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
      taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
    TaskContextHelper.setTaskContext(context)
    context.taskMetrics.setHostname(Utils.localHostName())
    taskThread = Thread.currentThread()
    if (_killed) {
      kill(interruptThread = false)
    }
    try {
//调用抽象方法runTask()
//Task的子类只有ShuffleMapTask和ResultTask,所以,这里是调用这两个的runTask()方法
      runTask(context)
    } finally {
      context.markTaskCompleted()
      TaskContextHelper.unset()
    }
  }

org.apache.spark.scheduler.ShuffleMapTask:一个ShuffleMapTask会将一个RDD的元素,切分为多个bucket,基于一个在ShuffleDependency中指定的partitioner,默认是hashPartitioner;ShufflerMapTask的runTask()方法有MapStatus返回值

override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
//对task要处理的rdd相关的数据,做一些反序列化操作
//这个rdd,是通过broadcast variable拿到的,
//多个task运行在多个executor中,都是并行运行,或者并发运行的,可能都不再一个地方,但是一个stage的task,
//其实要处理的rdd是一样的,那么这个task就通过broadcast variable直接拿到自己要处理的那个rdd数据
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

    metrics = Some(context.taskMetrics)
    var writer: ShuffleWriter[Any, Any] = null
    try {
//获取ShuffleManager
//从ShuffleManager中获取ShuffleWriter
      val manager = SparkEnv.get.shuffleManager
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
//最重要的就是这里(rdd.iterator)
//首先,调用rdd的iterator()方法,并且传入当前task要处理哪个partition
//核心的逻辑就在rdd的iterator()方法中,在这里,实现了针对rdd的某个partition,执行我们定义的算子,函数
//返回的数据,是通过ShuffleWriter,经过HashPartitioner进行分区之后,写入自己对应的分区bucket
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
//最后,返回结果mapStatus
//MapStatus里面封装了ShuffleMapTask计算后的数据,存储在哪里,其实就是BlockManager相关的信息
//BlockManager,是spark底层的内存数据,磁盘数据管理的组件
      return writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

org.apache.spark.rdd.RDD#iterator->org.apache.spark.rdd.RDD#computeOrReadCheckpoint——>
org.apache.spark.rdd.MapPartitionsRDD#compute

//compute就是针对RDD中某个partition执行我们给这个RDD定义的算子和函数
//这个f,可以理解成我们自己定义的算子和函数,但是spark内部进行了封装,还实现了一些其他的逻辑
//调用到这里为止,其实就是在针对rdd的partition,执行自定义的计算操作,并返回新的rdd的Partition的数据
override def compute(split: Partition, context: TaskContext) =
    f(context, split.index, firstParent[T].iterator(split, context))

org.apache.spark.executor.CoarseGrainedExecutorBackend#statusUpdate

//这里会发送StatusUptate消息,给SparkDeploySchedulerBackend
  override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
    driver ! StatusUpdate(executorId, taskId, state, data)
  }

org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.StatusUpdate
SparkDeploySchedulerBackend的父类是CoarseGrainedSchedulerBackend

//处理task执行结束的事件
 case StatusUpdate(executorId, taskId, state, data) =>
//调用TaskSchedulerImpl的statusUpdata方法
        scheduler.statusUpdate(taskId, state, data.value)
        if (TaskState.isFinished(state)) {
          executorDataMap.get(executorId) match {
            case Some(executorInfo) =>
              executorInfo.freeCores += scheduler.CPUS_PER_TASK
              makeOffers(executorId)
            case None =>
              // Ignoring the update since we don't know about the executor.
              logWarning(s"Ignored task status update ($taskId state $state) " +
                "from unknown executor $sender with ID $executorId")
          }
        }

org.apache.spark.scheduler.TaskSchedulerImpl#statusUpdate

 def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
    var failedExecutor: Option[String] = None
    synchronized {
      try {
//如果task 是 lost了,
        if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
          // We lost this entire executor, so remember that it's gone
          //这里就会移除executor,将它加入失败队列
          val execId = taskIdToExecutorId(tid)
          if (activeExecutorIds.contains(execId)) {
            removeExecutor(execId)
            failedExecutor = Some(execId)
          }
        }
        taskIdToTaskSetId.get(tid) match {
//获取对应的taskSet
          case Some(taskSetId) =>
//如果task结束了,从内存缓存中移除
            if (TaskState.isFinished(state)) {
              taskIdToTaskSetId.remove(tid)
              taskIdToExecutorId.remove(tid)
            }
//如果正常结束,那么也做相应的处理
            activeTaskSets.get(taskSetId).foreach { taskSet =>
              if (state == TaskState.FINISHED) {
                taskSet.removeRunningTask(tid)
                taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
              } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
                taskSet.removeRunningTask(tid)
                taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
              }
            }
          case None =>
            logError(
              ("Ignoring update with state %s for TID %s because its task set is gone (this is " +
               "likely the result of receiving duplicate task finished status updates)")
              .format(state, tid))
        }
      } catch {
        case e: Exception => logError("Exception in statusUpdate", e)
      }
    }
    // Update the DAGScheduler without holding a lock on this, since that can deadlock
    if (failedExecutor.isDefined) {
      dagScheduler.executorLost(failedExecutor.get)
      backend.reviveOffers()
    }
  }

接下里分析org.apache.spark.scheduler.ResultTask#runTask

override def runTask(context: TaskContext): U = {
    // Deserialize the RDD and the func using the broadcast variables.
//进行了基本的反序列化
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

    metrics = Some(context.taskMetrics)
//通过rdd的iterator,执行我们定义的算子和函数
    func(context, rdd.iterator(partition, context))
  }

你可能感兴趣的:(spark 内核源码剖析十:Task原理)