Spark源码走读(二) —— Job的提交

import org.apache.spark.{SparkConf, SparkContext}

object SparkWordCount{
  def main(args: Array[String]) {
    if (args.length == 0) {
      System.exit(1)
    }

    val conf = new SparkConf().setAppName("SparkWordCount")
    val sc = new SparkContext(conf)

    val file=sc.textFile("xxx")
    val counts=file.flatMap(line=>line.split(" "))
                   .map(word=>(word,1))
                   .reduceByKey(_+_)
    counts.saveAsTextFile("xxx")

  }
}

以上面代码为例,可以顺藤摸瓜看一看Spark如何提交Job。

SparkContext

上面代码首先创建了一个SparkContext对象,然后利用SparkContext对象生成了一个RDD(file),再利用RDD去进行计算,存储。可以看出SparkContext是Spark应用的入口,而实际上SparkContext负责与整个Spark集群进行交互,可以创建RDD、accumulators 及广播变量等。官网上SparkContext与其他组件交互图如下:
Spark源码走读(二) —— Job的提交_第1张图片

并且官网对上图有个简述:

  1. Spark应用在集群上是作为独立的进程运行的,它们由主程序(Driver Program)中的SparkContext对象去协调
  2. 为了把应用运行在集群上,
    • SparkContext会连接到几种Cluster Manager(YARN、MESOS、standalone)去给应用分配资源。
    • 一旦连接上,Spark将会在Worker节点获取executors,为应用运行计算和存储数据。
    • 接下来,它会把应用代码发送给executors(定义成jar包或python文件传递给SparkContext),最终SparkContext把tasks发送给executor运行

由上述描述可见SparkContext作用的重要性。官网上有对上图还有说明几个有用的点:

  1. 每个应用都有自己的executor进程,这些进程在应用的整个执行期都存在,且executor中可以采用多线程的方式执行Task。这样做的好处是,应用相互隔离,不仅是调度侧如此(每个driver调度它自己的tasks),还包括executor侧(不同应用的任务运行在不同JVM中)。然而,这也意味着如果不使用外部存储系统,数据不能在多个Spark应用(SparkContext实例)之间共享
  2. Spark不感知底层的cluster manager。只要可以获取executor进程,并且这些进程可以互相通信,即使在支持其他应用的cluster manager(如Mesos/Yarn)上运行也比较容易。
  3. driver program在它的整个生命周期,必须监听和接收来自于它的executors的连接。因此,driver program必须能够从工作节点进行网络寻址。
  4. 由于driver在集群上调度任务,所以它应该离worker节点近点,最好是在同一个局域网上运行。如果必须远程发送请求到集群,最好是给driver打开一个RPC,并且就近提交操作,而不是远离工作节点去运行driver。

Spark Job提交流程

RDD的操作分为transformation和action,transformation是惰性计算的,只有遇到action才会开始计算。文章开始的代码中saveAsTextFile是action,代码为:

/**
   * Save this RDD as a text file, using string representations of elements.
   */
  def saveAsTextFile(path: String): Unit = withScope {
    val nullWritableClassTag = implicitly[ClassTag[NullWritable]]
    val textClassTag = implicitly[ClassTag[Text]]
    val r = this.mapPartitions { iter =>
      val text = new Text()
      iter.map { x =>
        text.set(x.toString)
        (NullWritable.get(), text)
      }
    }
    RDD.rddToPairRDDFunctions(r)(nullWritableClassTag, textClassTag, null)
      .saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path)
  }

直接看方法尾部的saveAsHadoopFile方法,一层层看下去saveAsHadoopFile -> saveAsHadoopDataset,在saveAsHadoopDataset尾部的self.context.runJob开始执行Job,这里的context是SparkContext对象。

def saveAsHadoopDataset(conf: JobConf): Unit = self.withScope {
  // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038).
  val hadoopConf = conf
  val outputFormatInstance = hadoopConf.getOutputFormat
  val keyClass = hadoopConf.getOutputKeyClass
  val valueClass = hadoopConf.getOutputValueClass
  if (outputFormatInstance == null) {
    throw new SparkException("Output format class not set")
  }
  if (keyClass == null) {
    throw new SparkException("Output key class not set")
  }
  if (valueClass == null) {
    throw new SparkException("Output value class not set")
  }
  SparkHadoopUtil.get.addCredentials(hadoopConf)

  logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
    valueClass.getSimpleName + ")")

  if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(self.conf)) {
    // FileOutputFormat ignores the filesystem parameter
    val ignoredFs = FileSystem.get(hadoopConf)
    hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf)
  }

  val writer = new SparkHadoopWriter(hadoopConf)
  writer.preSetup()

  val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => {
    // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
    // around by taking a mod. We expect that no task will be attempted 2 billion times.
    val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt

    val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context)

    writer.setup(context.stageId, context.partitionId, taskAttemptId)
    writer.open()
    var recordsWritten = 0L

    Utils.tryWithSafeFinallyAndFailureCallbacks {
      while (iter.hasNext) {
        val record = iter.next()
        writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])

        // Update bytes written metric every few records
        SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten)
        recordsWritten += 1
      }
    }(finallyBlock = writer.close())
    writer.commit()
    outputMetrics.setBytesWritten(callback())
    outputMetrics.setRecordsWritten(recordsWritten)
  }

  self.context.runJob(self, writeToFile)
  writer.commitJob()
}

runJob一层层看下去会发现调用了DAGScheduler的runJob,源码如下:

def runJob[T, U](
    rdd: RDD[T],
    func: (TaskContext, Iterator[T]) => U,
    partitions: Seq[Int],
    callSite: CallSite,
    resultHandler: (Int, U) => Unit,
    properties: Properties): Unit = {
  val start = System.nanoTime

  //提交作业 
  val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
  // Note: Do not call Await.ready(future) because that calls `scala.concurrent.blocking`,
  // which causes concurrent SQL executions to fail if a fork-join pool is used. Note that
  // due to idiosyncrasies in Scala, `awaitPermission` is not actually used anywhere so it's
  // safe to pass in null here. For more detail, see SPARK-13747.
  val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
  waiter.completionFuture.ready(Duration.Inf)(awaitPermission)
  waiter.completionFuture.value.get match {
    case scala.util.Success(_) =>
      logInfo("Job %d finished: %s, took %f s".format
        (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
    case scala.util.Failure(exception) =>
      logInfo("Job %d failed: %s, took %f s".format
        (waiter.jobId, callSite.shortForm, (System.nanoTime - start) / 1e9))
      // SPARK-8644: Include user stack trace in exceptions coming from DAGScheduler.
      val callerStackTrace = Thread.currentThread().getStackTrace.tail
      exception.setStackTrace(exception.getStackTrace ++ callerStackTrace)
      throw exception
  }
}

源码中的submitJob源码如下:

def submitJob[T, U](
      rdd: RDD[T],
      func: (TaskContext, Iterator[T]) => U,
      partitions: Seq[Int],
      callSite: CallSite,
      resultHandler: (Int, U) => Unit,
      properties: Properties): JobWaiter[U] = {
    // Check to make sure we are not launching a task on a partition that does not exist.
    val maxPartitions = rdd.partitions.length
    partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
      throw new IllegalArgumentException(
        "Attempting to access a non-existent partition: " + p + ". " +
          "Total number of partitions: " + maxPartitions)
    }

    val jobId = nextJobId.getAndIncrement()
    if (partitions.size == 0) {
      // Return immediately if the job is running 0 tasks
      return new JobWaiter[U](this, jobId, 0, resultHandler)
    }

    assert(partitions.size > 0)
    val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
    //JobWaiter对象是等待DAGScheduler job去完成的对象。
    //当任务执行完,它会把任务结果传给给定的handler函数
    val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
    //eventProcessLoop的实际类型是DAGSchedulerEventProcessLoop,post实际是将JobSubmitted放入eventQueue中,由eventThread后台处理
    eventProcessLoop.post(JobSubmitted(
      jobId, rdd, func2, partitions.toArray, callSite, waiter,
      SerializationUtils.clone(properties)))
    waiter
  }

这里eventProcessLoop的实际类型是DAGSchedulerEventProcessLoop,而DAGSchedulerEventProcessLoop继承了EventLoop[DAGSchedulerEvent],EventLoop源码如下:

private[spark] abstract class EventLoop[E](name: String) extends Logging {

//事件队列
  private val eventQueue: BlockingQueue[E] = new LinkedBlockingDeque[E]()

  private val stopped = new AtomicBoolean(false)

  private val eventThread = new Thread(name) {
    setDaemon(true)

    override def run(): Unit = {
      try {
        while (!stopped.get) {
          //从事件队列中取出事件
          val event = eventQueue.take()
          try {
          //处理事件,这里调用的DAGSchedulerEventProcessLoop的onReceive
            onReceive(event)
          } catch {
            case NonFatal(e) =>
              try {
                onError(e)
              } catch {
                case NonFatal(e) => logError("Unexpected error in " + name, e)
              }
          }
        }
      } catch {
        case ie: InterruptedException => // exit even if eventQueue is not empty
        case NonFatal(e) => logError("Unexpected error in " + name, e)
      }
    }

  }

  def start(): Unit = {
    if (stopped.get) {
      throw new IllegalStateException(name + " has already been stopped")
    }
    // Call onStart before starting the event thread to make sure it happens before onReceive
    onStart()
    eventThread.start()
  }
//省略其余源码

这里调用的onReceiveDAGSchedulerEventProcessLooponReceive,源码如下:

/**
 * The main event loop of the DAG scheduler.
 */
override def onReceive(event: DAGSchedulerEvent): Unit = {
  val timerContext = timer.time()
  try {
    doOnReceive(event)
  } finally {
    timerContext.stop()
  }
}

再看doOnReceive源码:

private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
    //JobSubmitted在这里处理
    case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) =>
      dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties)

    case MapStageSubmitted(jobId, dependency, callSite, listener, properties) =>
      dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties)

    case StageCancelled(stageId, reason) =>
      dagScheduler.handleStageCancellation(stageId, reason)

    case JobCancelled(jobId, reason) =>
      dagScheduler.handleJobCancellation(jobId, reason)

    case JobGroupCancelled(groupId) =>
      dagScheduler.handleJobGroupCancelled(groupId)

    case AllJobsCancelled =>
      dagScheduler.doCancelAllJobs()

    case ExecutorAdded(execId, host) =>
      dagScheduler.handleExecutorAdded(execId, host)

    case ExecutorLost(execId, reason) =>
      val filesLost = reason match {
        case SlaveLost(_, true) => true
        case _ => false
      }
      dagScheduler.handleExecutorLost(execId, filesLost)

    case BeginEvent(task, taskInfo) =>
      dagScheduler.handleBeginEvent(task, taskInfo)

    case GettingResultEvent(taskInfo) =>
      dagScheduler.handleGetTaskResult(taskInfo)

    case completion: CompletionEvent =>
      dagScheduler.handleTaskCompletion(completion)

    case TaskSetFailed(taskSet, reason, exception) =>
      dagScheduler.handleTaskSetFailed(taskSet, reason, exception)

    case ResubmitFailedStages =>
      dagScheduler.resubmitFailedStages()
  }

从这里可以发现最终是调用了DAGScheduler的handleJobSubmitted方法进行job的提交。job提交之后涉及到Stage的划分和task的提交。

你可能感兴趣的:(Spark)