spark源码阅读——shuffle读

DAGScheduler在拆分任务的时候如果发现需要shuffle则会把之前RDD运算产生的结果输出到本地磁盘中(详细的会在以后的文章分析)。

紧接着就需要对Shuffle后的结果分别进行运算了(比如说count
那么接着之前的RDD会有一个ShuffledRDD来处理shuffle之后的结果。
(实际上是一个新的Stage

同样在这个Stage会把任务拆分成Task并发送给Executor

这里拆分成的TaskResultTask实际上也很简单,任务反序列化之后执行ShuffledRDD.iterator -> ShuffledRDD.compute

override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
    SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
      .read()
      .asInstanceOf[Iterator[(K, C)]]
  }

ShuffleManager获取一个Reader读取之前Shuffle输出的数据进行运算。
实际这个Reader是一个BlockStoreShuffleReader
这个类会做什么呢?

  • 首先任务是计算我这个partition的结果,我需要知道之前依赖的partition的数据的位置(MapOutputTracker)
  • 根据位置获取依赖的数据。(BlockManager)
  • 如果需要combine则执行聚合逻辑
  • 如果需要排序则排序(ExternalSorter)
/**
 * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by
 * requesting them from other nodes' block stores.
 */
private[spark] class BlockStoreShuffleReader[K, C](
    handle: BaseShuffleHandle[K, _, C],
    startPartition: Int,
    endPartition: Int,
    context: TaskContext,
    serializerManager: SerializerManager = SparkEnv.get.serializerManager,
    blockManager: BlockManager = SparkEnv.get.blockManager,
    mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
  extends ShuffleReader[K, C] with Logging {

override def read(): Iterator[Product2[K, C]] = {
    val wrappedStreams = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      serializerManager.wrapStream,
      ...
   )

    val serializerInstance = dep.serializer.newInstance()

    // Create a key/value iterator for each stream
    val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
       ...
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
    }

       ...
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
      if (dep.mapSideCombine) {
         ...
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {
       ...      
       interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {
      case Some(keyOrd: Ordering[K]) =>
        // Create an ExternalSorter to sort the data.
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
        sorter.insertAll(aggregatedIter)
        ...
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }

mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
这个方法返回的参数类型是Seq[(BlockManagerId, Seq[(BlockId, Long)])]
也就是说在哪个Executor的哪个位置上保存着需要的数据信息,数据大小是多少。

(每个Executor都有一个SparkEnv,每个SparkEnv都包含一个BlockManagerDriver端的BlockManager是Master,Executor端的BlockManager是slave,集群内部的BlockManager构成了一个master-slave模式的集群,以后会说)

ShuffleBlockFetcherIterator

这个对象实际上就负责的从远程节点拉取所有数据的任务。
首先到了initialize方法

  • 注册回调,方便任务结束后清理内存(ByteBuffer)
  • 分离请求,有的数据实际可能保存在本地
  • 获取远程数据块
  • 获取本地数据块
private[this] def initialize(): Unit = {
    
    context.addTaskCompletionListener(_ => cleanup())
    ...
    val remoteRequests = splitLocalRemoteBlocks()
    ...
    fetchRequests ++= Utils.randomize(remoteRequests)
    ...
    fetchUpToMaxBytes()
    ...
    fetchLocalBlocks()
    ...
  }

这里将数据块信息拆分成获取任务的列表,这里有一个优化,为了加快获取速度,会将同一个文件拆分成多个请求同时获取。

private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
    // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
    // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
    // nodes, rather than blocking on reading output from one node.

    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize
      + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)

    // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
    // at most maxBytesInFlight in order to limit the amount of data in flight.
    val remoteRequests = new ArrayBuffer[FetchRequest]

    // Tracks total number of blocks (including zero sized blocks)
    var totalBlocks = 0
    for ((address, blockInfos) <- blocksByAddress) {
      totalBlocks += blockInfos.size
      if (address.executorId == blockManager.blockManagerId.executorId) {
        // Filter out zero-sized blocks
        localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
        numBlocksToFetch += localBlocks.size
      } else {
        val iterator = blockInfos.iterator
        var curRequestSize = 0L
        var curBlocks = new ArrayBuffer[(BlockId, Long)]
        while (iterator.hasNext) {
          val (blockId, size) = iterator.next()
          if (size > 0) {
            curBlocks += ((blockId, size))
            remoteBlocks += blockId
            numBlocksToFetch += 1
            curRequestSize += size
          } else if (size < 0) {
            throw new BlockException(blockId, "Negative block size " + size)
          }
          if (curRequestSize >= targetRequestSize ||
              curBlocks.size >= maxBlocksInFlightPerAddress) {
            remoteRequests += new FetchRequest(address, curBlocks)
            curBlocks = new ArrayBuffer[(BlockId, Long)]
            curRequestSize = 0
          }
        }
        // Add in the final request
        if (curBlocks.nonEmpty) {
          remoteRequests += new FetchRequest(address, curBlocks)
        }
      }
    }
    remoteRequests
  }

拆分完获取任务就要直接开始获取任务了。这里是有一个请求速率控制的机制在里面,分别是maxBytesInFlight和maxRequestInFlight,如果能发送请求则发送,否则放到延迟队列中,等待下一次调用这个方法的时候去发送请求。


private def fetchUpToMaxBytes(): Unit = {
    // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
    // immediately, defer the request until the next time it can be processed.

    // Process any outstanding deferred fetch requests if possible.
    if (deferredFetchRequests.nonEmpty) {
      for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
        while (isRemoteBlockFetchable(defReqQueue) &&
            !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
          val request = defReqQueue.dequeue()
          logDebug(s"Processing deferred fetch request for $remoteAddress with "
            + s"${request.blocks.length} blocks")
          send(remoteAddress, request)
          if (defReqQueue.isEmpty) {
            deferredFetchRequests -= remoteAddress
          }
        }
      }
    }

    // Process any regular fetch requests if possible.
    while (isRemoteBlockFetchable(fetchRequests)) {
      val request = fetchRequests.dequeue()
      val remoteAddress = request.address
      if (isRemoteAddressMaxedOut(remoteAddress, request)) {
        logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
        val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
        defReqQueue.enqueue(request)
        deferredFetchRequests(remoteAddress) = defReqQueue
      } else {
        send(remoteAddress, request)
      }
    }

    def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
      sendRequest(request)
      numBlocksInFlightPerAddress(remoteAddress) =
        numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
    }

    def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
      fetchReqQueue.nonEmpty &&
        (bytesInFlight == 0 ||
          (reqsInFlight + 1 <= maxReqsInFlight &&
            bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
    }

    // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
    // given remote address.
    def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
      numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
        maxBlocksInFlightPerAddress
    }
  }

发送请求实际借助了ShuffleClient这个类,这个类会使用NettyBlockTransferService这个类向远程的BlockManager发起获取数据块请求。
整个过程是异步的,并在发请求的时候增加了回调。回调便是在获取结束之后把结果放到一个LinkedBlockingQueue里面。

之后便是获取本地的数据块,从本地的BlockManager直接获取即可。

但实际上这个类是一个Iterator,会不断被外部调用next()方法。
next方法实际是阻塞的,因为如果results这个队列是空的则一直阻塞在这里。
每次从队列中获取一个获取结果之后做相应的处理,包装成一个InputStream
每次调用next的时候都会获取一个BlockId和相应的文件流而无需考虑这个文件块是否是远程和本地,因为每次调用next的时候都会调用这个fetchUpToMaxBytes方法,保证远程数据可以一直被获取。

override def next(): (BlockId, InputStream) = {
    if (!hasNext) {
      throw new NoSuchElementException
    }
    numBlocksProcessed += 1

    var result: FetchResult = null
    var input: InputStream = null
    while (result == null) {
           ...
      result = results.take()
           ...
      result match {
        case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
          if (address != blockManager.blockManagerId) {
            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
           ...
          }
          bytesInFlight -= size
          if (isNetworkReqDone) {
            reqsInFlight -= 1
          }

          val in = try {
            buf.createInputStream()
          } catch {
            case e: IOException =>
              buf.release()
              throwFetchFailedException(blockId, address, e)
          }

          input = streamWrapper(blockId, in)
           ...
          }
      }
      fetchUpToMaxBytes()   // <-----------------
    }

    currentResult = result.asInstanceOf[SuccessFetchResult]
    (currentResult.blockId, new BufferReleasingInputStream(input, this))
  }

流程梳理

ResultTaskExecutor上运行,调用ShuffledRDDiterator方法。这个方法从ShuffleManager获取一个BlockStoreShuffleReader,这个Reader内部负责了获取远程Shuffle输出文件的任务,获取之后根据combine,排序等处理数据,完成后续的运算。

你可能感兴趣的:(spark源码阅读——shuffle读)