spark shuffle过程分析

    shuffle是作业执行过程中的一个重要阶段,对作业性能有很大影响,不管是对hadoop还是spark,shuffle都是一个核心环节,spark的shuffle和hadoop的shuffle的原理大致相同,shuffle发生在ShuffleMapTask中,在一个task处理partition数据时,需要对外输出作为下个stage的数据源,这个输出可能是不落盘的,但如果数据量很大,导致内存放不下,这时候就需要shuffle了,shuffle包含spill、和merge两个阶段,最终会输出为一个shuffle文件(spark1.3.1)和一个包含个partition索引的index文件。如果包含map端的aggregate和order则会对partition和KV数据进行排序,负责会按partition id进行排序,这样在reduce阶段,通过索引和该shuffle文件就能找到reduce需要的数据。
  override def runTask(context: TaskContext): MapStatus = {
    // 反序列化获得rdd和其依赖
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

    metrics = Some(context.taskMetrics)
    var writer: ShuffleWriter[Any, Any] = null
    try {
      val manager = SparkEnv.get.shuffleManager
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      return writer.stop(success = true).get
    } catch {
  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
    if (dep.mapSideCombine) {
      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
      sorter = new ExternalSorter[K, V, C](
        dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
    } else {
      // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
      // care whether the keys get sorted in each partition; that will be done on the reduce side
      // if the operation being run is sortByKey.
      sorter = new ExternalSorter[K, V, V](
        None, Some(dep.partitioner), None, dep.serializer)

    // Don't bother including the time to open the merged output file in the shuffle write time,
    // because it just opens a single file, so is typically too fast to measure accurately
    // (see SPARK-3570).
    val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
    val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)//合并spill文件
    val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
    shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)//写索引文件

    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
  def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
    // TODO: stop combining if we find that the reduction factor isn't high
    val shouldCombine = aggregator.isDefined
    if (shouldCombine) {
      // Combine values in-memory first using our AppendOnlyMap
      val mergeValue = aggregator.get.mergeValue
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      while (records.hasNext) {
        kv =
        map.changeValue((getPartition(kv._1), kv._1), update)
        maybeSpillCollection(usingMap = true)
    } else if (bypassMergeSort) {
      // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
      if (records.hasNext) {
        spillToPartitionFiles( { kv =>
          ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
    } else {
      // Stick values into our buffer
      while (records.hasNext) {
        val kv =
        buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
        maybeSpillCollection(usingMap = false)
  protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
    if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
        currentMemory >= myMemoryThreshold) {
      // Claim up to double our current memory from the shuffle memory pool
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
      myMemoryThreshold += granted
      if (myMemoryThreshold <= currentMemory) {
        // We were granted too little memory to grow further (either tryToAcquire returned 0,
        // or we already had more memory than myMemoryThreshold); spill the current collection
        _spillCount += 1


        _elementsRead = 0
        // Keep track of spills, and release memory
        _memoryBytesSpilled += currentMemory
        return true
如果内存够用,则把数据放入一个SizeTrackingAppendOnlyMap数据结构中,这是一个只能append的kv map集合,由spark本身实现,注意并不是java的collection,有兴趣的朋友可以翻下代码看看,内部存储是一个Array
   * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
  override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
    if (bypassMergeSort) {
    } else {
下面操作会真正的写出数据,并且是按key有序的输出,输出形式是 key1,value1,key2,value2.....,真正写文件的时候还是比较简单的,直接用BlockObjectWriter把一个个对象写到文件
   * Spill our in-memory collection to a sorted file that we can merge later (normal code path).
   * We add this file into spilledFiles to find it later.
   * Alternatively, if bypassMergeSort is true, we spill to separate files for each partition.
   * See spillToPartitionedFiles() for that code path.
   * @param collection whichever collection we're using (map or buffer)
  private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {

    // Because these files may be read during shuffle, their compression must be controlled by
    // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
    // createTempShuffleBlock here; see SPARK-3426 for more context.
    val (blockId, file) = diskBlockManager.createTempShuffleBlock()
    curWriteMetrics = new ShuffleWriteMetrics()
    var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
    var objectsWritten = 0   // Objects written since the last flush

    // List of batch sizes (bytes) in the order they are written to disk
    val batchSizes = new ArrayBuffer[Long]

    // How many elements we have in each partition
    val elementsPerPartition = new Array[Long](numPartitions)

    // Flush the disk writer's contents to disk, and update relevant variables.
    // The writer is closed at the end of this process, and cannot be reused.
    def flush() = {
      val w = writer
      writer = null
      _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
      objectsWritten = 0

    var success = false
    try {
      val it = collection.destructiveSortedIterator(partitionKeyComparator)
      while (it.hasNext) {
        val elem =
        val partitionId = elem._1._1
        val key = elem._1._2
        val value = elem._2
        elementsPerPartition(partitionId) += 1
        objectsWritten += 1

        if (objectsWritten == serializerBatchSize) {
          curWriteMetrics = new ShuffleWriteMetrics()
          writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
      if (objectsWritten > 0) {
      } else if (writer != null) {
        val w = writer
        writer = null
      success = true
    } finally {
      if (!success) {
        // This code path only happens if an exception was thrown above before we set success;
        // close our stuff and let the exception be thrown further
        if (writer != null) {
        if (file.exists()) {
    spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
