spark shuffle过程分析

    shuffle是作业执行过程中的一个重要阶段,对作业性能有很大影响,不管是对hadoop还是spark,shuffle都是一个核心环节,spark的shuffle和hadoop的shuffle的原理大致相同,shuffle发生在ShuffleMapTask中,在一个task处理partition数据时,需要对外输出作为下个stage的数据源,这个输出可能是不落盘的,但如果数据量很大,导致内存放不下,这时候就需要shuffle了,shuffle包含spill、和merge两个阶段,最终会输出为一个shuffle文件(spark1.3.1)和一个包含个partition索引的index文件。如果包含map端的aggregate和order则会对partition和KV数据进行排序,负责会按partition id进行排序,这样在reduce阶段,通过索引和该shuffle文件就能找到reduce需要的数据。
上一篇提到task的启动在runTask函数中,我们从这里开始看整个shuffle的过程
  override def runTask(context: TaskContext): MapStatus = {
    // 反序列化获得rdd和其依赖
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)


    metrics = Some(context.taskMetrics)
    var writer: ShuffleWriter[Any, Any] = null
    try {
      val manager = SparkEnv.get.shuffleManager
      //获得SortShuffleWriter实例
      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
      //中间数据的操作再这里写出,有先写内存,不够时写磁盘
      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
      return writer.stop(success = true).get
    } catch {
			.....
    }
  }	
写出时用到了ExternalSorter,根据是否启用map端合并创建不用类型的实例,如果指定了aggregator和keyOrdering则会按顺序输出
  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
    if (dep.mapSideCombine) {
      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
      sorter = new ExternalSorter[K, V, C](
        dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
      sorter.insertAll(records)
    } else {
      // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
      // care whether the keys get sorted in each partition; that will be done on the reduce side
      // if the operation being run is sortByKey.
      sorter = new ExternalSorter[K, V, V](
        None, Some(dep.partitioner), None, dep.serializer)
      sorter.insertAll(records)
    }


    // Don't bother including the time to open the merged output file in the shuffle write time,
    // because it just opens a single file, so is typically too fast to measure accurately
    // (see SPARK-3570).
    val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
    val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)//合并spill文件
    val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
    shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)//写索引文件


    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
  }
下面是遍历一个partition中的数据并做输出
  def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
    // TODO: stop combining if we find that the reduction factor isn't high
    val shouldCombine = aggregator.isDefined
		//如果执行combine则会对map数据进行merge操作
    if (shouldCombine) {
      // Combine values in-memory first using our AppendOnlyMap
      val mergeValue = aggregator.get.mergeValue
      val createCombiner = aggregator.get.createCombiner
      var kv: Product2[K, V] = null
      //merge函数
      val update = (hadValue: Boolean, oldValue: C) => {
        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
      }
      //遍历partition数据
      while (records.hasNext) {
        addElementsRead()
        kv = records.next()
        map.changeValue((getPartition(kv._1), kv._1), update)
        //对外输出,可能产生spill,至于何时spill,看下面分析
        maybeSpillCollection(usingMap = true)
      }
    } else if (bypassMergeSort) {
      // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
      if (records.hasNext) {
        spillToPartitionFiles(records.map { kv =>
          ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
        })
      }
    } else {
      // Stick values into our buffer
      while (records.hasNext) {
        addElementsRead()
        val kv = records.next()
        buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
        maybeSpillCollection(usingMap = false)
      }
    }
  }
这里包含了spill的触发条件,最关键条件就是内存不够,当当前内粗不足时会首先申请内存,如果申请后仍然不够则需要spill,那么内存够的情况下,数据放哪里呢?继续看后面分析
  protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
    if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
        currentMemory >= myMemoryThreshold) {
      // Claim up to double our current memory from the shuffle memory pool
      val amountToRequest = 2 * currentMemory - myMemoryThreshold
      val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
      myMemoryThreshold += granted
      if (myMemoryThreshold <= currentMemory) {
        // We were granted too little memory to grow further (either tryToAcquire returned 0,
        // or we already had more memory than myMemoryThreshold); spill the current collection
        _spillCount += 1
        logSpillage(currentMemory)


        spill(collection)


        _elementsRead = 0
        // Keep track of spills, and release memory
        _memoryBytesSpilled += currentMemory
        releaseMemoryForThisThread()
        return true
      }
    }
    false
  }
如果内存够用,则把数据放入一个SizeTrackingAppendOnlyMap数据结构中,这是一个只能append的kv map集合,由spark本身实现,注意并不是java的collection,有兴趣的朋友可以翻下代码看看,内部存储是一个Array
    返回头看看spill写文件的过程,上面讲到了spill的触发条件,那么条件满足后文件怎么写的,写到哪里呢,看下面分析
  /**
   * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
   */
  override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
    if (bypassMergeSort) {
      spillToPartitionFiles(collection)
    } else {
      spillToMergeableFile(collection)
    }
  }
下面操作会真正的写出数据,并且是按key有序的输出,输出形式是 key1,value1,key2,value2.....,真正写文件的时候还是比较简单的,直接用BlockObjectWriter把一个个对象写到文件
  /**
   * Spill our in-memory collection to a sorted file that we can merge later (normal code path).
   * We add this file into spilledFiles to find it later.
   *
   * Alternatively, if bypassMergeSort is true, we spill to separate files for each partition.
   * See spillToPartitionedFiles() for that code path.
   *
   * @param collection whichever collection we're using (map or buffer)
   */
  private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
    assert(!bypassMergeSort)


    // Because these files may be read during shuffle, their compression must be controlled by
    // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
    // createTempShuffleBlock here; see SPARK-3426 for more context.
    val (blockId, file) = diskBlockManager.createTempShuffleBlock()
    curWriteMetrics = new ShuffleWriteMetrics()
    var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
    var objectsWritten = 0   // Objects written since the last flush


    // List of batch sizes (bytes) in the order they are written to disk
    val batchSizes = new ArrayBuffer[Long]


    // How many elements we have in each partition
    val elementsPerPartition = new Array[Long](numPartitions)


    // Flush the disk writer's contents to disk, and update relevant variables.
    // The writer is closed at the end of this process, and cannot be reused.
    def flush() = {
      val w = writer
      writer = null
      w.commitAndClose()
      _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
      batchSizes.append(curWriteMetrics.shuffleBytesWritten)
      objectsWritten = 0
    }


    var success = false
    try {
      val it = collection.destructiveSortedIterator(partitionKeyComparator)
      while (it.hasNext) {
        val elem = it.next()
        val partitionId = elem._1._1
        val key = elem._1._2
        val value = elem._2
        writer.write(key)
        writer.write(value)
        elementsPerPartition(partitionId) += 1
        objectsWritten += 1


        if (objectsWritten == serializerBatchSize) {
          flush()
          curWriteMetrics = new ShuffleWriteMetrics()
          writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
        }
      }
      if (objectsWritten > 0) {
        flush()
      } else if (writer != null) {
        val w = writer
        writer = null
        w.revertPartialWritesAndClose()
      }
      success = true
    } finally {
      if (!success) {
        // This code path only happens if an exception was thrown above before we set success;
        // close our stuff and let the exception be thrown further
        if (writer != null) {
          writer.revertPartialWritesAndClose()
        }
        if (file.exists()) {
          file.delete()
        }
      }
    }
    //把spillwenjian加入集合中,以便后期合并
    spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
  }
最后spill文件会被合并成一个shuffle文件,spill文件的名字类似于temp_shuffle_&^%%$%之类的名字,合并为一个shuffle文件后,其他文件会被删除,上面就是整个shuffle过程,中间还有很多细节,感兴趣的朋友可以深入分析下。

你可能感兴趣的:(spark)