Spark技术内幕:Shuffle Read的整体流程

回忆一下,每个Stage的上边界,要么需要从外部存储读取数据,要么需要读取上一个Stage的输出;而下边界,要么是需要写入本地文件系统(需要Shuffle),以供childStage读取,要么是最后一个Stage,需要输出结果。这里的Stage,在运行时的时候就是可以以pipeline的方式运行的一组Task,除了最后一个Stage对应的是ResultTask,其余的Stage对应的都是ShuffleMap Task。

而除了需要从外部存储读取数据和RDD已经做过cache或者checkpoint的Task,一般Task的开始都是从ShuffledRDD的ShuffleRead开始的。本节将详细讲解Shuffle Read的过程。

先看一下ShuffleRead的整体架构图。

Spark技术内幕:Shuffle Read的整体流程_第1张图片

org.apache.spark.rdd.ShuffledRDD#compute 开始,通过调用org.apache.spark.shuffle.ShuffleManager的getReader方法,获取到org.apache.spark.shuffle.ShuffleReader,然后调用其read()方法进行读取。在Spark1.2.0中,不管是Hash BasedShuffle或者是Sort BasedShuffle,内置的Shuffle Reader都是 org.apache.spark.shuffle.hash.HashShuffleReader。核心实现:

[java] view plain copy
  1.  override def read(): Iterator[Product2[K, C]] = {  
  2. val ser =Serializer.getSerializer(dep.serializer)  
  3. // 获取结果  
  4.    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId,startPartition, context, ser)  
  5.    // 处理结果  
  6.    val aggregatedIter: Iterator[Product2[K, C]] = if(dep.aggregator.isDefined) {//需要聚合  
  7.      if (dep.mapSideCombine) {//需要map side的聚合  
  8.        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(  
  9.                             iter, context))  
  10.      } else {//只需要reducer端的聚合  
  11.        new InterruptibleIterator(context,dep.aggregator.get.combineValuesByKey(  
  12.                             iter, context))  
  13.       
  14. }  
  15.     }else { // 无需聚合操作  
  16.        iter.asInstanceOf[Iterator[Product2[K,C]]].map(pair => (pair._1, pair._2))  
  17.     }  
  18.    
  19.    // Sort the output if there is a sort ordering defined.  
  20.    dep.keyOrdering match {//判断是否需要排序  
  21.      case Some(keyOrd: Ordering[K]) => //对于需要排序的情况  
  22.        // 使用ExternalSorter进行排序,注意如果spark.shuffle.spill是false,那么数据是  
  23.        // 不会spill到硬盘的  
  24.        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd),  
  25.                                          serializer= Some(ser))  
  26.        sorter.insertAll(aggregatedIter)  
  27.        context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled  
  28.        context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled  
  29.        sorter.iterator  
  30.      case None => //无需排序  
  31.        aggregatedIter  
  32.     }  
  33.   }  

org.apache.spark.shuffle.hash.BlockStoreShuffleFetcher#fetch会获得数据,它首先会通过

org.apache.spark.MapOutputTracker#getServerStatuses来获得数据的meta信息,这个过程有可能需要向org.apache.spark.MapOutputTrackerMasterActor发送读请求,这个读请求是在org.apache.spark.MapOutputTracker#askTracker发出的。在获得了数据的meta信息后,它会将这些数据存入Seq[(BlockManagerId,Seq[(BlockId, Long)])]中,然后调用org.apache.spark.storage.ShuffleBlockFetcherIterator最终发起请求。ShuffleBlockFetcherIterator根据数据的本地性原则进行数据获取。如果数据在本地,那么会调用org.apache.spark.storage.BlockManager#getBlockData进行本地数据块的读取。而getBlockData对于shuffle类型的数据,会调用ShuffleManager的ShuffleBlockManager的getBlockData。

如果数据在其他的Executor上,那么如果用户使用的spark.shuffle.blockTransferService是netty,那么就会通过org.apache.spark.network.netty.NettyBlockTransferService#fetchBlocks获取;如果使用的是nio,那么就会通过org.apache.spark.network.nio.NioBlockTransferService#fetchBlocks获取。


数据读取策略的划分

org.apache.spark.storage.ShuffleBlockFetcherIterator会通过splitLocalRemoteBlocks划分数据的读取策略:如果在本地有,那么可以直接从BlockManager中获取数据;如果需要从其他的节点上获取,那么需要走网络。由于Shuffle的数据量可能会很大,因此这里的网络读有以下的策略:

1)       每次最多启动5个线程去最多5个节点上读取数据

2)       每次请求的数据大小不会超过spark.reducer.maxMbInFlight(默认值为48MB)/5

这样做的原因有几个:

1)  避免占用目标机器的过多带宽,在千兆网卡为主流的今天,带宽还是比较重要的。如果机器使用的万兆网卡,那么可以通过设置spark.reducer.maxMbInFlight来充分利用带宽。

2)  请求数据可以平行化,这样请求数据的时间可以大大减少。请求数据的总时间就是请求中耗时最长的。这样可以缓解一个节点出现网络拥塞时的影响。

主要的实现:

[java] view plain copy
  1. private[this] def splitLocalRemoteBlocks():ArrayBuffer[FetchRequest] = {  
  2.    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)  
  3.    val remoteRequests = new ArrayBuffer[FetchRequest]  
  4.    for ((address, blockInfos) <- blocksByAddress) {  
  5.      if (address.executorId == blockManager.blockManagerId.executorId) {  
  6.        // Block在本地,需要过滤大小为0的block。  
  7.        localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)  
  8.        numBlocksToFetch += localBlocks.size  
  9.      } else { //需要远程获取的Block  
  10.        val iterator = blockInfos.iterator  
  11.         var curRequestSize = 0L  
  12.        var curBlocks = new ArrayBuffer[(BlockId, Long)]  
  13.        while (iterator.hasNext) {  
  14.           //blockId 是org.apache.spark.storage.ShuffleBlockId,  
  15.           // 格式:"shuffle_" +shuffleId + "_" + mapId + "_" + reduceId  
  16.          val (blockId, size) = iterator.next()  
  17.          // Skip empty blocks  
  18.          if (size > 0) {  
  19.            curBlocks += ((blockId, size))  
  20.            remoteBlocks += blockId  
  21.            numBlocksToFetch += 1  
  22.            curRequestSize += size  
  23.           }  
  24.    
  25.          if (curRequestSize >= targetRequestSize) {  
  26.            // 当前总的size已经可以批量放入一次request中  
  27.            remoteRequests += new FetchRequest(address, curBlocks)  
  28.            curBlocks = new ArrayBuffer[(BlockId, Long)]  
  29.            curRequestSize = 0  
  30.          }  
  31.        }  
  32.        // 剩余的请求组成一次request  
  33.        if (curBlocks.nonEmpty) {  
  34.          remoteRequests += new FetchRequest(address, curBlocks)  
  35.        }  
  36.      }  
  37.     }  
  38.    remoteRequests  
  39.   }  

本地读取

fetchLocalBlocks() 负责本地Block的获取。在splitLocalRemoteBlocks中,已经将本地的Block列表存入了localBlocks:private[this] val localBlocks = newArrayBuffer[BlockId]()

具体过程如下:

[java] view plain copy
  1. val iter = localBlocks.iterator  
  2.  while (iter.hasNext) {  
  3.    val blockId = iter.next()  
  4.    try {  
  5.      val buf = blockManager.getBlockData(blockId)  
  6.      shuffleMetrics.localBlocksFetched += 1  
  7.      buf.retain()  
  8.      results.put(new SuccessFetchResult(blockId, 0, buf))  
  9.    } catch {  
  10.    }  
  11.   }  

而blockManager.getBlockData(blockId)的实现是:

[java] view plain copy
  1. override def getBlockData(blockId:BlockId): ManagedBuffer = {  
  2.    if (blockId.isShuffle) {  
  3.     shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])  
  4. }  
这就调用了ShuffleBlockManager的getBlockData方法。在Shuffle Pluggable框架中我们介绍了实现一个Shuffle Service之一就是要实现ShuffleBlockManager。

以Hash BasedShuffle为例,它的ShuffleBlockManager是org.apache.spark.shuffle.FileShuffleBlockManager。FileShuffleBlockManager有两种情况,一种是File consolidate的,这种的话需要根据Map ID和 Reduce ID首先获得FileGroup的一个文件,然后根据在文件中的offset和size来获取需要的数据;如果是没有File consolidate,那么直接根据Shuffle Block ID直接读取整个文件就可以。

[java] view plain copy
  1. override def getBlockData(blockId:ShuffleBlockId): ManagedBuffer = {  
  2.    if (consolidateShuffleFiles) {  
  3.      val shuffleState = shuffleStates(blockId.shuffleId)  
  4.      val iter = shuffleState.allFileGroups.iterator  
  5. while(iter.hasNext) {  
  6.   // 根据Map ID和Reduce ID获取File Segment的信息  
  7.        val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId,blockId.reduceId)  
  8.        if (segmentOpt.isDefined) {  
  9.          val segment = segmentOpt.get  
  10.          // 根据File Segment的信息,从FileGroup中找到相应的File和Block在  
  11.           // 文件中的offset和size  
  12.          return new FileSegmentManagedBuffer(  
  13.            transportConf, segment.file, segment.offset, segment.length)  
  14.        }  
  15.      }  
  16.      throw new IllegalStateException("Failed to find shuffle block:" + blockId)  
  17.     }else {  
  18.      val file = blockManager.diskBlockManager.getFile(blockId) //直接获取文件句柄  
  19.      new FileSegmentManagedBuffer(transportConf, file, 0, file.length)  
  20.     }  
  21.   }  

对于Sort BasedShuffle,它需要通过索引文件来获得数据块在数据文件中的具体位置信息,从而读取这个数据。

具体实现在org.apache.spark.shuffle.IndexShuffleBlockManager#getBlockData中。 

[java] view plain copy
  1. override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {  
  2.    // 根据ShuffleID和MapID从org.apache.spark.storage.DiskBlockManager 获取索引文件  
  3.    val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)  
  4.    val in = new DataInputStream(new FileInputStream(indexFile))  
  5.    try {  
  6.      ByteStreams.skipFully(in, blockId.reduceId * 8//跳到本次Block的数据区  
  7.      val offset = in.readLong() //数据文件中的开始位置  
  8.      val nextOffset = in.readLong() //数据文件中的结束位置  
  9.      new FileSegmentManagedBuffer(  
  10.        transportConf,  
  11.        getDataFile(blockId.shuffleId, blockId.mapId),  
  12.        offset,  
  13.        nextOffset - offset)  
  14.     }finally {  
  15.      in.close()  
  16.     }  
  17.   } 

你可能感兴趣的:(spark)