涉及内容从Task执行,到RDD的读取,以及Shuffle数据的获取。本章主要从第二部分入手
Task体系
一 ShuffleMapTask的读和写
二 Shuffle Block的读和写
三 External Shuffle Service的设计
上一章完成了从ShuffledRDD到ShuffleBlock的读取,这一章节侧重于作为ExternalShuffleService的CLient端,Spark Executor如何完成shuffle 数据的读取。
因为Spark使用了Netty作为底层的数据传输框架,所以阅读这一部分需要掌握基础的Netty知识,否则看起来的时候很难以理解。
建议通过官方文档进行入门的基础学习。
https://netty.io/3.7/guide/#architecture
Spark中为Shuffle数据的读取设计了如下的消息类型
Message.java
enum Type implements Encodable {
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
RpcRequest(3), RpcResponse(4), RpcFailure(5),
StreamRequest(6), StreamResponse(7), StreamFailure(8),
OneWayMessage(9), UploadStream(10), User(-1);
下面摘抄一段TransportClient.java中的注释
1. For example, a typical workflow might be:
2. client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100
3. client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
4. client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
5. client.sendRPC(new CloseStream(100))
这里只说明其中几个比较关键的消息类型
这里拿出OpenBlock为例子进行一下分析。
在OneForOneBLockFetcher.java中可以分析出如下的关系, openblock请求会得到streamhandle作为响应。
ShuffleClient ExternalShuffleService
|------->–openblock-------->|
| |
|<--------streamHandle–<-- |
| |
public OpenBlocks(String appId, String execId, String[] blockIds) {
this.appId = appId;
this.execId = execId;
this.blockIds = blockIds;
}
public StreamHandle(long streamId, int numChunks) {
this.streamId = streamId;
this.numChunks = numChunks;
}
public void start() {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
try {
streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response);
logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle);
// Immediately request all chunks -- we expect that the total size of the request is
// reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
for (int i = 0; i < streamHandle.numChunks; i++) {
if (downloadFileManager != null) {
client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
new DownloadCallback(i));
} else {
client.fetchChunk(streamHandle.streamId, i, chunkCallback);
}
}
} catch (Exception e) {
logger.error("Failed while starting block fetches after success", e);
failRemainingBlocks(blockIds, e);
}
}
@Override
public void onFailure(Throwable e) {
logger.error("Failed while starting block fetches", e);
failRemainingBlocks(blockIds, e);
}
});
}
如上代码稍微注意一下RpcResponseCallback的使用,为社么这里会有一个这样的Callback存在。稍后进行分析
public ChunkFetchRequest(StreamChunkId streamChunkId) {
this.streamChunkId = streamChunkId;
}
public StreamChunkId(long streamId, int chunkIndex) {
this.streamId = streamId;
this.chunkIndex = chunkIndex;
}
于其对应的response有两个,分别是ChunkFetchFailure和ChunkFetchSuccess。
public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
super(buffer, true);
this.streamChunkId = streamChunkId;
}
这里ChunkFetchSuccess携带了一个ManagedBuffer类型的对象。这个对象代表的返回的数据。
同样思考一下上面OneForOneBLockFetcher.java中为什么发起ChunkFetchRequest的时候传入了一个chunkCallback对象作为callback?
ShuffleRDD iterator() --> compute() --> SparkEnv.get.shuffleManager.getReader().read --> SortShuffleManager.getReader().read() --> BlockStoreShuffleReader.read() [这里完成了和mapOutputTrackerMaster通信,和shuffleservice数据读取和本地sort merge] --> ShuffleBlockFetcherIterator.next(),这里每次都出来的是一个block,下文从ShuffleBlockFetcherIterator开始分析
ShuffleBlockFetcherIterator.java
final class ShuffleBlockFetcherIterator(
context: TaskContext,
shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
maxBlocksInFlightPerAddress: Int,
maxReqSizeShuffleToMem: Long,
detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {
初始化的时候就开始 fetchUpToMaxBytes() 进行数据的获取了
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
context.addTaskCompletionListener[Unit](_ => cleanup())
// Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
assert ((0 == reqsInFlight) == (0 == bytesInFlight),
"expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
// Send out initial requests for blocks, up to our maxBytesInFlight
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
这里会进行一些限流,然后发起rpc请求
private def fetchUpToMaxBytes(): Unit = {
// Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
// immediately, defer the request until the next time it can be processed.
// Process any outstanding deferred fetch requests if possible.
if (deferredFetchRequests.nonEmpty) {
for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
while (isRemoteBlockFetchable(defReqQueue) &&
!isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
val request = defReqQueue.dequeue()
logDebug(s"Processing deferred fetch request for $remoteAddress with "
+ s"${request.blocks.length} blocks")
send(remoteAddress, request)
if (defReqQueue.isEmpty) {
deferredFetchRequests -= remoteAddress
}
}
}
}
// Process any regular fetch requests if possible.
while (isRemoteBlockFetchable(fetchRequests)) {
val request = fetchRequests.dequeue()
val remoteAddress = request.address
if (isRemoteAddressMaxedOut(remoteAddress, request)) {
logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
defReqQueue.enqueue(request)
deferredFetchRequests(remoteAddress) = defReqQueue
} else {
send(remoteAddress, request)
}
}
def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
sendRequest(request)
numBlocksInFlightPerAddress(remoteAddress) =
numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
}
最关键的一段, 下面的代码完成了整个向ExternalShuffleService发送请求以及获取所有数据的过程。
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
reqsInFlight += 1
// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
val blockIds = req.blocks.map(_._1.toString)
val address = req.address
val blockFetchingListener = new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
ShuffleBlockFetcherIterator.this.synchronized {
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
}
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
// Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
// already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
// the data and write it to file directly.
if (req.size > maxReqSizeShuffleToMem) {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, this)
} else {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, null)
}
}