【Spark二七】Spark Shuffle写过程源代码代码剖析

以WordCount为例,最简单情况的Shuffle过程为例,展示Spark Shuffle的读写过程,

WordCount代码:

 

package spark.examples

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext

import org.apache.spark.SparkContext._

object SparkWordCount {
  def main(args: Array[String]) {
    System.setProperty("hadoop.home.dir", "E:\\devsoftware\\hadoop-2.5.2\\hadoop-2.5.2");
    val conf = new SparkConf()
    conf.setAppName("SparkWordCount")
    conf.setMaster("local")
    val sc = new SparkContext(conf)
    val rdd = sc.textFile("file:///D:/word.in")
    println(rdd.toDebugString)
    val rdd1 = rdd.flatMap(_.split(" "))
    println("rdd1:" + rdd1.toDebugString)
    val rdd2 = rdd1.map((_, 1))
    println("rdd2:" + rdd2.toDebugString)
    val rdd3 = rdd2.reduceByKey(_ + _);
    println("rdd3:" + rdd3.toDebugString)
    rdd3.saveAsTextFile("file:///D:/wordout" + System.currentTimeMillis());
    sc.stop
  }
}

 

 上面的代码,在Spark内部会创建六个RDD,两个Stage,两个Task, 一个Job。

前面四个RDD形成一个Stage,后面两个RDD形成第二个Stage。第一个Stage对应的是ShuffleMapTask,类似于Hadoop的Map阶段,类似将数据进行单词分解,计数(为1)。第二个Stage对应的是ResultTask,类似于Hadoop的Reduce阶段,用于将Map阶段的结果Shuffle到Reduce节点,然后经过统计计算,写入本地磁盘(本例所示)。

 

为什么Stage0的最后一个RDD(MappedRDD)和ShuffleRDD之间没有依赖关系?答:对于有依赖关系的RDD,假如RDD B依赖于RDD A,那么RDD B的compute方法,必然要从A中获取输入,而对于ShuffleRDD,它是从上一个Stage的输出中获取输入(可能是内存也可能是HDFS),(不同的Stage位于不同的Task,并且不同阶段的Task是串行执行的,因此,在上一个Stage结束时,它的Task已经结束了,因此RDD的生命周期也结束了),所以ShuffledRDD只能从RDD之外获取数据,因为ShuffleRDD不是从它的父RDD获取,所以ShuffleRDD没有依赖。

这就是说,每个Stage的第一个RDD不依赖于其它RDD。

 
【Spark二七】Spark Shuffle写过程源代码代码剖析_第1张图片
 

 

 

 

ShuffledRDD是个经过shuffle后开始的RDD,那么它依赖的数据分散在多个Mapper节点上,那么ShuffledRDD就需要将它们拉取回来,所以,真正产生从mapper拉取Shuffle数据的RDD是ShuffledRDD。也就是说,ShuffledRDD,是个依赖于多个节点的数据的RDD(也就是ShuffleRDD是个宽依赖)

 

那么问题来了,第一个Stage的计算结果存放在哪,第二个Stage的Reduce从何处取数据? 第一个Stage和第二个Stage是否有重叠,即Stage0产出一部分数据,Stage1立马可取?

 

 

DAGScheduler将RDD图进行Stage,划分为两个Stage,Stage0对应的ShuffleMapTask(负责Map的Shuffle写),Stage1对应的是ResultTask(首先读数据,然后shuffle的统计整理,进行写操作)

 

ShuffleMapTask的runTask的主流程

1.通过调用如下语句反序列化得到Stage0的最后一个RDD,这里是MappedRDD,以及ShuffleDependency实例。序列化的数据时task的二进制表示taskBinary.value,即从SparkContext提交过来的任务?

注意的是taskBinary是个Broadcast变量,它的定义在ShuffleMapTask是,taskBinary: Broadcast[Array[Byte]]。

a. rdd中有个成员变量f,它是这个RDD携带的操作,在WordCount的,他表示val rdd3 = rdd2.reduceByKey(_ + _)语句中的_ + _,

 

 

    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
b.rdd中的dependencies_ 成员变量包含了这个RDD依赖关系图,也完成了序列化
c. dep是ShuffleDependency实例,它的类声明为
/**
 * :: DeveloperApi ::
 * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
 * the RDD is transient since we don't need it on the executor side.
 *
 * @param _rdd the parent RDD
 * @param partitioner partitioner used to partition the shuffle output
 * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
 *                   the default serializer, as specified by `spark.serializer` config option, will
 *                   be used.
 */

 因为它的rdd属性是transisent,因此,在我们看到的反序列化得到seq,rdd为空。

问题:在ShuffleMapTask中就有了ShuffleDependency,那么在ResultTask中是否也有此ShuffleDependency,言外之意是,到底哪个Stage是shuffle Stage?

2.通过SparkEnv.get.shuffleManager获取SortSuffleManager实例,SortShuffleManager包含有一个IndexShuffleBlockManager实例,而IndexShuffleBlockManager实例包含org.apache.spark.storage.BlockManager实例。

IndexShuffleBlockManager实例和org.apache.spark.storage.BlockManager实例都是从SparkEnv中获取的,因此可以认为IndexShuffleBlockManager实例和org.apache.spark.storage.BlockManager是全局唯一的。

 

3.通过SortShuffleManager获取SortShuffleWriter实例,SortShuffleWriter实例中包含

  • IndexShuffleBlockManager实例(与SortSuffleManager实例中的IndexShuffleBlockManager实例同一个),
  • org.apache.spark.storage.BlockManager实例
  • MapStatus实例,这个实例用于最后返回?

writer的实例是通过如下语句获得的,因此在获取是需要传入shuffleHandle对象。dep.shuffleHandle是个方法调用,它要为shuffleManager注册shuffle后得到的句柄,

 

获取writer实例

  writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) //启动的partitionId表示的是当前RDD的某个partition,也就是说write操作作用于partition之上

 

 dep.shuffleHandle方法调用:

  val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(shuffleId, _rdd.partitions.size, this)

  shuffleManager.registerShuffle方法调用

  /**
   * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
   */
  override def registerShuffle[K, V, C](
      shuffleId: Int,
      numMaps: Int,
      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
    new BaseShuffleHandle(shuffleId, numMaps, dependency)
  }

 

 

 

4. 调用SortShuffleWriter的write方法,将结果输出到Shuffle的Output,那么write方法的输入是什么呢?它是调用Stage0的最后一个RDD的iterator方法,将RDD中的数据转换为一个数组返回,那么,执行逻辑将专项调用

RDD(Stage0的最后一个RDD是MappedRDD)的iterator方法获取RDD中的数据。这个iterator方法,会将Stage中所有的RDD的iterator方法调用一遍,具体的算法是:

假如A<-B<-C<-D

D,C,B分别调用它的parent的iterator方法获得数据,然后调用自身的RDD携带的函数,得到一个数据,传递给下一个RDD。

以C为例,

D调用C的iterator方法获取C的数据,那么C在实现时,调用B的iterator方法获得数据,然后C调用自身的函数得到转换的数据后,将它传递给D。这就是RDD的pipeline思想的所在。

A没有依赖的RDD,因为A是HadoopRDD,它的输入不是它依赖的RDD,而是Hadoop文件系统中的InputSplit(这个是个抽象的概念,在Spark中,本地文件也会被封装成HadoopRDD

比如,如下是FlatMapRDD的compute方法:

  override def compute(split: Partition, context: TaskContext) =
    firstParent[T].iterator(split, context).flatMap(f)

 

 

5. SortShuffleWriter的write方法的代码:

 

override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
    if (dep.mapSideCombine) { ///这个在什么地方定义的?
      if (!dep.aggregator.isDefined) {
        throw new IllegalStateException("Aggregator is empty for map-side combine")
      }
      //ExternalSorter是关键类,在构造ExternalSort类是,需要提供aggravator,partitioner。keyOrdering,serializer用来做什么的?
      ///aggregator用来做什么的?
      sorter = new ExternalSorter[K, V, C]( 
        dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)

      ///将依赖的RDD读取过来后,调用sorter的insertAll,insertAll做了什么操作?关注下sorter构造的时候,有哪些参数?blockManager从env中直接获取
      sorter.insertAll(records)
    } else {
      // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
      // care whether the keys get sorted in each partition; that will be done on the reduce side
      // if the operation being run is sortByKey.
      sorter = new ExternalSorter[K, V, V](
        None, Some(dep.partitioner), None, dep.serializer)
      sorter.insertAll(records)
    }

    val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
    val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)
    val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
    
    ///将shuffleId,mapId,写入索引文件
    shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)

    ////记录下shuffleId和blockManager之间的对应关系,在ResultTask中依赖此关系获取Block
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) 
  }

 

 

 

6. ExternalSorter类

private[spark] class ExternalSorter[K, V, C](
    aggregator: Option[Aggregator[K, V, C]] = None, ///aggregator用来做什么操作的?
    partitioner: Option[Partitioner] = None,
    ordering: Option[Ordering[K]] = None,
    serializer: Option[Serializer] = None)
  extends Logging with Spillable[SizeTrackingPairCollection[(Int, K), C]] {

  private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) ///从分区器中获取分区的个数
  private val shouldPartition = numPartitions > 1 ///Partition的个数大于1的时候,才做分区

  private val blockManager = SparkEnv.get.blockManager  ///从evn中获取blockManager
  private val diskBlockManager = blockManager.diskBlockManager ///从blockManager获取diskBlockManager
  private val ser = Serializer.getSerializer(serializer)
  private val serInstance = ser.newInstance()

  private val conf = SparkEnv.get.conf
  private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true) //是否spill到磁盘
  private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 //每个文件buffer的size,默认为32k
  private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true) //是否启用spark.file.transferTo,这个参数控制什么的?

  // Size of object batches when reading/writing from serializers.
  //
  // Objects are written in batches, with each batch using its own serialization stream. This
  // cuts down on the size of reference-tracking maps constructed when deserializing a stream.
  //
  // NOTE: Setting this too low can cause excessive copying when serializing, since some serializers
  // grow internal data structures by growing + copying every time the number of objects doubles.
  private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)

 

ExternalSorters类是个非常重要的类,对于reduceByKey而言,Map端每个节点在写磁盘前会做combine操作,即把本节点相同的Key进行combine,这个combine操作就是_ + _,即对相同的key做加法操作。

 

7.Stage0的写磁盘操作发生在SortShuffleWriter的write方法中,

 

  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
    if (dep.mapSideCombine) {
      if (!dep.aggregator.isDefined) {
        throw new IllegalStateException("Aggregator is empty for map-side combine")
      }
      sorter = new ExternalSorter[K, V, C](
        dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
      ///重点是sorter.insertAll这条语句的执行
      ///records是stage0最后一个RDD经过对它依赖的RDD进行函数计算后得到的记录
      sorter.insertAll(records) 
    } else {
      // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
      // care whether the keys get sorted in each partition; that will be done on the reduce side
      // if the operation being run is sortByKey.
      sorter = new ExternalSorter[K, V, V](
        None, Some(dep.partitioner), None, dep.serializer)
      sorter.insertAll(records)
    }

    val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
    val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)
    val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
    shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)

    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
  }

 

8. ExternalSorter.insertAll方法

 

def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
    // TODO: stop combining if we find that the reduction factor isn't high
    val shouldCombine = aggregator.isDefined ///应该是true

    if (shouldCombine) { ///这条语句会执行
      // Combine values in-memory first using our AppendOnlyMap
      val mergeValue = aggregator.get.mergeValue //mergeValue是个函数定义,指的就是val rdd3 = rdd2.reduceByKey(_ + _);中的_ + _运算

      //createCombiner这个函数指定的是PairRDDFunctions的如下方法的(v:V)=>v方法
      /*
       def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = {
         combineByKey[V]((v: V) => v, func, func, partitioner)
       }
      */
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null //kv就是records每次遍历得到的中的(K V)值
      
      //update方法是个关键官方,它接受两个参数,1.是否在Map中包含了值hasValue,2.旧值是多少,如果还不存在,那是null,在scala中,null也是一个对象
      val update = (hadValue: Boolean, oldValue: C) => {
        ///如果已经存在,则进行merge,根据Key进行merge(所谓的merge,就是调用mergeValue方法),否则调用createCombiner获取值
        //createCombiner方法是(v:V)=>v就是原样输出值
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }
      while (records.hasNext) {
        addElementsRead()///遍历计数,每遍历一次增1
        kv = records.next() ///读取当前record
        map.changeValue((getPartition(kv._1), kv._1), update) ///关键代码
        maybeSpillCollection(usingMap = true) ///是否要spill到磁盘,这个需要根据数据量来看这个代码
      }
    } else if (bypassMergeSort) {
      // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
      if (records.hasNext) {
        spillToPartitionFiles(records.map { kv =>
          ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
        })
      }
    } else {
      // Stick values into our buffer
      while (records.hasNext) {
        addElementsRead()
        val kv = records.next()
        buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
        maybeSpillCollection(usingMap = false)
      }
    }
  }

 

10.ExternalSorter.insertAll方法中调用map.changeValue((getPartition(kv._1), kv._1), update)

首先

  • map是SizeTrackingAppendOnlyMap,这个Map的说明文档是:
  // Data structures to store in-memory objects before we spill. Depending on whether we have an
  // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
  // store them in an array buffer.
  • kv是record的每条记录
  • update函数是在ExternalSorter.insertAll方法中创建的
  • getPartition(kv._1)是根据Key获得Parttion,它的方法实现是
  private def getPartition(key: K): Int = {
    if (shouldPartition) partitioner.get.getPartition(key) else 0 ///partitioner是个HashPartitioner,如果不进行Partition,则返回0,表示仅有1个Partition
  }

 

 

10. 看SizeTrackingAppendOnlyMap.changeValue方法的实现

 

  override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
    val newValue = super.changeValue(key, updateFunc)
    super.afterUpdate()
    newValue
  }

在上面SizeTrackingAppendOnlyMap的changeValue调用父类AppendOnlyMap的changeValue方法

 

def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
    assert(!destroyed, destructionMessage)
    val k = key.asInstanceOf[AnyRef]
    if (k.eq(null)) {///如果key是null,表示什么含义?
      if (!haveNullValue) {
        incrementSize()
      }
      nullValue = updateFunc(haveNullValue, nullValue)///nullValue是个val类型,定义于AppendOnlyMap类中
      haveNullValue = true
      return nullValue
    }
    var pos = rehash(k.hashCode) & mask ////对于Key进行rehash,计算出这个key在SizeTrackingAppendOnlyMap这个数据结构中的位置
    var i = 1
    while (true) {
      val curKey = data(2 * pos)///data是个数组,应该是AppendOnlyMap底层的数据结构,它使用两倍数据的容量,这是为何?原因是2*pos表示key,2*pos+1表示key对应的value
      if (k.eq(curKey) || k.equals(curKey)) { ///当前key已经存在于Map中,则需要做combine操作
        val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) ///对Map中缓存的Key的Value进行_ + _操作,updateFunc即是在ExternalSorter.insertAll方法中创建的update函数
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]///将新值回写到data(2*pos+1)处,不管data(2*pos + 1)处是否有值
        return newValue
      } else if (curKey.eq(null)) {///如果当前Map中,data(2*pos)处是null对象
        val newValue = updateFunc(false, null.asInstanceOf[V])///调用_ + _操作获取kv的value值
        data(2 * pos) = k ///Key
        data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] ///Value
        incrementSize() ///因为是Map中新增的K/V,做容量扩容检查
        return newValue
      } else {////如果当前Map中,data(2*pos)处是空
        val delta = i ///这是什么意思?
        pos = (pos + delta) & mask
        i += 1
      }
    }
    null.asInstanceOf[V] // Never reached but needed to keep compiler happy
  }

 

 

 11.ExternalSorter.insertAll方法调用maybeSpillCollection(usingMap = true)

从方法说明中,如果有必要,将当前内存中的集合spill到磁盘

参数usingMap则指定是使用map还是buffer做为集合内的存储,如果使用Map,则是使用SizeTrackingAppendOnlyMap;如果使用buffer则是使用SizeTrackingPairBuffer。其中map和buffer都是ExternalSorter中定义的两个内存集合数据结构

 

 

  /**
   * Spill the current in-memory collection to disk if needed.
   *
   * @param usingMap whether we're using a map or buffer as our current in-memory collection
   */
  private def maybeSpillCollection(usingMap: Boolean): Unit = {
    if (!spillingEnabled) { ///默认启用,如果不启用,则有OOM风险
      return
    }

    if (usingMap) { ///如果使用Map,则有可能重建Map,重建Map,需要将Map中的数据转储到磁盘中,是否要做磁盘级的重新combine呢?
      if (maybeSpill(map, map.estimateSize())) {
        map = new SizeTrackingAppendOnlyMap[(Int, K), C]
      }
    } else {
      if (maybeSpill(buffer, buffer.estimateSize())) {
        buffer = new SizeTrackingPairBuffer[(Int, K), C]
      }
    }
  }

 

 11.1 maybeSpill方法,

maybeSpill方法属于ExternalSorter类,但是它是在ExternalSorter的父接口(trait)Spillable中定义的,因此调用将转到Spillable的maybeSpill方法中

方法声明中说,在spill之前,尝试获得更多的内存。意思是说,如果获得了更多的内存,是否还要做spill到磁盘的动作?

 

  /**
   * Spills the current in-memory collection to disk if needed. Attempts to acquire more
   * memory before spilling.
   *
   * @param collection collection to spill to disk
   * @param currentMemory estimated size of the collection in bytes
   * @return true if `collection` was spilled to disk; false otherwise
   */
  protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
   //判断逻辑:
   //1. if判断如果为false,那么不进行spill
   //2. if判断如果为true, 如果我的内存阀值小于当前已使用的内存,则进行spill,否则不进行spill

   //elementsRead表示已经读取过的元素个数,只有当前读过的elements为32的整数倍才有可能spill,elementsRead%32==0
    //trackMemoryThreshold是预定义的1000,
   // Threshold for `elementsRead` before we start tracking this collection's memory usage
   //  private[this] val trackMemoryThreshold = 1000 

  //myMemoryThreshod默认是5M:
  // Initial threshold for the size of a collection before we start tracking its memory usage
  // Exposed for testing
  //private[this] val initialMemoryThreshold: Long =
  //  SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)

  ///currentMemery是一个预估值,表示当前占用的内存,它是由map.estimateSize()计算而来

  //amountToRequest:申请的容量是当前使用容量*2减去内存阀值
   if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
        currentMemory >= myMemoryThreshold) {
      // Claim up to double our current memory from the shuffle memory pool
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
     ///将申请的内存添加到我的能接受的内存阀值之上,即增加我的可忍受的内存阀值
      myMemoryThreshold += granted
      ///此时的内存阀值还是小雨当前使用两,则必须进行spill了
      if (myMemoryThreshold <= currentMemory) {
        // We were granted too little memory to grow further (either tryToAcquire returned 0,
        // or we already had more memory than myMemoryThreshold); spill the current collection
        _spillCount += 1
        logSpillage(currentMemory)
        ///进行spill操作
        spill(collection)

        _elementsRead = 0 ///已读数清0
        // Keep track of spills, and release memory
        _memoryBytesSpilled += currentMemory ///已经释放的内存总量,  // Number of bytes spilled in total  private[this] var _memoryBytesSpilled = 0L
        releaseMemoryForThisThread() ///因为已经spill到磁盘,所以需要释放已经占用的内存,将我的内存阀值恢复到最初值
        return true
      }
    }
    false ///上面的条件不满足,则不需要spill
  }、

 

11.2 记录spill日志logSpillage

 

  /**
   * Prints a standard log message detailing spillage.
   *
   * @param size number of bytes spilled
   */
  @inline private def logSpillage(size: Long) {
    val threadId = Thread.currentThread().getId
    logInfo("Thread %d spilling in-memory map of %s to disk (%d time%s so far)"
      .format(threadId, org.apache.spark.util.Utils.bytesToString(size),
        _spillCount, if (_spillCount > 1) "s" else ""))
  }

 

11.3 spill(Collection)

 

  /**
   * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
   */
  override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
    if (bypassMergeSort) { /根据这个参数决定如何spill
      spillToPartitionFiles(collection)
    } else {
      spillToMergeableFile(collection)
    }
  }

 

spillToPartitionFiles:

  private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = {
    assert(bypassMergeSort)

    // Create our file writers if we haven't done so yet
    if (partitionWriters == null) {
      curWriteMetrics = new ShuffleWriteMetrics()
      partitionWriters = Array.fill(numPartitions) {
        // Because these files may be read during shuffle, their compression must be controlled by
        // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
        // createTempShuffleBlock here; see SPARK-3426 for more context.
        val (blockId, file) = diskBlockManager.createTempShuffleBlock()
        blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open()
      }
    }

    // No need to sort stuff, just write each element out
    while (iterator.hasNext) {
      val elem = iterator.next()
      val partitionId = elem._1._1
      val key = elem._1._2
      val value = elem._2
      partitionWriters(partitionId).write((key, value))
    }
  }

 

spillToMergeableFile

/**
   * Spill our in-memory collection to a sorted file that we can merge later (normal code path).
   * We add this file into spilledFiles to find it later.
   *
   * Alternatively, if bypassMergeSort is true, we spill to separate files for each partition.
   * See spillToPartitionedFiles() for that code path.
   *
   * @param collection whichever collection we're using (map or buffer)
   */
  private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
    assert(!bypassMergeSort)

    // Because these files may be read during shuffle, their compression must be controlled by
    // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
    // createTempShuffleBlock here; see SPARK-3426 for more context.
    val (blockId, file) = diskBlockManager.createTempShuffleBlock()
    curWriteMetrics = new ShuffleWriteMetrics()
    var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
    var objectsWritten = 0   // Objects written since the last flush

    // List of batch sizes (bytes) in the order they are written to disk
    val batchSizes = new ArrayBuffer[Long]

    // How many elements we have in each partition
    val elementsPerPartition = new Array[Long](numPartitions)

    // Flush the disk writer's contents to disk, and update relevant variables.
    // The writer is closed at the end of this process, and cannot be reused.
    def flush() = {
      val w = writer
      writer = null
      w.commitAndClose()
      _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
      batchSizes.append(curWriteMetrics.shuffleBytesWritten)
      objectsWritten = 0
    }

    var success = false
    try {
      val it = collection.destructiveSortedIterator(partitionKeyComparator)
      while (it.hasNext) {
        val elem = it.next()
        val partitionId = elem._1._1
        val key = elem._1._2
        val value = elem._2
        writer.write(key)
        writer.write(value)
        elementsPerPartition(partitionId) += 1
        objectsWritten += 1

        if (objectsWritten == serializerBatchSize) {
          flush()
          curWriteMetrics = new ShuffleWriteMetrics()
          writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
        }
      }
      if (objectsWritten > 0) {
        flush()
      } else if (writer != null) {
        val w = writer
        writer = null
        w.revertPartialWritesAndClose()
      }
      success = true
    } finally {
      if (!success) {
        // This code path only happens if an exception was thrown above before we set success;
        // close our stuff and let the exception be thrown further
        if (writer != null) {
          writer.revertPartialWritesAndClose()
        }
        if (file.exists()) {
          file.delete()
        }
      }
    }

    spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
  }

 

 

 

 

 11.4回收内存

  /**
   * Release our memory back to the shuffle pool so that other threads can grab it.
   */
  private def releaseMemoryForThisThread(): Unit = {
    // The amount we requested does not include the initial memory tracking threshold
    shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold) ///回收的内存总量,不能减去自身的大小
    myMemoryThreshold = initialMemoryThreshold
  }

 

 

 

 

 

 

 

 

 

 

 

 

 

 Map的输出和Reduce的输入是怎么串联到一起的?即Reduce如何知道从哪里获得Map的输出?不管Map输出到磁盘还是内存?

 

 通过MapOutputTracker?MapOutputTracker是一个全局的Map?获取ServerStatus的依据是shuffleId和reduceId,这两个变量是怎么保存的?

 

    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)

 

MapOutputTracker的类说明:

 

/**
 * Class that keeps track of the location of the map output of
 * a stage. This is abstract because different versions of MapOutputTracker
 * (driver and worker) use different HashMap to store its metadata.
 */

 

 

 

 MapOutputTracker的getServerStatuser方法的说明:

 

  /**
   * Called from executors to get the server URIs and output sizes of the map outputs of
   * a given shuffle.
   */

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

  override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val ser = SparkEnv.get.closureSerializer.newInstance()

    ///rdd对应stage0的最后一个MappedRDD,这个rdd关联的函数是 val rdd2 = rdd1.map((_, 1)),即把每个单词转换为(单词,1)的操作
    ///dep是ShuffleDependency,
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ///对taskBinaray.value进行反序列化,得到rdd和dep的Tuple
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

    metrics = Some(context.taskMetrics)
    var writer: ShuffleWriter[Any, Any] = null
    try {
      //取默认的SortShuffleManager
      val manager = SparkEnv.get.shuffleManager
      ///获取SortShuffleWriter,writer中包含IndexShuffleBlockManager实例
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)

      ///partition是HadoopPartitioner对象,其中包含InputSplit以及partition的idx,表示是InputSplit中的第几个分片?
      ///rdd.iterator方法是做什么的?调用rdd的compute方法以获取数据(通过遍历的方式)
      ///调用rdd的父rdd,这里的
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])

      return writer.stop(success = true).get
    } catch {
      case e: Exception =>
        try {
          if (writer != null) {
            writer.stop(success = false)
          }
        } catch {
          case e: Exception =>
            log.debug("Could not stop writer", e)
        }
        throw e
    }
  }

rdd和dep

 
【Spark二七】Spark Shuffle写过程源代码代码剖析_第2张图片
 

 

 

 taskBinary.value字节对应的字符串
【Spark二七】Spark Shuffle写过程源代码代码剖析_第3张图片
 

 SortShuffleManager的getWriter方法

 

  /** Get a writer for a given partition. Called on executors by map tasks. */
  override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
      : ShuffleWriter[K, V] = {
    val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]]
    shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps)
    new SortShuffleWriter(
      ///shuffleBlockManager是个方法调用,返回的是IndexShuffleBlockManager
      shuffleBlockManager, baseShuffleHandle, mapId, context)
  }

 

 rdd.iterator方法

 

  /**
   * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
   * This should ''not'' be called by users directly, but is available for implementors of custom
   * subclasses of RDD.
   */
  final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
    } else {
      computeOrReadCheckpoint(split, context)
    }
  }

 

你可能感兴趣的:(shuffle)