spark1.1.0源码阅读-executor

1. executor上执行launchTask

1   def launchTask(

2       context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {

3     val tr = new TaskRunner(context, taskId, taskName, serializedTask)

4     runningTasks.put(taskId, tr)

5     threadPool.execute(tr)

6   }

2. executor上执行TaskRunner的run

 1  class TaskRunner(

 2       execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)

 3     extends Runnable {

 4 

 5     @volatile private var killed = false

 6     @volatile var task: Task[Any] = _

 7     @volatile var attemptedTask: Option[Task[Any]] = None

 8 

 9     def kill(interruptThread: Boolean) {

10       logInfo(s"Executor is trying to kill $taskName (TID $taskId)")

11       killed = true

12       if (task != null) {

13         task.kill(interruptThread)

14       }

15     }

16 

17     override def run() {

18       val startTime = System.currentTimeMillis()

19       SparkEnv.set(env)

20       Thread.currentThread.setContextClassLoader(replClassLoader)

21       val ser = SparkEnv.get.closureSerializer.newInstance()

22       logInfo(s"Running $taskName (TID $taskId)")

23       execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)

24       var taskStart: Long = 0

25       def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum

26       val startGCTime = gcTime

27 

28       try {

29         SparkEnv.set(env)

30         Accumulators.clear()

31         val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) //反序列化出 taskFiles,taskJars,taskBytes 32         updateDependencies(taskFiles, taskJars)

33         task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)  //反序列化出task对象 34 

35         // If this task has been killed before we deserialized it, let's quit now. Otherwise,

36         // continue executing the task.

37         if (killed) {

38           // Throw an exception rather than returning, because returning within a try{} block

39           // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl

40           // exception will be caught by the catch block, leading to an incorrect ExceptionFailure

41           // for the task.

42           throw new TaskKilledException

43         }

44 

45         attemptedTask = Some(task)

46         logDebug("Task " + taskId + "'s epoch is " + task.epoch)

47         env.mapOutputTracker.updateEpoch(task.epoch)

48 

49         // Run the actual task and measure its runtime.

50         taskStart = System.currentTimeMillis()

51         val value = task.run(taskId.toInt)

52         val taskFinish = System.currentTimeMillis()

53 

54         // If the task has been killed, let's fail it.

55         if (task.killed) {

56           throw new TaskKilledException

57         }

3. task.run

 1 private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {

 2 

 3   final def run(attemptId: Long): T = {

 4     context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)

 5     context.taskMetrics.hostname = Utils.localHostName()

 6     taskThread = Thread.currentThread()

 7     if (_killed) {

 8       kill(interruptThread = false)

 9     }

10     runTask(context)

11   }

4. task是抽象类,对于具体的类(resultTask和shuffleMapTask)会执行相应的runTask。

a. resultTask

 1   override def runTask(context: TaskContext): U = {

 2     // Deserialize the RDD and the func using the broadcast variables.

 3     val ser = SparkEnv.get.closureSerializer.newInstance()

 4     val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](

 5       ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

 6 

 7     metrics = Some(context.taskMetrics)

 8     try {

 9       func(context, rdd.iterator(partition, context))

10     } finally {

11       context.markTaskCompleted()

12     }

13   }

b. shuffleMapTask

 1   override def runTask(context: TaskContext): MapStatus = {

 2     // Deserialize the RDD using the broadcast variable.

 3     val ser = SparkEnv.get.closureSerializer.newInstance()

 4     val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](

 5       ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

 6 

 7     metrics = Some(context.taskMetrics)

 8     var writer: ShuffleWriter[Any, Any] = null

 9     try {

10       val manager = SparkEnv.get.shuffleManager

11       writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)

12       writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])

13       return writer.stop(success = true).get

14     } catch {

15       case e: Exception =>

16         if (writer != null) {

17           writer.stop(success = false)

18         }

19         throw e

20     } finally {

21       context.markTaskCompleted()

22     }

23   }
 1   /** Write a bunch of records to this task's output */

 2   override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {

 3     val iter = if (dep.aggregator.isDefined) {

 4       if (dep.mapSideCombine) {

 5         dep.aggregator.get.combineValuesByKey(records, context)

 6       } else {

 7         records

 8       }

 9     } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {

10       throw new IllegalStateException("Aggregator is empty for map-side combine")

11     } else {

12       records

13     }

14 

15     for (elem <- iter) {

16       val bucketId = dep.partitioner.getPartition(elem._1)

17       shuffle.writers(bucketId).write(elem)

18     }

19   }

 

 1   /**

 2    * Get a ShuffleWriterGroup for the given map task, which will register it as complete

 3    * when the writers are closed successfully

 4    */

 5   def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer,

 6       writeMetrics: ShuffleWriteMetrics) = {

 7     new ShuffleWriterGroup {

 8       shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))

 9       private val shuffleState = shuffleStates(shuffleId)

10       private var fileGroup: ShuffleFileGroup = null

11 

12       val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {

13         fileGroup = getUnusedFileGroup()

14         Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>

15           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)

16           blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,

17             writeMetrics)

18         }

19       } else {

20         Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>

21           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)

22           val blockFile = blockManager.diskBlockManager.getFile(blockId)

23           // Because of previous failures, the shuffle file may already exist on this machine.

24           // If so, remove it.

25           if (blockFile.exists) {

26             if (blockFile.delete()) {

27               logInfo(s"Removed existing shuffle file $blockFile")

28             } else {

29               logWarning(s"Failed to remove existing shuffle file $blockFile")

30             }

31           }

32           blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)

33         }

34       }

 

你可能感兴趣的:(executor)