本文基于spark 2.11
1. 前言
在spark shuffle write 和read一文中3.1.1 节创建ShuffleBlockFetchIterator来从上游所有partition从fetch数据,回顾一下调用ShuffleBlockFetchIterator#sendRequest发送请求数据block,代码如下:
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
reqsInFlight += 1
// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
val blockIds = req.blocks.map(_._1.toString)
val address = req.address
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
ShuffleBlockFetcherIterator.this.synchronized {
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
}
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
)
}
代码中shuffleClient从blockManager中获得的,在没有使用外部存储的情况下,是NettyBlockTransferService
的实例,下文会讲到NettyBlockTransferService基于netty实现的数据块服务。从fetchBlocks的参数表明了:
- address.host,数据所在node的地址
- port, NettyBlockTransferService服务端的端口,有配置
spark.driver.blockManager.port
指定。 - executorId,运行在node上的产生block的executorId
- blockIds,请求的一批block的id。
- BlockFetchingListener, 返回数据块时回调,将数据块存下来。
如果能弄清楚NettyBlockTransferService之后,数据从在executor和executor之间的传输过程就清晰了。
2. NettyBlockTransferService
NettyBlockTransferService是一个基于netty实现的数据传输服务,在文章Spark rpc实现一文中介绍过spark rpc框架也同样基于netty实现。NettyBlockTransferService在SparkEnv初始化时创建、在BlockManager中初始化。下面的方法表明他可以提供的服务:
- override def fetchBlocks(
host: String,
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener)
从远端node上获取blocks
- override def uploadBlock(
hostname: String,
port: Int,
execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel,
classTag: ClassTag[_])
将blockData上传到远端
下面是NettyBlockTransferService的初始化代码:
override def init(blockDataManager: BlockDataManager): Unit = {
val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager)
var serverBootstrap: Option[TransportServerBootstrap] = None
var clientBootstrap: Option[TransportClientBootstrap] = None
if (authEnabled) {
serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager))
clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager))
}
transportContext = new TransportContext(transportConf, rpcHandler)
clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava)
server = createServer(serverBootstrap.toList)
appId = conf.getAppId
logInfo(s"Server created on ${hostName}:${server.getPort}")
}
/** Creates and binds the TransportServer, possibly trying multiple ports. */
private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = {
def startService(port: Int): (TransportServer, Int) = {
val server = transportContext.createServer(bindAddress, port, bootstraps.asJava)
(server, server.getPort)
}
Utils.startServiceOnPort(_port, startService, conf, getClass.getName)._1
}
和Spark rpc创建NettyRpcEnv时通过TransportContext创建netty的server端和client,不同的是注册的handler由NettyRpcHandler
变成了NettyBlockRpcServer
,那NettyBlockTransferService整个框架如下图所示:
2.1 客户端发起请求
回到源码中请求block的的fetchBlocks方法如下:
override def fetchBlocks(
host: String,
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
// 创建了一个基于netty的client,参考上图TransportClient部分
val client = clientFactory.createClient(host, port)
// OneForOneBlockFetcher用来拉取数据
new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
}
}
val maxRetries = transportConf.maxIORetries()
if (maxRetries > 0) {
// Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
// a bug in this code. We should remove the if statement once we're sure of the stability.
new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
} else {
blockFetchStarter.createAndStart(blockIds, listener)
}
} catch {
case e: Exception =>
logError("Exception while beginning fetchBlocks", e)
blockIds.foreach(listener.onBlockFetchFailure(_, e))
}
}
上面代码中拉取数据的核心在OneForOneBlockFetcher,其start方法完成数据的远程读取。首先是创建OneForOneBlockFetcher是指定的参数:
- client, 用来发起网络请求
- execId, 远程的executorId
- blockIds,一次请求的一批blockid
- listener,第1节中ShuffleBlockFetchIterator#sendRequest方法中创建的那个listener,用来成功返回数据或者失败是回调的。
接下来进入OneForOneBlockFetcher#start方法:
private class ChunkCallback implements ChunkReceivedCallback {
@Override
public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
// On receipt of a chunk, pass it upwards as a block.
listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
}
@Override
public void onFailure(int chunkIndex, Throwable e) {
...
}
}
public void start() {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
// 发送OpenBlocks rpc请求
client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
try {
streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
// 成功返回是回调,返回的是StreamHandler,然后再调用fetchChunk,请求数据。
// 说明第一次OpenBlock是请求在服务端收到之后,没有将数据返回过来,而只是准备好了数据,并将数据信息包装在StreamHandler里返回。
// numChunks一般情况下等于OpenBlocks消息请求的block数量。
// chunkCallBack即上面ChunkCallback的实例,在服务端反馈结果是回调,可一看到其onSuccess方法中又回调了创建OneForOneBlockFetcher时传的listener。
for (int i = 0; i < streamHandle.numChunks; i++) {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
}
} catch (Exception e) {
logger.error("Failed while starting block fetches after success", e);
failRemainingBlocks(blockIds, e);
}
}
@Override
public void onFailure(Throwable e) {
logger.error("Failed while starting block fetches", e);
failRemainingBlocks(blockIds, e);
}
});
}
2.2 服务端接收到请求
2.1节中客户端发起了OpenBlocks请求,图1中服务端接收到OpenBlocks请求后经过解码(MessageDecoder)后再到TransportRequestHandler, TransportRequestHandler调用NettyBlockRpcHandler处理消息,处理消息的逻辑在其receive()方法中:
override def receive(
client: TransportClient,
rpcMessage: ByteBuffer,
responseContext: RpcResponseCallback): Unit = {
val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
logTrace(s"Received request: $message")
message match {
case openBlocks: OpenBlocks =>
val blocks: Seq[ManagedBuffer] =
// 使用blockManager获取到请求的所有blockid的blockdata
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
// 将blockdata注册到streamManager
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
// 返回StreamHandler,上面2.1节中fetchBlocks接受到的
// streamHandler
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)
case uploadBlock: UploadBlock =>
...
}
}
服务端接收到OpenBlocks请求的,通过blockManager获取到block data,然后注册到StreamManager,返回StreamHandler给client,client再根据streamHandler请求数据。
2.1节fetchBlocks代码,client接收到streamHandler之后,调用client.fetchChunks发送ChunkFetchRequest
到server端,server在何处处理这个消息?
图1中server端还注册了TransportChannelHandler,对于进来的RrequestMessage,还会经过TransportRequestHandler处理,下面是这个这个handler的handle方法:
public void handle(RequestMessage request) {
// 接受到client发送的ChunkFetchRequest请求,包括streamId和chunkIndex
if (request instanceof ChunkFetchRequest) {
processFetchRequest((ChunkFetchRequest) request);
} else if (request instanceof RpcRequest) {
processRpcRequest((RpcRequest) request);
} else if (request instanceof OneWayMessage) {
processOneWayMessage((OneWayMessage) request);
} else if (request instanceof StreamRequest) {
processStreamRequest((StreamRequest) request);
} else {
throw new IllegalArgumentException("Unknown request type: " + request);
}
}
private void processFetchRequest(final ChunkFetchRequest req) {
if (logger.isTraceEnabled()) {
logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel),
req.streamChunkId);
}
ManagedBuffer buf;
try {
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
streamManager.registerChannel(channel, req.streamChunkId.streamId);
// server端接受到OpenBlocks消息后会将blockdata注册到streamManager, 此处根据streamId, chunkId取出数据返回给client。
buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex);
} catch (Exception e) {
logger.error(String.format("Error opening block %s for request from %s",
req.streamChunkId, getRemoteAddress(channel)), e);
respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e)));
return;
}
respond(new ChunkFetchSuccess(req.streamChunkId, buf));
}