9.Shuffle读写源码分析

先直接上原理图吧 !
ShuffleMapTask在计算数据之后会为每一个ResultTask创建一份bucket缓存 , 以及对应的ShuffleBlockFIle磁盘文件进行储存 , 在计算完之后会将计算过的相应信息放入MapStatus , 最后发送给Driver中的DAGScheduler的MapOutputTracker , 每个ResultTask会用BlockStoreShuffleFetcher去MapOutputTracker中的MapStatus获取需要拉取的数据 , 然后通过底层的BlockManager将数据拉取过来 , 拉取过来的数据就会组成一个内部的RDD , 叫ShuffleRDD , 存入缓存 , 缓存不够存入磁盘 , 最后ResultMap对数据进行聚合生成MapPartitionRDD , 也就是我们所写程序中action操作后结果RDD


优化后的shuffle分析原理图:


9.Shuffle读写源码分析_第1张图片


优化后的shuffle原理就是根据cpu的数量在ShuffleMap写入数据到磁盘文件时只会创建于cpu相应的文件数据 , 后面在运行新的ShuffleMapTask的时候也只会向同样的文件中写入数据 , 同时会记录下一些索引来记录哪个ShuffleMapTask计算的数据在ShuffleBlockFile中的位置 , 多个ShuffleMapTask写入的数据就叫做一个segment , 也就是说原来的100个ShuffleMapTask对应的100个ResulTask时会创建100*100个磁盘文件 , 而现在只需要cpu数量乘以ResultMap的数量之积文件数 , 减少了大量的磁盘文件读写 , 这种优化shuffle的方式只需在创建SparkContext的时候设置一个参数即可

在上一章节关于Task的源码分析最后的关于writer的代码中:
writer . write ( rdd . iterator ( partition , context ). asInstanceOf [ Iterator [ _ <: Product2 [ Any , Any ]]])

其实这个writer默认的情况下就是HaspShuffleWriter , 调用writer的方法源码如下:
       
       
       
       
  1. /** Write a bunch of records to this task's output */
  2. /**
  3. * 将每个ShuffleMapTask计算出来的新的RDD的partition数据写入本地磁盘
  4. */
  5. override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
  6. // 首先判断,是否需要在map端进行本地聚合
  7. // 比如reduceByKey这样的算子操作的话它的dep.aggregator.isDegined就是true , 包括def.mapSideCombine也是true
  8. val iter = if (dep.aggregator.isDefined) {
  9. if (dep.mapSideCombine) {
  10. // 这里就会执行本地聚合,比如(Hi,1)(Hi,1)那么此时就会聚合成(Hi,2)
  11. dep.aggregator.get.combineValuesByKey(records, context)
  12. } else {
  13. records
  14. }
  15. } else {
  16. require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
  17. records
  18. }
  19. // 如果进行本地聚合那么就会遍历数据 , 对每个数据调用partition默认是HashPartition , 生成bucketId
  20. // 也就决定了每一份数据要写入哪个bucket
  21. for (elem <- iter) {
  22. val bucketId = dep.partitioner.getPartition(elem._1)
  23. // 获取到了bucketId之后就会调用ShuffleBlockManager.formapTask()方法来生成bucketId对应的writer,然后用writer将数据写入bucket
  24. shuffle.writers(bucketId).write(elem)
  25. }
  26. }

这里的shuffle是HushShuffleWriter的一个成员变量 , 通过shuffleBlockManager对象的forMapTask方法获取每个bucketId对应的writer , forMapTask方法源码如下:
       
       
       
       
  1. /**
  2. * 给每个map task获取一个ShuffleWriterGroup
  3. */
  4. def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer,
  5. writeMetrics: ShuffleWriteMetrics) = {
  6. new ShuffleWriterGroup {
  7. shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
  8. private val shuffleState = shuffleStates(shuffleId)
  9. private var fileGroup: ShuffleFileGroup = null
  10. // 重点: 对应上我们之前所说的shuffle有两种模式 , 一种是普通的,一种是优化后的
  11. // 如果开启了consolication机制,也即使consolicationShuffleFiles为true的话那么实际上不会给每个bucket都获取一个独立的文件
  12. // 而是为了这个bucket获取一个ShuffleGroup的writer
  13. val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
  14. fileGroup = getUnusedFileGroup()
  15. Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
  16. // 首先用shuffleId, mapId,bucketId生成一个一个唯一的ShuffleBlockId
  17. // 然后用bucketId来调用shuffleFileGroup的apply()函数为bucket获取一个ShuffleFileGroup
  18. val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
  19. // 然后用BlockManager的getDisWriter()方法针对ShuffleFileGroup获取一个Writer
  20. // 这样的话如果开启了consolidation机制那么对于每一个bucket都会获取一个针对ShuffleFileGroup的writer , 而不是一个独立的ShuffleBlockFile的writer
  21. // 这样就实现了所谓的多个ShuffleMapTask的输出数据合并
  22. blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,
  23. writeMetrics)
  24. }
  25. } else {
  26. // 如果没有开启consolation机制
  27. Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
  28. // 同样生成一个ShuffleBlockId
  29. val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
  30. // 然后调用BlockManager的DiskBlockManager , 获取一个代表了要写入本地磁盘文件的BlockFile
  31. val blockFile = blockManager.diskBlockManager.getFile(blockId)
  32. // Because of previous failures, the shuffle file may already exist on this machine.
  33. // If so, remove it.
  34. // 而且会判断这个blockFile要是存在的话还得删除它
  35. if (blockFile.exists) {
  36. if (blockFile.delete()) {
  37. logInfo(s"Removed existing shuffle file $blockFile")
  38. } else {
  39. logWarning(s"Failed to remove existing shuffle file $blockFile")
  40. }
  41. }
  42. // 然后调用BlockManager的getDiskWriterff针对那个blockFile生成writer
  43. blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
  44. }
  45. // 所以使用过这种普通的我shuffle操作的话对于每一个ShuffleMapTask输出的bucket都会在本地获取一个但粗的shuffleBlockFile
  46. }

上面代码的注释已经很详细啦 , 就是根据是否设置consolication机制来判断是否给每一个bucket数据创建一个独立的文件 , 若设置了consolication机制的话那么就会给这个bucket数据生成一个shuffeBlockId
然后根据bucket原有的id获取到一个ShuffleFileGroup . 而最后就会针对每一个bucket都会获取这个关于ShuffleFileGroup的Writer进行数据的写 , 而不是为每一个bucket都创建一个独立的shufflerBlockFile的writer

上面是关于一个stage中最后shuffle的写操作 , 接下来就是下一个stage读取上一个stage shuffle数据的读操作:
先来看下ShuffleRDD中的compute方法 , 源码如下:
        
        
        
        
  1. /**
  2. * Shuffle读数据的入口
  3. */
  4. override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
  5. // ResultTask或者ShuffleMapTask在执行ShuffleRDD时肯定会调用ShuffleRDD的compute方法,来计算当前这个RDD的partition的数据
  6. // 这个就是之前的Task源码分析时结合TaskRunner所分析的
  7. // 在这里会调用ShuffleManager的getReader()方法,获取一个HashShuffleReader , 然后调用它的read()方法拉取该ResultTask,ShuffleMapTask需要聚合的数据
  8. val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
  9. SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
  10. .read()
  11. .asInstanceOf[Iterator[(K, C)]]
  12. }

其实就是获取了一个与HashShuffleWriter相对应的HashShuffleReader来读取bucket中的数据而已 , 我们来看看HashShuffleWriter中读取数据的方法read(),源码如下:
       
       
       
       
  1. override def read(): Iterator[Product2[K, C]] = {
  2. val ser = Serializer.getSerializer(dep.serializer)
  3. // 这里就跟图解上面的串起来了
  4. // ResultTask在拉取数据时其实会调用BlockStoreShuffleFetcher来从DAGScheduler的MapOutputTrackermaster中获取自己想要的数据的信息
  5. // 底层再通过BlockManager从对应的位置拉取需要的数据
  6. val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
  7. val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
  8. if (dep.mapSideCombine) {
  9. new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
  10. } else {
  11. new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
  12. }
  13. } else {
  14. require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
  15. // Convert the Product2s to pairs since this is what downstream RDDs currently expect
  16. iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
  17. }
原理就是先拿到需要拉取数据的原信息 , 通过DAGScheduler的MapOutputTracker来获取 , 然后通过BlockManager来进行网络数据的拉取 , 这之间的操作都是上面的BlockStoreShuffleFetcher的fetch()方法实现的 , 源码如下:
        
        
        
        
  1. def fetch[T](
  2. shuffleId: Int,
  3. reduceId: Int,
  4. context: TaskContext,
  5. serializer: Serializer)
  6. : Iterator[T] =
  7. {
  8. logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
  9. val blockManager = SparkEnv.get.blockManager
  10. val startTime = System.currentTimeMillis
  11. // 重点 : 首先拿到一个全局的MapOutputTrackerMaster的引用 , 然后调用其getServerStatuses方法 , 传入的两个参数要注意
  12. // shuffleId可以代表当前这个stage的上一个stage , shuffle是分为两个stage的 , shuffle write发生在上一个stage中,shuffle read发生在当前的stage
  13. // 因此shuffleId 可以限制到上一个stage的所有ShuffleMapTask输出的mapStatus
  14. // 而reduceId就是所谓的buckedId来限制每个MapStatus中获取当前这个ResultTask需要获取的每个ShuffleMapTask的输出文件的信息
  15. // 这里的getServerStatuses会走远程网络通信的 , 因为要获取Driver上的DAGScheduler的MapOutputTrackerMaster
  16. val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
  17. logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
  18. shuffleId, reduceId, System.currentTimeMillis - startTime))
  19. // 下面的代码就是对刚刚拉取到的信息status进行一些数据结构上的转换操作 , 比如弄成map格式的数据
  20. val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
  21. for (((address, size), index) <- statuses.zipWithIndex) {
  22. splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
  23. }
  24. val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
  25. case (address, splits) =>
  26. (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
  27. }
  28. def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
  29. val blockId = blockPair._1
  30. val blockOption = blockPair._2
  31. blockOption match {
  32. case Success(block) => {
  33. block.asInstanceOf[Iterator[T]]
  34. }
  35. case Failure(e) => {
  36. blockId match {
  37. case ShuffleBlockId(shufId, mapId, _) =>
  38. val address = statuses(mapId.toInt)._1
  39. throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
  40. case _ =>
  41. throw new SparkException(
  42. "Failed to get block " + blockId + ", which is not a shuffle block", e)
  43. }
  44. }
  45. }
  46. }
  47. // 重点 : ShuffleBlockFetcherIterator构造以后在其内部就直接根据拉取到的硬盘上的具体位置信息
  48. // 通过BlockManager去远程的ShuffleMapTask所在节点的BlockManager去拉取数据
  49. val blockFetcherItr = new ShuffleBlockFetcherIterator(
  50. context,
  51. SparkEnv.get.blockManager.shuffleClient,
  52. blockManager,
  53. blocksByAddress,
  54. serializer,
  55. SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024)
  56. val itr = blockFetcherItr.flatMap(unpackBlock)
  57. // 最后将拉取到的数据进行一些转化和封装返回
  58. val completionIter = CompletionIterator[T, Iterator[T]](itr, {
  59. context.taskMetrics.updateShuffleReadMetrics()
  60. })
  61. new InterruptibleIterator[T](context, completionIter) {
  62. val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
  63. override def next(): T = {
  64. readMetrics.incRecordsRead(1)
  65. delegate.next()
  66. }
  67. }
  68. }
重点是MapOutputTrackerMaster的getServerStatuses方法中的shuffleId和reduceId , shuffleId代表的是上一个stage中shuffle产生的所有MapStatus数据 ,而reduceId其实就是bucketId , 代表的是当前这个stage中MapResultTask获取的数据文件信息
我们在进入MapOutputTrackerMaster的getServerStatuses方法继续深入 , 源码如下:
        
        
        
        
  1. /**
  2. * Called from executors to get the server URIs and output sizes of the map outputs of
  3. * a given shuffle.
  4. */
  5. def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
  6. val statuses = mapStatuses.get(shuffleId).orNull
  7. if (statuses == null) {
  8. logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
  9. var fetchedStatuses: Array[MapStatus] = null
  10. // 做了线程同步
  11. fetching.synchronized {
  12. // Someone else is fetching it; wait for them to be done
  13. // 不断去拉取shuffleId对应的数据 , 只要还没拉倒就死循环等待
  14. while (fetching.contains(shuffleId)) {
  15. try {
  16. fetching.wait()
  17. } catch {
  18. case e: InterruptedException =>
  19. }
  20. }
其实就是不断拉取shuffleId对应的数据而已
最后就是拉取ResultTaskMap的数据了 , 在ShuffleBlockFetchIterator类中的initialize()方法中 , 源码如下:
        
        
        
        
  1. /**
  2. * 将这个方法作为入口 , 开始拉取ResultTask对应的多份数据
  3. */
  4. private[this] def initialize(): Unit = {
  5. // Add a task completion callback (called in both success case and failure case) to cleanup.
  6. context.addTaskCompletionListener(_ => cleanup())
  7. // Split local and remote blocks.
  8. // 切分本地的和远程的block
  9. val remoteRequests = splitLocalRemoteBlocks()
  10. // Add the remote requests into our queue in a random order
  11. // 切分完之后进行shuffle随机排序操作
  12. fetchRequests ++= Utils.randomize(remoteRequests)
  13. // Send out initial requests for blocks, up to our maxBytesInFlight
  14. // 循环往复 , 只要发现还有数据没有拉取完就发送请求到远程去拉取数据
  15. // 这其中有一个参数就是max.bytes.in.flight这么一个参数,这个参数就决定了最多能拉取到多少数据到本地就要开始我们自定义的reduce算子的处理
  16. while (fetchRequests.nonEmpty &&
  17. (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
  18. sendRequest(fetchRequests.dequeue())
  19. }
  20. val numFetches = remoteRequests.size - fetchRequests.size
  21. logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
  22. // Get Local Blocks
  23. // 拉取完了远程数据之后获取本地的数据
  24. fetchLocalBlocks()
  25. logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  26. }

以上所有就是Shuffle操作的所有详情咯 !

你可能感兴趣的:(Java,spark)