spark shuffle过程源码解析

前言

为了更好的理解spark的shuffle过程,通过走读源码,彻底理解shuffle过程中的执行过程以及与排序相关的内容。

本文所使用的spark版本为:2.4.4

1、shuffle之BypassMergeSortShuffleWriter

基本原理:

1、下游reduce有多少个分区partition,上游map就建立多少个fileWriter[reduceNumer],每一个下游分区的数据写入到一个独立的文件中。当所有的分区文件写完之后,将多个分区的数据合并到一个文件中,代码如下:

while (records.hasNext()) {
      final Product2 record = records.next();
      final K key = record._1();
      //作者注:将数据写到对应分区的文件中去。
      partitionWriters[partitioner.getPartition(key)].write(key, record._2());
    }

    for (int i = 0; i < numPartitions; i++) {
      final DiskBlockObjectWriter writer = partitionWriters[i];
      partitionWriterSegments[i] = writer.commitAndGet();
      writer.close();
    }

    File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
    File tmp = Utils.tempFileWith(output);
    try {
      //作者注:合并所有分区的小文件为一个大文件,保证同一个分区的数据连续存在
      partitionLengths = writePartitionedFile(tmp);
      //作者注:构建索引文件
      shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
    } finally {
      if (tmp.exists() && !tmp.delete()) {
        logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
      }
    }

2、由于每一个独立的分区文件的数据都是属于同一个reduce的,在进行文件合并的时候,不需要进行排序,只需要按照文件顺序合并到一个文件中即可,并建立对应的分区数据索引文件。

3、使用BypassMergeSortShuffleWriter的条件是:

        (1)、下游分区的个数不能超过参数spark.shuffle.sort.bypassMergeThreshold的值(默认是200)

        (2)、非map端预聚合算子(reduceByKey)

        具体判断代码如下:

def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
    // We cannot bypass sorting if we need to do map-side aggregation.
    if (dep.mapSideCombine) {
      false
    } else {
      val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
      dep.partitioner.numPartitions <= bypassMergeThreshold
    }
  }

2、shuffle之SortShuffleWriter

执行该writer的条件:

(1)、下游分区partitions的个数超过spark.shuffle.sort.bypassMergeThreshold参数设置的值(默认200)

(2)、跳过一个叫做UnsafeShuffleWriter(详情见3)的writer

执行过程描述:

1、如果map端是预聚合算子(例如reduceByKey)

(1)、使用一个map:PartitionedAppendOnlyMap对象进行数据的存储和预聚合,代码如下:

        备注:可以看到在map.changeValue时,更新的key并不是数据的key,而是在数据key的基础上,加上了该key的分区id((getPartition(kv._1), kv._1)),这样做的目的是为了下面将数据溢出到磁盘时,按分区id进行排序,以保证同一个分区的数据能连续存放到一起。

//作者注:判断是否是一个预聚合算子
if (shouldCombine) {
      // Combine values in-memory first using our AppendOnlyMap
      //作者注:获取预聚合算子的执行函数
      val mergeValue = aggregator.get.mergeValue
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }
      while (records.hasNext) {
        addElementsRead()
        kv = records.next()
        //作者注:使用一个map:PartitionedAppendOnlyMap类型进行数据的存储和预聚合更新
        map.changeValue((getPartition(kv._1), kv._1), update)
        //作者注:执行溢出到磁盘操作
        maybeSpillCollection(usingMap = true)
      }
    }

(2)、执行数据溢出到磁盘操作:maybeSpillCollection,代码如下:

 private def maybeSpillCollection(usingMap: Boolean): Unit = {
    var estimatedSize = 0L
    //作者注:判断是否是预聚合算子
    if (usingMap) {
      //作者注:预聚合算子,则把map对象里面的数据写入到磁盘
      estimatedSize = map.estimateSize()
      if (maybeSpill(map, estimatedSize)) {
        map = new PartitionedAppendOnlyMap[K, C]
      }
    } else {
      //作者注:不是预聚合算子,则把buffer对象里面的数据写入到磁盘
      estimatedSize = buffer.estimateSize()
      if (maybeSpill(buffer, estimatedSize)) {
        buffer = new PartitionedPairBuffer[K, C]
      }
    }

       然后执行maybeSpill函数,根据溢出条件判断是否溢出到磁盘,代码如下:

protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
    var shouldSpill = false
    if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
      // Claim up to double our current memory from the shuffle memory pool
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      val granted = acquireMemory(amountToRequest)
      myMemoryThreshold += granted
      // If we were granted too little memory to grow further (either tryToAcquire returned 0,
      // or we already had more memory than myMemoryThreshold), spill the current collection
      shouldSpill = currentMemory >= myMemoryThreshold
    }

    //作者注:是否溢出磁盘,有两个判断条件
    //1、shouldSplill:判断内存空间的是否充足
    //2、_elementsRead > numElementsForceSpillThreshold:判断当前的写的数据条数是否超过阈值numElementsForceSpillThreshold(默认Integer.MAX_VALUE)
    shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
    // Actually spill
    if (shouldSpill) {
      _spillCount += 1
      logSpillage(currentMemory)
      spill(collection)
      _elementsRead = 0
      _memoryBytesSpilled += currentMemory
      releaseMemory()
    }
    shouldSpill
  }

        如果满足条件,这执行spill(collection)进行数据溢出,代码如下:

override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
    val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
    val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
    spills += spillFile
  }

        看一下这行代码:

val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)

        这一行代码的作用是对数据进行排序,如何排序呢?经过作者的debug过程发现,这个排序过程其实并不是对数据的key进行排序,而是对分区id进行排序,这样保证同一个分区的数据能连续在一起,为后续的溢出文件合并的归并排序提供基础。

        至此,数据的磁盘溢出操作完成。下一步就是如何将溢出的数据进行合并。

(3)、溢出磁盘数据文件合并成一个大文件,并建立一个分区的索引文件,具体的代码执行过程如下(SortShuffleWriter):里面具体的执行过程再次不再赘述。

try {
      val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
      //作者注:将溢出的磁盘文件和当前缓存的文件进行归并合并,保证同一分区的数据连续存在
      val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
      //作者注:构建索引文件
      shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
    } finally {
      if (tmp.exists() && !tmp.delete()) {
        logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
      }
    }

2、如果map端不是预聚合算子(例如groupByKey)

上面介绍了预聚合算子的shufflerwriter的执行过程,而非预聚合算子的shufflewriter的执行过程基本和预聚合算子是一样的,唯一不同的一点就是,存储数据的结构不是map:PartitionedAppendOnlyMap,而是buffer:PartitionedPairBuffer,代码如下:

if (shouldCombine) {
      // Combine values in-memory first using our AppendOnlyMap
      val mergeValue = aggregator.get.mergeValue
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }
      while (records.hasNext) {
        addElementsRead()
        kv = records.next()
        map.changeValue((getPartition(kv._1), kv._1), update)
        maybeSpillCollection(usingMap = true)
      }
    } else {
      // Stick values into our buffer
      //作者注:非预聚合算子的数据存储到buffer中
      while (records.hasNext) {
        addElementsRead()
        val kv = records.next()
        buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
        maybeSpillCollection(usingMap = false)
      }
    }

至于后续的数据溢出、数据排序、溢出数据文件的合并等过程,和预聚合算子的执行过程一模一样,调用的一样的执行过程,在此就不再赘述。

3、shuffle之UnsafeShuffleWriter

这个UnsafeShuffleWriter的具体执行过程作者没有深入追究,因为从名字上也能看出,Unsafe使用的是堆外内存进行数据的存储以及相关的操作,基本的原理是将数据对象进行序列化后存储到堆外内存,然后使用二进制的方式进行数据的排序工作,这样能提升运算性能。

在实际的执行过程中,是优先使用这种方式进行shuffle过程的write的,具体的执行条件如下:

def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
    val shufId = dependency.shuffleId
    val numPartitions = dependency.partitioner.numPartitions
    //作者注:序列化器支持relocation.
    //作者注:目前spark提供的有两个序列化器:JavaSerializer和KryoSerializer
    //其中KryoSerializer支持relocation,而JavaSerializer不支持relocation
    if (!dependency.serializer.supportsRelocationOfSerializedObjects) {
      log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
        s"${dependency.serializer.getClass.getName}, does not support object relocation")
      false
    } else if (dependency.mapSideCombine) { //作者注:非map端预聚合算子
      log.debug(s"Can't use serialized shuffle for shuffle $shufId because we need to do " +
        s"map-side aggregation")
      false
    } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
      //作者注:下游分区个数小于MAXIMUM_PARTITION_ID = (1 << 24) - 1
      log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
        s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions")
      false
    } else {
      log.debug(s"Can use serialized shuffle for shuffle $shufId")
      true
    }
  }

具体的执行逻辑作者没有深入追踪,因为追踪到后面可能全是二进制的数据,无法直观查看数据信息,读者如果有兴趣可以自行debug追踪。

你可能感兴趣的:(spark,spark,大数据,scala)