Spark数据传输

本文基于spark 2.11

1. 前言

在spark shuffle write 和read一文中3.1.1 节创建ShuffleBlockFetchIterator来从上游所有partition从fetch数据,回顾一下调用ShuffleBlockFetchIterator#sendRequest发送请求数据block,代码如下:

private[this] def sendRequest(req: FetchRequest) {
    logDebug("Sending request for %d blocks (%s) from %s".format(
      req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
    bytesInFlight += req.size
    reqsInFlight += 1

    // so we can look up the size of each blockID
    val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
    val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
    val blockIds = req.blocks.map(_._1.toString)

    val address = req.address
    shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
      new BlockFetchingListener {
        override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
          // Only add the buffer to results queue if the iterator is not zombie,
          // i.e. cleanup() has not been called yet.
          ShuffleBlockFetcherIterator.this.synchronized {
            if (!isZombie) {
              // Increment the ref count because we need to pass this to a different thread.
              // This needs to be released after use.
              buf.retain()
              remainingBlocks -= blockId
              results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
                remainingBlocks.isEmpty))
              logDebug("remainingBlocks: " + remainingBlocks)
            }
          }
          logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
        }

        override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
          logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
          results.put(new FailureFetchResult(BlockId(blockId), address, e))
        }
      }
    )
  }

代码中shuffleClient从blockManager中获得的,在没有使用外部存储的情况下,是NettyBlockTransferService的实例,下文会讲到NettyBlockTransferService基于netty实现的数据块服务。从fetchBlocks的参数表明了:

  1. address.host,数据所在node的地址
  2. port, NettyBlockTransferService服务端的端口,有配置spark.driver.blockManager.port指定。
  3. executorId,运行在node上的产生block的executorId
  4. blockIds,请求的一批block的id。
  5. BlockFetchingListener, 返回数据块时回调,将数据块存下来。

如果能弄清楚NettyBlockTransferService之后,数据从在executor和executor之间的传输过程就清晰了。

2. NettyBlockTransferService

NettyBlockTransferService是一个基于netty实现的数据传输服务,在文章Spark rpc实现一文中介绍过spark rpc框架也同样基于netty实现。NettyBlockTransferService在SparkEnv初始化时创建、在BlockManager中初始化。下面的方法表明他可以提供的服务:

- override def fetchBlocks(
      host: String,
      port: Int,
      execId: String,
      blockIds: Array[String],
      listener: BlockFetchingListener)
     从远端node上获取blocks
- override def uploadBlock(
      hostname: String,
      port: Int,
      execId: String,
      blockId: BlockId,
      blockData: ManagedBuffer,
      level: StorageLevel,
      classTag: ClassTag[_])
  将blockData上传到远端

下面是NettyBlockTransferService的初始化代码:

override def init(blockDataManager: BlockDataManager): Unit = {
    val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager)
    var serverBootstrap: Option[TransportServerBootstrap] = None
    var clientBootstrap: Option[TransportClientBootstrap] = None
    if (authEnabled) {
      serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager))
      clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager))
    }
    transportContext = new TransportContext(transportConf, rpcHandler)
    clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava)
    server = createServer(serverBootstrap.toList)
    appId = conf.getAppId
    logInfo(s"Server created on ${hostName}:${server.getPort}")
  }

  /** Creates and binds the TransportServer, possibly trying multiple ports. */
  private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = {
    def startService(port: Int): (TransportServer, Int) = {
      val server = transportContext.createServer(bindAddress, port, bootstraps.asJava)
      (server, server.getPort)
    }

    Utils.startServiceOnPort(_port, startService, conf, getClass.getName)._1
  }

和Spark rpc创建NettyRpcEnv时通过TransportContext创建netty的server端和client,不同的是注册的handler由NettyRpcHandler变成了NettyBlockRpcServer,那NettyBlockTransferService整个框架如下图所示:

Spark数据传输_第1张图片
图1. NettyBlockTransferService框架

2.1 客户端发起请求

回到源码中请求block的的fetchBlocks方法如下:

override def fetchBlocks(
      host: String,
      port: Int,
      execId: String,
      blockIds: Array[String],
      listener: BlockFetchingListener): Unit = {
    logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
    try {
      val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
        override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
          // 创建了一个基于netty的client,参考上图TransportClient部分
          val client = clientFactory.createClient(host, port)
          // OneForOneBlockFetcher用来拉取数据
          new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
        }
      }

      val maxRetries = transportConf.maxIORetries()
      if (maxRetries > 0) {
        // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
        // a bug in this code. We should remove the if statement once we're sure of the stability.
        new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
      } else {
        blockFetchStarter.createAndStart(blockIds, listener)
      }
    } catch {
      case e: Exception =>
        logError("Exception while beginning fetchBlocks", e)
        blockIds.foreach(listener.onBlockFetchFailure(_, e))
    }
  }

上面代码中拉取数据的核心在OneForOneBlockFetcher,其start方法完成数据的远程读取。首先是创建OneForOneBlockFetcher是指定的参数:

  1. client, 用来发起网络请求
  2. execId, 远程的executorId
  3. blockIds,一次请求的一批blockid
  4. listener,第1节中ShuffleBlockFetchIterator#sendRequest方法中创建的那个listener,用来成功返回数据或者失败是回调的。

接下来进入OneForOneBlockFetcher#start方法:

 private class ChunkCallback implements ChunkReceivedCallback {
    @Override
    public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
      // On receipt of a chunk, pass it upwards as a block.
      listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
    }

    @Override
    public void onFailure(int chunkIndex, Throwable e) {
       ... 
    }
  }

public void start() {
    if (blockIds.length == 0) {
      throw new IllegalArgumentException("Zero-sized blockIds array");
    }

   // 发送OpenBlocks rpc请求
    client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
      @Override
      public void onSuccess(ByteBuffer response) {
        try {
          streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
          logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);

          // 成功返回是回调,返回的是StreamHandler,然后再调用fetchChunk,请求数据。
          // 说明第一次OpenBlock是请求在服务端收到之后,没有将数据返回过来,而只是准备好了数据,并将数据信息包装在StreamHandler里返回。
          // numChunks一般情况下等于OpenBlocks消息请求的block数量。
          // chunkCallBack即上面ChunkCallback的实例,在服务端反馈结果是回调,可一看到其onSuccess方法中又回调了创建OneForOneBlockFetcher时传的listener。
          for (int i = 0; i < streamHandle.numChunks; i++) {
            client.fetchChunk(streamHandle.streamId, i, chunkCallback);
          }
        } catch (Exception e) {
          logger.error("Failed while starting block fetches after success", e);
          failRemainingBlocks(blockIds, e);
        }
      }

      @Override
      public void onFailure(Throwable e) {
        logger.error("Failed while starting block fetches", e);
        failRemainingBlocks(blockIds, e);
      }
    });
  }

2.2 服务端接收到请求

2.1节中客户端发起了OpenBlocks请求,图1中服务端接收到OpenBlocks请求后经过解码(MessageDecoder)后再到TransportRequestHandler, TransportRequestHandler调用NettyBlockRpcHandler处理消息,处理消息的逻辑在其receive()方法中:

override def receive(
      client: TransportClient,
      rpcMessage: ByteBuffer,
      responseContext: RpcResponseCallback): Unit = {
    val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
    logTrace(s"Received request: $message")

    message match {
      case openBlocks: OpenBlocks =>
        val blocks: Seq[ManagedBuffer] =
// 使用blockManager获取到请求的所有blockid的blockdata          
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
          // 将blockdata注册到streamManager
        val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
        logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
        // 返回StreamHandler,上面2.1节中fetchBlocks接受到的
        // streamHandler
        responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)

      case uploadBlock: UploadBlock =>
        ...
    }
  }

服务端接收到OpenBlocks请求的,通过blockManager获取到block data,然后注册到StreamManager,返回StreamHandler给client,client再根据streamHandler请求数据。

2.1节fetchBlocks代码,client接收到streamHandler之后,调用client.fetchChunks发送ChunkFetchRequest到server端,server在何处处理这个消息?

图1中server端还注册了TransportChannelHandler,对于进来的RrequestMessage,还会经过TransportRequestHandler处理,下面是这个这个handler的handle方法:

 public void handle(RequestMessage request) {
   // 接受到client发送的ChunkFetchRequest请求,包括streamId和chunkIndex
    if (request instanceof ChunkFetchRequest) {
      processFetchRequest((ChunkFetchRequest) request);
    } else if (request instanceof RpcRequest) {
      processRpcRequest((RpcRequest) request);
    } else if (request instanceof OneWayMessage) {
      processOneWayMessage((OneWayMessage) request);
    } else if (request instanceof StreamRequest) {
      processStreamRequest((StreamRequest) request);
    } else {
      throw new IllegalArgumentException("Unknown request type: " + request);
    }
  }

  private void processFetchRequest(final ChunkFetchRequest req) {
    if (logger.isTraceEnabled()) {
      logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel),
        req.streamChunkId);
    }

    ManagedBuffer buf;
    try {
      streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
      streamManager.registerChannel(channel, req.streamChunkId.streamId);
  // server端接受到OpenBlocks消息后会将blockdata注册到streamManager, 此处根据streamId, chunkId取出数据返回给client。
      buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
    } catch (Exception e) {
      logger.error(String.format("Error opening block %s for request from %s",
        req.streamChunkId, getRemoteAddress(channel)), e);
      respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e)));
      return;
    }

    respond(new ChunkFetchSuccess(req.streamChunkId, buf));
  }

你可能感兴趣的:(Spark数据传输)