[spark] Shuffle Read解析 (Sort Based Shuffle)

Shuffle Write 请看 Shuffle Write解析。

本文将讲解shuffle Reduce部分,shuffle的下游Stage的第一个rdd是ShuffleRDD,通过其compute方法来获取上游Stage Shuffle Write溢写到磁盘文件数据的一个迭代器:

 override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }

从SparkEnv中获取shuffleManager(这里是SortShuffleManager),通过manager获取Reader并调用其read方法来得到一个迭代器。

override def getReader[K, C](
      handle: ShuffleHandle,
      startPartition: Int,
      endPartition: Int,
      context: TaskContext): ShuffleReader[K, C] = {
    new BlockStoreShuffleReader(
      handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
  }

getReader方法实例化了一个BlockStoreShuffleReader,参数有需要获取分区对应的partitionId,看看起read方法:

 override def read(): Iterator[Product2[K, C]] = {
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      // 获取存储数据位置的元数据
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      // 每次远程请求传输的最大大小
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

    // 用压缩加密来包装流
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      serializerManager.wrapStream(blockId, inputStream)
    }
  
    val serializerInstance = dep.serializer.newInstance()

    // 对每个流生成K/V迭代器
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
       serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // 每条记录读取后更新任务度量
    val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
    // 生成完整的迭代器
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map { record =>
        readMetrics.incRecordsRead(1)
        record
      },
      context.taskMetrics().mergeShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // 在map端已经聚合一次了
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // 只在reduce端聚合
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // 若需要全局排序
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }

首先实例化了ShuffleBlockFetcherIterator对象,其中一个参数:

mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)

该方法获取reduce端数据的来源的元数据,返回的是 Seq[(BlockManagerId, Seq[(BlockId, Long)])],即数据是来自于哪个节点的哪些block的,并且block的数据大小是多少,看看getMapSizesByExecutorId是怎么实现的:

def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
    // 获取元数据信息
    val statuses = getStatuses(shuffleId)
    // 转换格式并得到指定partition的元数据信息
    statuses.synchronized {
      return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
    }
  }
  • 传入shuffleId获取对应shuffle的所有元数据信息
  • 转换格式并获取指定partition的元数据

跟进getStatuses:

private def getStatuses(shuffleId: Int): Array[MapStatus] = {
    // 直接从mapStatuses中获取
    val statuses = mapStatuses.get(shuffleId).orNull
    if (statuses == null) {
      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
      val startTime = System.currentTimeMillis
      var fetchedStatuses: Array[MapStatus] = null
      ......
      if (fetchedStatuses == null) {
        // We won the race to fetch the statuses; do so
        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
        // This try-finally prevents hangs due to timeouts:
        try {
          // 从远程获取元数据
          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
          // 反序列化
          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
          logInfo("Got the output locations")
          // 加入mapStatus
          mapStatuses.put(shuffleId, fetchedStatuses)
        } finally {
          fetching.synchronized {
            fetching -= shuffleId
            fetching.notifyAll()
          }
        }
      } 
     .....
      }
    } else {
      return statuses
    }
  }

若能从mapStatuses获取到则直接返回,若不能则向mapOutputTrackerMaster通信发送GetMapOutputStatuses消息来获取元数据。

我们知道一个Executor对应一个CoarseGrainedExecutorBackend,构建CoarseGrainedExecutorBackend的时候会创建一个SparkEnv,创建SparkEnv的时候会创建一个mapOutputTracker,即mapOutputTracker和Executor一一对应,也就是每一个Executor都有一个mapOutputTracker来维护元数据信息。

这里的mapStatuses就是mapOutputTracker保存元数据信息的,mapOutputTracker和Executor一一对应,在该Executor上完成的Shuffle Write的元数据信息都会保存在其mapStatus里面,另外通过远程获取的其他Executor上完成的Shuffle Write的元数据信息也会在当前的mapStatuses中保存。

Executor对应的是mapOutputTrackerWorker,而Driver对应的是mapOutputTrackerMaster,两者都是在实例化SparkEnv的时候创建的,每个在Executor上完成的Shuffle Task的结果都会注册到driver端的mapOutputTrackerMaster中,即driver端的mapOutputTrackerMaster的mapStatuses保存这所有元数据信息,所以当一个Executor上的任务需要获取一个shuffle的输出时,会先在自己的mapStatuses中查找,找不到再和mapOutputTrackerMaster通信获取元数据。

mapOutputTrackerMaster收到消息后的处理逻辑:

case GetMapOutputStatuses(shuffleId: Int) =>
      val hostPort = context.senderAddress.hostPort
      logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
      val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))

调用了tracker的post方法:

 def post(message: GetMapOutputMessage): Unit = {
    mapOutputRequests.offer(message)
  }

将该Message加入了mapOutputRequests中,mapOutputRequests是一个链式阻塞队列,在mapOutputTrackerMaster初始化的时候专门启动了一个线程池来执行这些请求:

private val threadpool: ThreadPoolExecutor = {
    val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8)
    val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher")
    for (i <- 0 until numThreads) {
      pool.execute(new MessageLoop)
    }
    pool
  }

看看线程处理类MessageLoop的run方法是怎么定义的:

private class MessageLoop extends Runnable {
    override def run(): Unit = {
      try {
        while (true) {
          try {
            // 取出一个GetMapOutputMessage
            val data = mapOutputRequests.take()
             if (data == PoisonPill) {
              // Put PoisonPill back so that other MessageLoops can see it.
              mapOutputRequests.offer(PoisonPill)
              return
            }
            val context = data.context
            val shuffleId = data.shuffleId
            val hostPort = context.senderAddress.hostPort
            logDebug("Handling request to send map output locations for shuffle " + shuffleId +
              " to " + hostPort)
            // 通过shuffleId获取对应序列化后的元数据信息
            val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)
            // 返回数据
            context.reply(mapOutputStatuses)
          } catch {
            case NonFatal(e) => logError(e.getMessage, e)
          }
        }
      } catch {
        case ie: InterruptedException => // exit
      }
    }
  }

通过shuffleId获取对应序列化后的元数据信息并返回,具体看看getSerializedMapOutputStatuses的实现:

def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
    var statuses: Array[MapStatus] = null
    var retBytes: Array[Byte] = null
    var epochGotten: Long = -1

    // 从cache中检索出MapStatus,若没有则从mapStatuses中获取
    def checkCachedStatuses(): Boolean = {
      epochLock.synchronized {
        if (epoch > cacheEpoch) {
          cachedSerializedStatuses.clear()
          clearCachedBroadcast()
          cacheEpoch = epoch
        }
        cachedSerializedStatuses.get(shuffleId) match {
          case Some(bytes) =>
            retBytes = bytes
            true
          case None =>
            logDebug("cached status not found for : " + shuffleId)
            statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus])
            epochGotten = epoch
            false
        }
      }
    }

    if (checkCachedStatuses()) return retBytes
    var shuffleIdLock = shuffleIdLocks.get(shuffleId)
    if (null == shuffleIdLock) {
      val newLock = new Object()
      // in general, this condition should be false - but good to be paranoid
      val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
      shuffleIdLock = if (null != prevLock) prevLock else newLock
    }
    // synchronize so we only serialize/broadcast it once since multiple threads call
    // in parallel
    shuffleIdLock.synchronized {
      if (checkCachedStatuses()) return retBytes

      // 序列化statues
      val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager,
        isLocal, minSizeForBroadcast)
      logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
      // Add them into the table only if the epoch hasn't changed while we were working
      epochLock.synchronized {
        if (epoch == epochGotten) {
          cachedSerializedStatuses(shuffleId) = bytes
          if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
        } else {
          logInfo("Epoch changed, not caching!")
          removeBroadcast(bcast)
        }
      }
      bytes
    }
  }

大体思路是先从缓存中获取元数据(MapStatuses),获取到直接返回,若没有则从mapStatuses获取,获取到后将其序列化后返回,随后返回给mapOutputTrackerWorker(刚才与之通信的节点),mapOutputTracker收到回复后又将元数据序列化并加入当前Executor的mapStatuses中。

再回到getMapSizesByExecutorId方法中,getStatuses得到shuffleID对应的所有的元数据信息后,通过convertMapStatuses方法将获得的元数据信息转化成形如Seq[(BlockManagerId, Seq[(BlockId, Long)])]格式的位置信息,用来读取指定的分区的数据:

private def convertMapStatuses(
      shuffleId: Int,
      startPartition: Int,
      endPartition: Int,
      statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
    assert (statuses != null)
    // 存储指定partition的元数据
    val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]
    for ((status, mapId) <- statuses.zipWithIndex) {
      if (status == null) {
        val errorMessage = s"Missing an output location for shuffle $shuffleId"
        logError(errorMessage)
        throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
      } else {
        for (part <- startPartition until endPartition) {
          splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
            ((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))
        }
      }
    }

    splitsByAddress.toSeq
  }

这里的参数statuses:Array[MapStatus]是前面获取的上游stage所有的shuffle Write 文件的元数据,并且是按map端的partitionId排序的,通过zipWithIndex将元素和这个元素在数组中的ID(索引号)组合成键/值对,这里的索引号即是map端的partitionId,再根据shuffleId、mapPartitionId、reducePartitionId来构建ShuffleBlockId(在map端的ShuffleBlockId构建中的reducePartitionId始终是0,因为一个ShuffleMapTask就一个Block,而这里加入的真正的reducePartitionId在后面通过index文件获取对应reduce端partition偏移量的时候需要用到),并估算得到对应数据的大小,因为后面获取远程数据的时候需要限制大小,最后返回位置信息。

至此mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)方法完成,返回了指定分区对应的元数据MapStatus信息。

在初始化对象ShuffleBlockFetcherIterator的时候调用了其初始化方法initialize():

private[this] def initialize(): Unit = {
    // Add a task completion callback (called in both success case and failure case) to cleanup.
    context.addTaskCompletionListener(_ => cleanup())

    // 区分local blocks和remote blocks并返回远程请求FetchRequest
    val remoteRequests = splitLocalRemoteBlocks()
    // 将远程请求随机的加入到fetchRequests队列中
    fetchRequests ++= Utils.randomize(remoteRequests)
    assert ((0 == reqsInFlight) == (0 == bytesInFlight),
      "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
      ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)

    // 从fetchRequests取出远程请求,并使用sendRequest方法发送请求
    fetchUpToMaxBytes()

    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

    // 获取本地blocks
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }
  • 区分local blocks和remote blocks,并返回远程请求FetchRequest加入到fetchRequests队列中
  • 从fetchRequests取出远程请求,并使用sendRequest方法发送请求,获取远程数据
  • 获取本地blocks

先看是怎么区分local blocks和remote blocks的:

private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
    // 将一次能获取的数据最大大小/5,目的是增加并行度,最大为5个并行度
    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)

    // 存储远程请求的数组
    val remoteRequests = new ArrayBuffer[FetchRequest]

    // Tracks total number of blocks (including zero sized blocks)
    var totalBlocks = 0
    for ((address, blockInfos) <- blocksByAddress) {
      totalBlocks += blockInfos.size
      // 若block所在executor就是当前executor,则判断为本地,否则为远程
      if (address.executorId == blockManager.blockManagerId.executorId) {
        // 过滤掉大小为0的blocks
        localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
        numBlocksToFetch += localBlocks.size
      } else {
        val iterator = blockInfos.iterator
        var curRequestSize = 0L
        var curBlocks = new ArrayBuffer[(BlockId, Long)]
        while (iterator.hasNext) {
          val (blockId, size) = iterator.next()
          // Skip empty blocks
          if (size > 0) {
            curBlocks += ((blockId, size))
            remoteBlocks += blockId
            numBlocksToFetch += 1
            curRequestSize += size
          } else if (size < 0) {
            throw new BlockException(blockId, "Negative block size " + size)
          }
          // 当请求大小超过了限制,则创建一个FetchRequest并加入到remoteRequests中
          if (curRequestSize >= targetRequestSize) {
            // Add this FetchRequest
            remoteRequests += new FetchRequest(address, curBlocks)
            curBlocks = new ArrayBuffer[(BlockId, Long)]
            logDebug(s"Creating fetch request of $curRequestSize at $address")
            curRequestSize = 0
          }
        }
        // 将剩余的blocks创建一个FetchRequest并加入到remoteRequests中
        if (curBlocks.nonEmpty) {
          remoteRequests += new FetchRequest(address, curBlocks)
        }
      }
    }
    logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
    remoteRequests
  }
  • 为了增加在远程节点获取数据的并行度,将一个请求的大小限制除以5作为最终的大小限制,即每次最多启动5个线程去最多5个节点上读取数据
  • 判断是否是本地blocks的条件是block所在的executor和当前executor是否是同一个
  • 遍历远程数据节点(Executor节点)的blocks,在一个节点上的请求数据超过大小限制则构建一个FetchRequest并加入到remoteRequests中,最后返回远程请求remoteRequests,这里的FetchRequest是对一个请求数据的包装,包括地址和blockId及大小

区分完local remote blocks后加入到了队列fetchRequests中,并调用fetchUpToMaxBytes()来获取远程数据:

private def fetchUpToMaxBytes(): Unit = {
    // Send fetch requests up to maxBytesInFlight
    while (fetchRequests.nonEmpty &&
      (bytesInFlight == 0 ||
        (reqsInFlight + 1 <= maxReqsInFlight &&
          bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {
      sendRequest(fetchRequests.dequeue())
    }
  }

从fetchRequests中取出FetchRequest,并调用了sendRequest方法:

 private[this] def sendRequest(req: FetchRequest) {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    bytesInFlight += req.size
    reqsInFlight += 1

    // 转成map  Map[blockId,size]
    val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
    val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
    val blockIds = req.blocks.map(_._1.toString)

    val address = req.address
    // 通过shuffleClient的fetchBlocks方法来获取对应远程节点上的数据
    shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
      new BlockFetchingListener {
        // 将结果保存到results中
        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
          // Only add the buffer to results queue if the iterator is not zombie,
          // i.e. cleanup() has not been called yet.
          ShuffleBlockFetcherIterator.this.synchronized {
            if (!isZombie) {
              // Increment the ref count because we need to pass this to a different thread.
              // This needs to be released after use.
              buf.retain()
              remainingBlocks -= blockId
              results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
                remainingBlocks.isEmpty))
              logDebug("remainingBlocks: " + remainingBlocks)
            }
          }
          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
        }

        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
          logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
          results.put(new FailureFetchResult(BlockId(blockId), address, e))
        }
      }
    )
  }

通过shuffleClient的fetchBlocks方法来获取对应远程节点上的数据,默认是通过NettyBlockTransferService的fetchBlocks方法实现的,不管是成功还是失败都将构建SuccessFetchResult & FailureFetchResult 结果放入results中。

获取完远程的数据接着通过fetchLocalBlocks()方法来获取本地的blocks信息:

private[this] def fetchLocalBlocks() {
    val iter = localBlocks.iterator
    while (iter.hasNext) {
      val blockId = iter.next()
      try {
        val buf = blockManager.getBlockData(blockId)
        shuffleMetrics.incLocalBlocksFetched(1)
        shuffleMetrics.incLocalBytesRead(buf.size)
        buf.retain()
        results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
      } catch {
        case e: Exception =>
          // If we see an exception, stop immediately.
          logError(s"Error occurred while fetching local blocks", e)
          results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
          return
      }
    }
  }

迭代需要获取的block,直接从blockManager中获取数据,并通过结果数据构建SuccessFetchResult或者FailureFetchResult放入results中,看看在blockManager.getBlockData(blockId)的实现:

override def getBlockData(blockId: BlockId): ManagedBuffer = {
    if (blockId.isShuffle) {
      shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
    } else {
      getLocalBytes(blockId) match {
        case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer)
        case None =>
          // If this block manager receives a request for a block that it doesn't have then it's
          // likely that the master has outdated block statuses for this block. Therefore, we send
          // an RPC so that this block is marked as being unavailable from this block manager.
          reportBlockStatus(blockId, BlockStatus.empty)
          throw new BlockNotFoundException(blockId.toString)
      }
    }
  }

再看看getBlockData方法:

override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
    // 根据ShuffleID和MapID获取索引文件
    val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
    val in = new DataInputStream(new FileInputStream(indexFile))
    try {
      // 跳到对应Block的数据区
      ByteStreams.skipFully(in, blockId.reduceId * 8)
      // partition对应的开始offset
      val offset = in.readLong()
      // partition对应的结束offset
      val nextOffset = in.readLong()
      new FileSegmentManagedBuffer(
        transportConf,
        getDataFile(blockId.shuffleId, blockId.mapId),
        offset,
        nextOffset - offset)
    } finally {
      in.close()
    }
  }

根据shuffleId和mapId获取index文件,并创建一个读文件的文件流,根据block的reduceId(上面获取对应partition元数据的时候提到过)跳过对应的Block的数据区,先后获取开始和结束的offset,然后在数据文件中读取数据。

得到所有数据结果result后,再回到read()方法中:

 override def read(): Iterator[Product2[K, C]] = {
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      // 与mapOutputTrackerMaster通信获取存储数据位置的元数据
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      // 每次传输的最大大小
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))

    // 用压缩加密来包装流
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      serializerManager.wrapStream(blockId, inputStream)
    }
  
    val serializerInstance = dep.serializer.newInstance()

    // 对每个流生成K/V迭代器
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
       serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

    // 每条记录读取后更新任务度量
    val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
    // 生成完整的迭代器
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map { record =>
        readMetrics.incRecordsRead(1)
        record
      },
      context.taskMetrics().mergeShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
        // 在map端已经聚合一次了
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        // 只在reduce端聚合
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // 若需要全局排序
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }

这里的ShuffleBlockFetcherIterator继承了Iterator,results可以被迭代,在其next()方法中将FetchResult以(blockId,inputStream)的形式返回:

case SuccessFetchResult(blockId, address, _, buf, _) =>
        try {
          (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
        } catch {
          case NonFatal(t) =>
            throwFetchFailedException(blockId, address, t)
        }

在read()方法的后半部分会进行聚合和排序,和Shuffle Write部分很类似,这里大致描述一下。

在需要聚合的前提下,有map端聚合的时候执行combineCombinersByKey,没有则执行combineValuesByKey,但最终都调用了ExternalAppendOnlyMap的insertAll(iter)方法:

def combineCombinersByKey(
      iter: Iterator[_ <: Product2[K, C]],
      context: TaskContext): Iterator[(K, C)] = {
    val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
    combiners.insertAll(iter)
    updateMetrics(context, combiners)
    combiners.iterator
  }
def combineValuesByKey(
      iter: Iterator[_ <: Product2[K, V]],
      context: TaskContext): Iterator[(K, C)] = {
    val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
    combiners.insertAll(iter)
    updateMetrics(context, combiners)
    combiners.iterator
  }
def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
    if (currentMap == null) {
      throw new IllegalStateException(
        "Cannot insert new elements into a map after calling iterator")
    }
    // An update function for the map that we reuse across entries to avoid allocating
    // a new closure each time
    var curEntry: Product2[K, V] = null
    val update: (Boolean, C) => C = (hadVal, oldVal) => {
      if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
    }

    while (entries.hasNext) {
      curEntry = entries.next()
      val estimatedSize = currentMap.estimateSize()
      if (estimatedSize > _peakMemoryUsedBytes) {
        _peakMemoryUsedBytes = estimatedSize
      }
      if (maybeSpill(currentMap, estimatedSize)) {
        currentMap = new SizeTrackingAppendOnlyMap[K, C]
      }
      currentMap.changeValue(curEntry._1, update)
      addElementsRead()
    }
  }

在里面的迭代最终都会调用上面提到的ShuffleBlockFetcherIterator的next方法来获取数据。

每次update&insert也会估算currentMap的大小,并判断是否需要溢写到磁盘文件,若需要则将map中的数据根据定义的keyComparator对key进行排序后返回一个迭代器,然后写到一个临时的磁盘文件,然后新建一个map来放新的数据。

执行完combiners[ExternalAppendOnlyMap]的insertAll后,调用其iterator来返回一个代表一个完整partition数据(内存及spillFile)的迭代器:

override def iterator: Iterator[(K, C)] = {
    if (currentMap == null) {
      throw new IllegalStateException(
        "ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
    }
    if (spilledMaps.isEmpty) {
      CompletionIterator[(K, C), Iterator[(K, C)]](
        destructiveIterator(currentMap.iterator), freeCurrentMap())
    } else {
      new ExternalIterator()
    }
  }

跟进ExternalIterator类的实例化:

// A queue that maintains a buffer for each stream we are currently merging
    // This queue maintains the invariant that it only contains non-empty buffers
    private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]

    // Input streams are derived both from the in-memory map and spilled maps on disk
    // The in-memory map is sorted in place, while the spilled maps are already in sorted order
    private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator(
      currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap())
    private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)

    inputStreams.foreach { it =>
      val kcPairs = new ArrayBuffer[(K, C)]
      readNextHashCode(it, kcPairs)
      if (kcPairs.length > 0) {
        mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
      }
    }

将currentMap中的数据经过排序后和spillFile数据的iterator组合在一起得到inputStreams ,迭代这个inputStreams ,将所有数据都保存在mergeHeadp中,在ExternalIterator方法的next()方法中将被访问到。

最后若需要对数据进行全局的排序,则通过只有排序参数的ExternalSorter的insertAll方法来进行排序,和Shuffle Write一样的这里就不细讲了。

最终返回一个指定partition所有数据的一个迭代器。

你可能感兴趣的:([spark] Shuffle Read解析 (Sort Based Shuffle))