从Spark Shuffle RDD到Shuffle Service on Yarn 源码阅读 二

从Spark Shuffle RDD到Shuffle Service on Yarn 源码阅读 二

涉及内容从Task执行,到RDD的读取,以及Shuffle数据的获取。本章主要从第二部分入手
Task体系
一 ShuffleMapTask的读和写
二 Shuffle Block的读和写
三 External Shuffle Service的设计

引子

上一章完成了从ShuffledRDD到ShuffleBlock的读取,这一章节侧重于作为ExternalShuffleService的CLient端,Spark Executor如何完成shuffle 数据的读取。

基础知识

Netty

因为Spark使用了Netty作为底层的数据传输框架,所以阅读这一部分需要掌握基础的Netty知识,否则看起来的时候很难以理解。
建议通过官方文档进行入门的基础学习。
https://netty.io/3.7/guide/#architecture

Spark

ExternalShuffleService RPC消息设计

Spark中为Shuffle数据的读取设计了如下的消息类型
Message.java

  enum Type implements Encodable {
    ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
    RpcRequest(3), RpcResponse(4), RpcFailure(5),
    StreamRequest(6), StreamResponse(7), StreamFailure(8),
    OneWayMessage(9), UploadStream(10), User(-1);

下面摘抄一段TransportClient.java中的注释

 1. For example, a typical workflow might be:
 2. client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100
 3. client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
 4. client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
 5. client.sendRPC(new CloseStream(100))

这里只说明其中几个比较关键的消息类型

  1. RpcRequest,这是一个generic类型的消息,意味着其中可以封装多种类型的MSG。如下的Encodable的子类对象都可以进行发送。为什么这里有一个generic类型的rpc request定义呢,是因为这这消息只是control plane的消息,并不涉及真实文件数据的传送,对比其他几个消息就可以发现,其他的消息都涉及到大量数据的发送和接收过程,所以这里进行了区分。
    从Spark Shuffle RDD到Shuffle Service on Yarn 源码阅读 二_第1张图片

这里拿出OpenBlock为例子进行一下分析。
在OneForOneBLockFetcher.java中可以分析出如下的关系, openblock请求会得到streamhandle作为响应。

ShuffleClient ExternalShuffleService
|------->–openblock-------->|
| |
|<--------streamHandle–<-- |
| |

  public OpenBlocks(String appId, String execId, String[] blockIds) {
    this.appId = appId;
    this.execId = execId;
    this.blockIds = blockIds;
  }
  public StreamHandle(long streamId, int numChunks) {
    this.streamId = streamId;
    this.numChunks = numChunks;
  }
  public void start() {
    if (blockIds.length == 0) {
      throw new IllegalArgumentException("Zero-sized blockIds array");
    }

    client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
      @Override
      public void onSuccess(ByteBuffer response) {
        try {
          streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
          logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);

          // Immediately request all chunks -- we expect that the total size of the request is
          // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
          for (int i = 0; i < streamHandle.numChunks; i++) {
            if (downloadFileManager != null) {
              client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
                new DownloadCallback(i));
            } else {
              client.fetchChunk(streamHandle.streamId, i, chunkCallback);
            }
          }
        } catch (Exception e) {
          logger.error("Failed while starting block fetches after success", e);
          failRemainingBlocks(blockIds, e);
        }
      }

      @Override
      public void onFailure(Throwable e) {
        logger.error("Failed while starting block fetches", e);
        failRemainingBlocks(blockIds, e);
      }
    });
  }

如上代码稍微注意一下RpcResponseCallback的使用,为社么这里会有一个这样的Callback存在。稍后进行分析

  1. ChunkFetchRequest, 一个stream是某一个远端的ExternalShuffleService endpoint为当前executor需要获取的所有的block所开启的一个状态管理对象,executor会用同一个streamId访问到这个特定ExternalShuffleService上所有需要访问的block,每个block被称为一个chunk。那么上一步中streamHandle返回了一个streamId和其中包含的chunk个数,这里的ChunkFetchRequest携带者一个StreamChunkId来请求某一个chunk的数据
  public ChunkFetchRequest(StreamChunkId streamChunkId) {
    this.streamChunkId = streamChunkId;
  }
  public StreamChunkId(long streamId, int chunkIndex) {
    this.streamId = streamId;
    this.chunkIndex = chunkIndex;
  }

于其对应的response有两个,分别是ChunkFetchFailure和ChunkFetchSuccess。

  public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
    super(buffer, true);
    this.streamChunkId = streamChunkId;
  }

这里ChunkFetchSuccess携带了一个ManagedBuffer类型的对象。这个对象代表的返回的数据。
同样思考一下上面OneForOneBLockFetcher.java中为什么发起ChunkFetchRequest的时候传入了一个chunkCallback对象作为callback?

  1. StreamRequest,这里这个请求的意思是通过文件下载的方式进行数据的获取。相应的response是StreamResponse。

Shuffle文件读取代码流程

ShuffleRDD iterator() --> compute() --> SparkEnv.get.shuffleManager.getReader().read --> SortShuffleManager.getReader().read() --> BlockStoreShuffleReader.read() [这里完成了和mapOutputTrackerMaster通信,和shuffleservice数据读取和本地sort merge] --> ShuffleBlockFetcherIterator.next(),这里每次都出来的是一个block,下文从ShuffleBlockFetcherIterator开始分析
ShuffleBlockFetcherIterator.java

final class ShuffleBlockFetcherIterator(
    context: TaskContext,
    shuffleClient: ShuffleClient,
    blockManager: BlockManager,
    blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])],
    streamWrapper: (BlockId, InputStream) => InputStream,
    maxBytesInFlight: Long,
    maxReqsInFlight: Int,
    maxBlocksInFlightPerAddress: Int,
    maxReqSizeShuffleToMem: Long,
    detectCorrupt: Boolean)
  extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {

初始化的时候就开始 fetchUpToMaxBytes() 进行数据的获取了

  private[this] def initialize(): Unit = {
    // Add a task completion callback (called in both success case and failure case) to cleanup.
    context.addTaskCompletionListener[Unit](_ => cleanup())

    // Split local and remote blocks.
    val remoteRequests = splitLocalRemoteBlocks()
    // Add the remote requests into our queue in a random order
    fetchRequests ++= Utils.randomize(remoteRequests)
    assert ((0 == reqsInFlight) == (0 == bytesInFlight),
      "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
      ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)

    // Send out initial requests for blocks, up to our maxBytesInFlight
    fetchUpToMaxBytes()

    val numFetches = remoteRequests.size - fetchRequests.size
    logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))

    // Get Local Blocks
    fetchLocalBlocks()
    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
  }

这里会进行一些限流,然后发起rpc请求

  private def fetchUpToMaxBytes(): Unit = {
    // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
    // immediately, defer the request until the next time it can be processed.

    // Process any outstanding deferred fetch requests if possible.
    if (deferredFetchRequests.nonEmpty) {
      for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
        while (isRemoteBlockFetchable(defReqQueue) &&
            !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
          val request = defReqQueue.dequeue()
          logDebug(s"Processing deferred fetch request for $remoteAddress with "
            + s"${request.blocks.length} blocks")
          send(remoteAddress, request)
          if (defReqQueue.isEmpty) {
            deferredFetchRequests -= remoteAddress
          }
        }
      }
    }

    // Process any regular fetch requests if possible.
    while (isRemoteBlockFetchable(fetchRequests)) {
      val request = fetchRequests.dequeue()
      val remoteAddress = request.address
      if (isRemoteAddressMaxedOut(remoteAddress, request)) {
        logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
        val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
        defReqQueue.enqueue(request)
        deferredFetchRequests(remoteAddress) = defReqQueue
      } else {
        send(remoteAddress, request)
      }
    }

    def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
      sendRequest(request)
      numBlocksInFlightPerAddress(remoteAddress) =
        numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
    }

最关键的一段, 下面的代码完成了整个向ExternalShuffleService发送请求以及获取所有数据的过程。

  private[this] def sendRequest(req: FetchRequest) {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    bytesInFlight += req.size
    reqsInFlight += 1

    // so we can look up the size of each blockID
    val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
    val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
    val blockIds = req.blocks.map(_._1.toString)
    val address = req.address

    val blockFetchingListener = new BlockFetchingListener {
      override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
        // Only add the buffer to results queue if the iterator is not zombie,
        // i.e. cleanup() has not been called yet.
        ShuffleBlockFetcherIterator.this.synchronized {
          if (!isZombie) {
            // Increment the ref count because we need to pass this to a different thread.
            // This needs to be released after use.
            buf.retain()
            remainingBlocks -= blockId
            results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
              remainingBlocks.isEmpty))
            logDebug("remainingBlocks: " + remainingBlocks)
          }
        }
        logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
      }

      override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
        logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
        results.put(new FailureFetchResult(BlockId(blockId), address, e))
      }
    }

    // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
    // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
    // the data and write it to file directly.
    if (req.size > maxReqSizeShuffleToMem) {
      shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
        blockFetchingListener, this)
    } else {
      shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
        blockFetchingListener, null)
    }
  }

你可能感兴趣的:(Spark)