spark 源码分析系列 - wordcount 源码分析

前言

本文主要通过spark wordcount 案例的源码来分析spark中的运行过程。

spark 编程模型

在spark中,RDD被表示为对象,通过对象上的方法调用来对RDD进行转换。RDD经过一系列的transformations转换定义之后,就可以调用actions触发RDD的计算,action可以是向应用程序返回结果,或者是向存储系统保存数据。在spark中,只有遇到action,才会执行RDD的计算(即延迟计算)。

sc.textFile("input").flatMap(_.split(" ")).map((_,1)).reduceByKey(_+_).collect

算子:从认知心理学角度来讲,解决问题其实是将问题的初始状态,通过一系列的转换操作(operator),变成解决状态。

spark wordcount 完整示例代码


import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}

object WordCountScala {


  def main(args: Array[String]): Unit = {

    val conf = new SparkConf()
    conf.setAppName("wordcount")
    conf.setMaster("local")  

    val sc = new SparkContext(conf)
    //单词统计
    //DATASET
//    val fileRDD: RDD[String] = sc.textFile("bigdata-spark/data/testdata.txt",16)
    //hello world

//    fileRDD.flatMap(  _.split(" ") ).map((_,1)).reduceByKey(  _+_   ).foreach(println)

    val fileRDD: RDD[String] = sc.textFile("bigdata-spark/data/testdata.txt")
    //hello world
    val words: RDD[String] = fileRDD.flatMap((x:String)=>{x.split(" ")})
    //hello
    //world
    val pairWord: RDD[(String, Int)] = words.map((x:String)=>{new Tuple2(x,1)})
    //(hello,1)
    //(hello,1)
    //(world,1)
    val res: RDD[(String, Int)] = pairWord.reduceByKey(  (x:Int,y:Int)=>{x+y}   )
    //X:oldValue  Y:value
    //(hello,2)  -> (2,1)
    //(world,1)   -> (1,1)
    //(msb,2)   -> (2,1)

    val reverseRDD: RDD[(Int, Int)] = res.map((x)=>{  (x._2,1)  })
    val resOver: RDD[(Int, Int)] = reverseRDD.reduceByKey(_+_)

    resOver.foreach(println)
    res.foreach(println)

    Thread.sleep(Long.MaxValue)

  }

}

spark wordcount 源码分析总流程图

wordcount.png

textFile() 源码

在Spark中创建RDD的创建方式可以分为三种:从集合中创建RDD、从外部存储创建RDD、从其他RDD创建。

本次示例中使用textFile从hadoop中读取数据

sc.textfile()开始进行分析

 val fileRDD: RDD[String] = sc.textFile("bigdata-spark/data/testdata.txt")

SparkContexttextfile()开始进行分析

  /**
   * Read a text file from HDFS, a local file system (available on all nodes), or any
   * Hadoop-supported file system URI, and return it as an RDD of Strings.
   * @param path path to the text file on a supported file system
   * @param minPartitions suggested minimum number of partitions for the resulting RDD
   * @return RDD of lines of the text file
   */
  def textFile(
      path: String,
      minPartitions: Int = defaultMinPartitions): RDD[String] = withScope {
    assertNotStopped()
     
      //输入文件的格式TextInputFormat,key的类型LongWritable ,value的类型Text
      //最小分区数defaultMinPartitions
    hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
      minPartitions).map(pair => pair._2.toString).setName(path)
  }

textfile传入参数有两个:

第一个参数传了path,就是文件路径

第二个参数minpartitions,最小分区数,用户可以自己定义minpartitions,但是不一定会取到它,因为spark取并行度的原则是并行度最高优先,比如用户定义了一个minpartitions为10,而程序计算出来的minpartitions是15,那么取二者的最大值,也就是15,;如果用户定义的并行度为 20,程序算出来的并行度为15,那么就取20。

hadoopFile() 源码

textfile方法可以看到hadoopFile方法的调用

/** Get an RDD for a Hadoop file with an arbitrary InputFormat
   *
   * @note Because Hadoop's RecordReader class re-uses the same Writable object for each
   * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
   * operation will create many references to the same object.
   * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
   * copy them using a `map` function.
   * @param path directory to the input data files, the path can be comma separated paths
   * as a list of inputs
   * @param inputFormatClass storage format of the data to be read
   * @param keyClass `Class` of the key associated with the `inputFormatClass` parameter
   * @param valueClass `Class` of the value associated with the `inputFormatClass` parameter
   * @param minPartitions suggested minimum number of partitions for the resulting RDD
   * @return RDD of tuples of key and corresponding value
   */
  def hadoopFile[K, V](
      path: String,
      inputFormatClass: Class[_ <: InputFormat[K, V]],
      keyClass: Class[K],
      valueClass: Class[V],
      minPartitions: Int = defaultMinPartitions): RDD[(K, V)] = withScope {
    assertNotStopped()

    // This is a hack to enforce loading hdfs-site.xml.
    // See SPARK-11227 for details.
    FileSystem.getLocal(hadoopConfiguration)

    // A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
    val confBroadcast = broadcast(new SerializableConfiguration(hadoopConfiguration))
    val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
    new HadoopRDD(
      this,
      confBroadcast,
      Some(setInputPathsFunc),
      inputFormatClass,
      keyClass,
      valueClass,
      minPartitions).setName(path)
  }

HadoopRDD 源码

该方法主要是对不同的文件格式来返回一个 HadoopRDD

class HadoopRDD[K, V](
    sc: SparkContext,
    broadcastedConf: Broadcast[SerializableConfiguration],
    initLocalJobConfFuncOpt: Option[JobConf => Unit],
    inputFormatClass: Class[_ <: InputFormat[K, V]],
    keyClass: Class[K],
    valueClass: Class[V],
    minPartitions: Int)
  extends RDD[(K, V)](sc, Nil) with Logging {
      def this ...
      def getJobConf ...
      def getInputFormat ...
      def getPartitions ...
      def compute ...
      ...
  }

可以看到,这个HadoopRDD是RDD的一个子类,并将sparkcontext和一个Nil作为参数传到RDD父类

abstract class RDD[T: ClassTag](
    @transient private var _sc: SparkContext,
    @transient private var deps: Seq[Dependency[_]]
  ) extends Serializable with Logging {

这里的第二个参数deps表示的是依赖,因为这个spark LineAge 的源头 ,因此这个依赖指向的是Nill

RDD的依赖就像是一个单链表,后一个RDD的依赖中包含了前一个RDD。

此外,HadoopRDD中还包含了2个重要的方法:GetPartitions()compute()

GetPartitions() 源码

我们首先来看看GetPartitions()

 override def getPartitions: Array[Partition] = {
    val jobConf = getJobConf()
    // add the credentials here as this can be called before SparkContext initialized
    SparkHadoopUtil.get.addCredentials(jobConf)
    val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions)
    val inputSplits = if (ignoreEmptySplits) {
      allInputSplits.filter(_.getLength > 0)
    } else {
      allInputSplits
    }
    val array = new Array[Partition](inputSplits.size)
    for (i <- 0 until inputSplits.size) {
      array(i) = new HadoopPartition(id, i, inputSplits(i))
    }
    array
  }

这里有一个FileInputFormat getsplits()方法,了解过MapReduce的对它一定不陌生,因为MapReduce也有getsplits()这个方法,进去看我们就会发现,这个getsplits()方法的实现逻辑和MapReduce几乎一模一样。

getSplits() 源码
/** Splits files returned by {@link #listStatus(JobConf)} when
   * they're too big.*/ 
  public InputSplit[] getSplits(JobConf job, int numSplits)
    throws IOException {
    Stopwatch sw = new Stopwatch().start();
    //获取所有 FileStatus
    FileStatus[] files = listStatus(job);
    
    // Save the number of input files for metrics/loadgen
    job.setLong(NUM_INPUT_FILES, files.length);
    long totalSize = 0;                           // compute total size
    for (FileStatus file: files) {                // check we have valid files
      if (file.isDirectory()) {
        throw new IOException("Not a file: "+ file.getPath());
      }
      totalSize += file.getLen();
    }
    //获取目标分片goalsize和最小minsize
    long goalSize = totalSize / (numSplits == 0 ? 1 : numSplits);
    long minSize = Math.max(job.getLong(org.apache.hadoop.mapreduce.lib.input.
      FileInputFormat.SPLIT_MINSIZE, 1), minSplitSize);

    // generate splits
    ArrayList splits = new ArrayList(numSplits);
    NetworkTopology clusterMap = new NetworkTopology();
    for (FileStatus file: files) {
      Path path = file.getPath();
      long length = file.getLen();
      if (length != 0) {
        FileSystem fs = path.getFileSystem(job);
        BlockLocation[] blkLocations;
        if (file instanceof LocatedFileStatus) {
          blkLocations = ((LocatedFileStatus) file).getBlockLocations();
        } else {
          blkLocations = fs.getFileBlockLocations(file, 0, length);
        }
          //判断文件是否支持切分
        if (isSplitable(fs, path)) {
          long blockSize = file.getBlockSize();
            
          //支持切分就进行切分分片,切分分片大小计算
          long splitSize = computeSplitSize(goalSize, minSize, blockSize);

          long bytesRemaining = length;
          while (((double) bytesRemaining)/splitSize > SPLIT_SLOP) {
            String[][] splitHosts = getSplitHostsAndCachedHosts(blkLocations,
                length-bytesRemaining, splitSize, clusterMap);
            splits.add(makeSplit(path, length-bytesRemaining, splitSize,
                splitHosts[0], splitHosts[1]));
            bytesRemaining -= splitSize;
          }

          if (bytesRemaining != 0) {
            String[][] splitHosts = getSplitHostsAndCachedHosts(blkLocations, length
                - bytesRemaining, bytesRemaining, clusterMap);
            splits.add(makeSplit(path, length - bytesRemaining, bytesRemaining,
                splitHosts[0], splitHosts[1]));
          }
        } else {
          String[][] splitHosts = getSplitHostsAndCachedHosts(blkLocations,0,length,clusterMap);
          splits.add(makeSplit(path, 0, length, splitHosts[0], splitHosts[1]));
        }
      } else { 
        //Create empty hosts array for zero length files
        splits.add(makeSplit(path, 0, length, new String[0]));
      }
    }
    sw.stop();
    if (LOG.isDebugEnabled()) {
      LOG.debug("Total # of splits generated by getSplits: " + splits.size()
          + ", TimeTaken: " + sw.elapsedMillis());
    }
    return splits.toArray(new FileSplit[splits.size()]);
  }

getsplits实现分片几个步骤

  1. listStatus 获取所有 FileStatus
  2. 获取目标分片goalsize和最小minsize
  3. 判断文件是否支持切分
  4. 支持切分就进行切分分片,切分分片大小为Math.max(minSize, Math.min(goalSize, blockSize))
  5. 该方法最后返回InputSplit。后面getPartitions方法根据返回的InputSplit构建HadoopPartition
compute() 源码

接下来是看compute方法

注意 compute方法是在执行了action操作才会触发

 override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = {
    //创建了一个迭代器
     val iter = new NextIterator[(K, V)] {

      private val split = theSplit.asInstanceOf[HadoopPartition]
      logInfo("Input split: " + split.inputSplit)
      private val jobConf = getJobConf()

      private val inputMetrics = context.taskMetrics().inputMetrics
      private val existingBytesRead = inputMetrics.bytesRead

      // Sets InputFileBlockHolder for the file block's information
      split.inputSplit.value match {
        case fs: FileSplit =>
          InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength)
        case _ =>
          InputFileBlockHolder.unset()
      }

      // Find a function that will return the FileSystem bytes read by this thread. Do this before
      // creating RecordReader, because RecordReader's constructor might read some bytes
      private val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match {
        case _: FileSplit | _: CombineFileSplit =>
          Some(SparkHadoopUtil.get.getFSBytesReadOnThreadCallback())
        case _ => None
      }

      // We get our input bytes from thread-local Hadoop FileSystem statistics.
      // If we do a coalesce, however, we are likely to compute multiple partitions in the same
      // task and in the same thread, in which case we need to avoid override values written by
      // previous partitions (SPARK-13071).
      private def updateBytesRead(): Unit = {
        getBytesReadCallback.foreach { getBytesRead =>
          inputMetrics.setBytesRead(existingBytesRead + getBytesRead())
        }
      }

      private var reader: RecordReader[K, V] = null
      private val inputFormat = getInputFormat(jobConf)
      HadoopRDD.addLocalConfiguration(
        new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime),
        context.stageId, theSplit.index, context.attemptNumber, jobConf)

      reader =
        try {
          inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
        } catch {
          case e: IOException if ignoreCorruptFiles =>
            logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e)
            finished = true
            null
        }
      // Register an on-task-completion callback to close the input stream.
      context.addTaskCompletionListener { context =>
        // Update the bytes read before closing is to make sure lingering bytesRead statistics in
        // this thread get correctly added.
        updateBytesRead()
        closeIfNeeded()
      }

      private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
      private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()

      override def getNext(): (K, V) = {
        try {
          finished = !reader.next(key, value)
        } catch {
          case e: IOException if ignoreCorruptFiles =>
            logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e)
            finished = true
        }
        if (!finished) {
          inputMetrics.incRecordsRead(1)
        }
        if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
          updateBytesRead()
        }
        (key, value)
      }

      override def close(): Unit = {
        if (reader != null) {
          InputFileBlockHolder.unset()
          try {
            reader.close()
          } catch {
            case e: Exception =>
              if (!ShutdownHookManager.inShutdown()) {
                logWarning("Exception in RecordReader.close()", e)
              }
          } finally {
            reader = null
          }
          if (getBytesReadCallback.isDefined) {
            updateBytesRead()
          } else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
                     split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
            // If we can't get the bytes read from the FS stats, fall back to the split size,
            // which may be inaccurate.
            try {
              inputMetrics.incBytesRead(split.inputSplit.value.getLength)
            } catch {
              case e: java.io.IOException =>
                logWarning("Unable to get input size to set InputMetrics for task", e)
            }
          }
        }
      }
    }
    new InterruptibleIterator[(K, V)](context, iter)
  }

这里我们可以看到,compute()方法里创建了一个迭代器NextIterator[(K, V)],我们进入这个迭代器。

override def hasNext: Boolean = {
    if (!finished) {
      if (!gotNext) {
        nextValue = getNext()
        if (finished) {
          closeIfNeeded()
        }
        gotNext = true
      }
    }
    !finished
  }

  override def next(): U = {
    if (!hasNext) {
      throw new NoSuchElementException("End of stream")
    }
    gotNext = false
    nextValue
  }

迭代器有hasNext()next()方法,这就是获取文件中的方法,也就是说,compute()方法通过这个迭代器来迭代获取数据。值得一提的是:getPartition()compute()方法并没有被执行

通过textFile()方法 我们可以得到第一个RDD:fileRDD,下个操作是flatMap()方法

flatMap()源码

  /**
   *  Return a new RDD by first applying a function to all elements of this
   *  RDD, and then flattening the results.
   */
  def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = withScope {
    val cleanF = sc.clean(f)
    new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.flatMap(cleanF))
  }

clean(f)是将函数f序列化,并将其闭包进一个flatMap()中分发出去

flatMap()方法里又定义了一个新的RDD:MapPartitionsRDD,而且在new MapPartitionsRDD的时候还将 f 这个匿名函数传进去了,并且最终会调用一个迭代器去执行flatMap(cleanF)方法,我们进入MapPartitionsRDD去看看。

MapPartitionsRDD源码

private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
    var prev: RDD[T],
    f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
    preservesPartitioning: Boolean = false,
    isOrderSensitive: Boolean = false)
  extends RDD[U](prev) {

  override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None

  override def getPartitions: Array[Partition] = firstParent[T].partitions

  override def compute(split: Partition, context: TaskContext): Iterator[U] =
    f(context, split.index, firstParent[T].iterator(split, context))

  override def clearDependencies() {
    super.clearDependencies()
    prev = null
  }

注意,这里的prev指的是前一个RDD,也就是fileRDD,MapPartitionsRDD在调用RDD的构造方法时将this也就是fileRDD传给了RDD,这样MapPartitionsRDD就持有了一个对fileRDD的引用。并且函数 f 也被传进来了,我们看一下这个RDD的构造方法。

  /** Construct an RDD with just a one-to-one dependency on one parent */
  def this(@transient oneParent: RDD[_]) =
    this(oneParent.context, List(new OneToOneDependency(oneParent)))

MapPartiotion中也有一个compute()方法,我们看看这个compute()方法。

compute()源码

override def compute(split: Partition, context: TaskContext): Iterator[U] =
  f(context, split.index, firstParent[T].iterator(split, context))

传进来的 f 函数在compute被调用是被调起了

第一个参数sparkcontext是它自己的context

第二个参数是切片的分区数

第三个参数是前一个RDD的iterator()方法。但是HadoopRDD类中并没有iterator()方法,所以我们去它的父类RDD中找。

RDD iterator()源码
  /**
   * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
   * This should ''not'' be called by users directly, but is available for implementors of custom
   * subclasses of RDD.
   */
  final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      getOrCompute(split, context)
    } else {
      computeOrReadCheckpoint(split, context)
    }
  }

看这段代码的逻辑,如果没有缓存和持久化,那么它就会调用自己的compute()方法,这样我们发现,MapPartitinsRDD调用了fileRDD的iterator方法,fileRDD的iterator方法又调用了自己的compute()方法,compute()方法会返回一个拉取数据的迭代器。

flatMap 分析完毕,我们进入map()方法

map()源码

  /**
   * Return a new RDD by applying a function to all elements of this RDD.
   */
  def map[U: ClassTag](f: T => U): RDD[U] = withScope {
    val cleanF = sc.clean(f)
    new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
  }

跟flatMap()方法一样,这里会定义一个新的RDD:MapPartitionsRDD ,map()方法的执行逻辑和flatMap()方法是一样的。

接下来是reduceBykey()方法,reduceBykey会触发shuffle,因为之前的RDD只关心自己处理的一条记录,不用管其他的记录,而reduceBykey要处理的是key相同的一组记录,所以需要把相同的key的记录发送到相同的reducer中,我们进入reduceBykey()

reduceBykey()源码

  /**
   * Merge the values for each key using an associative and commutative reduce function. This will
   * also perform the merging locally on each mapper before sending results to a reducer, similarly
   * to a "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
   * parallelism level.
   */
  def reduceByKey(func: (V, V) => V): RDD[(K, V)] = self.withScope {
    reduceByKey(defaultPartitioner(self), func)
  }

这个方法又调用了一次它自己,并且调用时传了2个参数进去,第一个是defaultPartitioner(self),这个defaultPartitioner就是HashPartitioner(又和mapredcue一样)。我们进入这个reduceBykey(defaultPartitioner(self), func)方法。

 /**
   * Merge the values for each key using an associative and commutative reduce function. This will
   * also perform the merging locally on each mapper before sending results to a reducer, similarly
   * to a "combiner" in MapReduce.
   */
  def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = self.withScope {
    combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner)
  }

这里它调用了combineByKeyWithClassTag[V]((v: V) => v, func, func, partitioner)方法,并传入了4个参数

第一个参数匿名函数

第二个参数和第三个参数都是用户定义的function

第四个参数是partition

combineByKey() 源码

  def combineByKeyWithClassTag[C](
      createCombiner: V => C,
      mergeValue: (C, V) => C,
      mergeCombiners: (C, C) => C,
      partitioner: Partitioner,
      mapSideCombine: Boolean = true,
      serializer: Serializer = null)(implicit ct: ClassTag[C]): RDD[(K, C)] = self.withScope {
    require(mergeCombiners != null, "mergeCombiners must be defined") // required as of Spark 0.9.0
    if (keyClass.isArray) {
      if (mapSideCombine) {
        throw new SparkException("Cannot use map-side combining with array keys.")
      }
      if (partitioner.isInstanceOf[HashPartitioner]) {
        throw new SparkException("HashPartitioner cannot partition array keys.")
      }
    }
    val aggregator = new Aggregator[K, V, C](
      self.context.clean(createCombiner),
      self.context.clean(mergeValue),
      self.context.clean(mergeCombiners))
    if (self.partitioner == Some(partitioner)) {
      self.mapPartitions(iter => {
        val context = TaskContext.get()
        new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
      }, preservesPartitioning = true)
    } else {
      new ShuffledRDD[K, V, C](self, partitioner)
        .setSerializer(serializer)
        .setAggregator(aggregator)
        .setMapSideCombine(mapSideCombine)
    }
  }

combinerbykey()要经历三个阶段

第一个阶段是第一条记录的处理,第二个阶段是第二条及之后的记录的处理,第三个阶段是合并之前溢写出来多个小文件的处理。

这也是combinerbykey()要传入三个函数的原因。

这个函数将三个函数封装进一个aggregator里面,在函数的最后,会 new 一个ShuffledRDD,并调用setSerializer()setAggregator()setMapSideCombine()三个方法,第一个方法是序列化,第二个方法是聚合操作,即执行传入的三个方法,第三个是设置map端聚合。

我们先来看一下ShuffledRDD。

ShuffledRDD 源码

class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
    @transient var prev: RDD[_ <: Product2[K, V]],
    part: Partitioner)
  extends RDD[(K, C)](prev.context, Nil) {

注意,这里prev加了一个@transient注解,@transient的意思就是序列化时忽略这个变量。这意味着shuffledRDD序列化时,无法将其前面的RDD也序列化,后面的RDD也就无法获取shuffledRDD之前的RDD的引用了,所以shuffledRDD需要从前一个RDD的输出中拉取数据,而不是通过迭代器从源头开始计算。
再看构造shuffledRDD的过程中传入的参数,deps参数是Nil,也就是说它的依赖是空,不依赖前一个RDD了

shuffledRDD有一个getDependencies()方法来获取依赖。

getDependencies() 源码
override def getDependencies: Seq[Dependency[_]] = {
    val serializer = userSpecifiedSerializer.getOrElse {
      val serializerManager = SparkEnv.get.serializerManager
      if (mapSideCombine) {
        serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[C]])
      } else {
        serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[V]])
      }
    }
    List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
  }

这个方法最终会返回一个List,List里面new 了一个ShuffleDependency,第一个参数是前一个RDD,第二个参数是分区器,第三个是序列化器,第四个是key是否排序,第五个是聚合器,聚合器存的就是combinerBykey()方法中的三个聚合函数,第六个参数是map端是否聚合。这些参数共同构成了一个ShuffleDependency

shuffledRDD也有一个compute()方法

compute()源码
  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }

这个compute()同样会被RDD的iterator()方法调用,那么它被调用时会发生哪些事呢?它会拿到上一个RDD的依赖,然后通过sparkEnv来获取ShuffleManager,最终返回一个Reader,这个reader调用read()方法返回一个迭代器Iterator。这里我们发现,它没有调用父类的迭代器,因为前面是一个独立的计算过程,它会将自己的结果输出到一个文件中,shuffledRDD只是从这个文件中拉取上一个计算过程中输出的结果,而不用去重新跑一遍。

为了更直观的了解spark的运行过程,在reduceTask端再加一个map()操作,这个map方法和之前的一样,从shuffledRDD中获取数据。

执行foreach(println)作为action算子,进入foreach()方法

  /**
   * Applies a function f to all elements of this RDD.
   */
  def foreach(f: T => Unit): Unit = withScope {
    val cleanF = sc.clean(f)
    sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
  }

foreach()最终又调用了sparkcontext.runJob()

你可能感兴趣的:(spark 源码分析系列 - wordcount 源码分析)