SortShuffleManager只有BlockStoreShuffleReader
这一种ShuffleReader
首先获取要读取的数据位置信息:当ShuffleMapTask
完成时,会回调DAGSchedule的handleTaskCompletion
方法,内部的MapOutputTracker
会记录对应(shuffleId: Int, mapId: Int, status: MapStatus)
,所以后续executor端的MapOutputTrackerWorker
通过getMapSizesByExecutorId
可以获取driver端的记录信息,经过处理返回Seq[(BlockManagerId, Seq[(BlockId, Long)])
,BlockManagerId可以确定BlockManager的位置,Seq[(BlockId, Long)]
为该BlockManager的在指定partition范围内的数据块信息,记录长度是为了控制内存的使用
然后read方法实例化ShuffleBlockFetcherIterator
,它是获取数据块的迭代器,本地数据块直接调用BlockManager.getBlockData
,远程数据块采用Netty通过网络获取,数据块被存储到内存中,除非达到阈值(spark.maxRemoteBlockSizeFetchToMem
,默认值Int.MaxValue - 512),返回每个数据块的(BlockId, InputStream)
根据具体的需求,如果指定了聚合,根据map端是否进行聚合,分为两种情况,最后如果指定了keyOrdering,进行排序,最后返回迭代器Iterator[Product2[K, C]]
override def read(): Iterator[Product2[K, C]] = {
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
//获取远程数据块
blockManager.shuffleClient,
//获取本地数据块
blockManager,
//需要获取的数据块列表
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
//对数据流进行压缩和加密的相关处理
serializerManager.wrapStream,
//正在获取的最大远程数据量48M
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
//最大请求数目
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
//每个地址正在获取的数据块数目最大值
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
//shuffle数据存储到内存的最大字节,
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
//检测获取块中的损坏
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()
// 为每个数据流的创建一个 key/value 迭代器,然后连接起来
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// 度量每一条记录
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
// 再套一层迭代器,它通过检查TaskContext中的中断标志,提供任务中断功能
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
//指定了聚合,在reduce端同样需要进行聚合操作
//内部使用ExternalAppendOnlyMap进行聚合操作,类似ExternalSort的实现,不过没有进行排序,只支持K值相等时进行聚合
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
//map端已经进行了聚合,合并已经聚合过的值,涉及[K, C]
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
//map端没有进行了聚合,需要完整的进行聚合操作,涉及[K, V, C]
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// 如果指定了排序,对输出结果进行排序,使用ExternalSorter,最终返回排序过的数据迭代器
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
}
ShuffleBlockFetcherIterator
ShuffleBlockFetcherIterator
在初始化时会调用initialize
方法
private[this] def initialize(): Unit = {
// 任务完成时回调,用于清空数据
context.addTaskCompletionListener(_ => cleanup())
// 区分本地和远程数据块
val remoteRequests = splitLocalRemoteBlocks()
// 将远程数据块请求乱序添加到请求队列中
fetchRequests ++= Utils.randomize(remoteRequests)
...
//发送请求,确保请求的数据量不超过maxBytesInFlight
fetchUpToMaxBytes()
//部分数据块请求已经开始处理
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
//获取本地数据块,内部通过IndexShuffleBlockResolver.getBlockData方法
//然后构造一个SuccessFetchResult添加到结果记录队列results中
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
将远程数据块封装为FetchRequest
数组
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// 实际请求数据时大小为最大值的1/5,可以从5个节点并行的获取数据,避免阻塞到一个节点上
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
...
// 远程数据块会被分成过个FetchRequests,避免超过最大正在传输数据量的限制
val remoteRequests = new ArrayBuffer[FetchRequest]
// Tracks total number of blocks (including zero sized blocks)
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
//blockManager位于同一个executor,为本地数据块
if (address.executorId == blockManager.blockManagerId.executorId) {
// 过滤掉大小为0的数据块,其他的记录到localBlocks
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
//记录数据块的总数
numBlocksToFetch += localBlocks.size
} else {
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
// 处理非空的数据块
if (size > 0) {
curBlocks += ((blockId, size))
remoteBlocks += blockId //记录到remoteBlocks
numBlocksToFetch += 1 //记录数据块的总数
curRequestSize += size //记录数据块大小
} else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
//数据块大小,或者该address下数据块数目达到限定,封装为一个FetchRequest
if (curRequestSize >= targetRequestSize ||
curBlocks.size >= maxBlocksInFlightPerAddress) {
remoteRequests += new FetchRequest(address, curBlocks)
...
//重置数据
curBlocks = new ArrayBuffer[(BlockId, Long)]
curRequestSize = 0
}
}
// 将剩余的远程数据块封装为一个FetchRequest
if (curBlocks.nonEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
...
remoteRequests
}