Spark BlockStoreShuffleReader

SortShuffleManager只有BlockStoreShuffleReader这一种ShuffleReader

首先获取要读取的数据位置信息:当ShuffleMapTask完成时,会回调DAGSchedule的handleTaskCompletion方法,内部的MapOutputTracker会记录对应(shuffleId: Int, mapId: Int, status: MapStatus),所以后续executor端的MapOutputTrackerWorker通过getMapSizesByExecutorId可以获取driver端的记录信息,经过处理返回Seq[(BlockManagerId, Seq[(BlockId, Long)]),BlockManagerId可以确定BlockManager的位置,Seq[(BlockId, Long)]为该BlockManager的在指定partition范围内的数据块信息,记录长度是为了控制内存的使用

然后read方法实例化ShuffleBlockFetcherIterator,它是获取数据块的迭代器,本地数据块直接调用BlockManager.getBlockData,远程数据块采用Netty通过网络获取,数据块被存储到内存中,除非达到阈值(spark.maxRemoteBlockSizeFetchToMem,默认值Int.MaxValue - 512),返回每个数据块的(BlockId, InputStream)

根据具体的需求,如果指定了聚合,根据map端是否进行聚合,分为两种情况,最后如果指定了keyOrdering,进行排序,最后返回迭代器Iterator[Product2[K, C]]

  override def read(): Iterator[Product2[K, C]] = {
    val wrappedStreams = new ShuffleBlockFetcherIterator(
      context,
      //获取远程数据块
      blockManager.shuffleClient,
      //获取本地数据块
      blockManager,
      //需要获取的数据块列表
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      //对数据流进行压缩和加密的相关处理
      serializerManager.wrapStream,
      //正在获取的最大远程数据量48M
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
      //最大请求数目
      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
      //每个地址正在获取的数据块数目最大值
      SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
      //shuffle数据存储到内存的最大字节,
      SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
      //检测获取块中的损坏
      SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))

    val serializerInstance = dep.serializer.newInstance()

    // 为每个数据流的创建一个 key/value 迭代器,然后连接起来
    val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // 度量每一条记录
    val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map { record =>
        readMetrics.incRecordsRead(1)
        record
      },
      context.taskMetrics().mergeShuffleReadMetrics())

    // 再套一层迭代器,它通过检查TaskContext中的中断标志,提供任务中断功能
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    //指定了聚合,在reduce端同样需要进行聚合操作
    //内部使用ExternalAppendOnlyMap进行聚合操作,类似ExternalSort的实现,不过没有进行排序,只支持K值相等时进行聚合
    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        //map端已经进行了聚合,合并已经聚合过的值,涉及[K, C]
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        //map端没有进行了聚合,需要完整的进行聚合操作,涉及[K, V, C]
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // 如果指定了排序,对输出结果进行排序,使用ExternalSorter,最终返回排序过的数据迭代器
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data.
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }

ShuffleBlockFetcherIterator

ShuffleBlockFetcherIterator在初始化时会调用initialize方法

  private[this] def initialize(): Unit = {
    // 任务完成时回调,用于清空数据
    context.addTaskCompletionListener(_ => cleanup())

    // 区分本地和远程数据块
    val remoteRequests = splitLocalRemoteBlocks()
    // 将远程数据块请求乱序添加到请求队列中
    fetchRequests ++= Utils.randomize(remoteRequests)
    ...

    //发送请求,确保请求的数据量不超过maxBytesInFlight
    fetchUpToMaxBytes()
    //部分数据块请求已经开始处理
    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

    //获取本地数据块,内部通过IndexShuffleBlockResolver.getBlockData方法
    //然后构造一个SuccessFetchResult添加到结果记录队列results中
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }

将远程数据块封装为FetchRequest数组

  private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
    // 实际请求数据时大小为最大值的1/5,可以从5个节点并行的获取数据,避免阻塞到一个节点上
    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    ...
    // 远程数据块会被分成过个FetchRequests,避免超过最大正在传输数据量的限制
    val remoteRequests = new ArrayBuffer[FetchRequest]

    // Tracks total number of blocks (including zero sized blocks)
    var totalBlocks = 0
    for ((address, blockInfos) <- blocksByAddress) {
      totalBlocks += blockInfos.size
      //blockManager位于同一个executor,为本地数据块
      if (address.executorId == blockManager.blockManagerId.executorId) {
        // 过滤掉大小为0的数据块,其他的记录到localBlocks
        localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
        //记录数据块的总数
        numBlocksToFetch += localBlocks.size
      } else {
        val iterator = blockInfos.iterator
        var curRequestSize = 0L
        var curBlocks = new ArrayBuffer[(BlockId, Long)]
        while (iterator.hasNext) {
          val (blockId, size) = iterator.next()
          // 处理非空的数据块
          if (size > 0) {
            curBlocks += ((blockId, size))
            remoteBlocks += blockId //记录到remoteBlocks
            numBlocksToFetch += 1 //记录数据块的总数
            curRequestSize += size //记录数据块大小
          } else if (size < 0) {
            throw new BlockException(blockId, "Negative block size " + size)
          }
          //数据块大小,或者该address下数据块数目达到限定,封装为一个FetchRequest
          if (curRequestSize >= targetRequestSize ||
              curBlocks.size >= maxBlocksInFlightPerAddress) {
            remoteRequests += new FetchRequest(address, curBlocks)
            ...
            //重置数据
            curBlocks = new ArrayBuffer[(BlockId, Long)]
            curRequestSize = 0
          }
        }
        // 将剩余的远程数据块封装为一个FetchRequest
        if (curBlocks.nonEmpty) {
          remoteRequests += new FetchRequest(address, curBlocks)
        }
      }
    }
    ...
    remoteRequests
  }

你可能感兴趣的:(Spark BlockStoreShuffleReader)