上一篇文章我们分析了Shuffle的write部分,本文中我们来继续分析Shuffle的read部分。
我们来看ShuffledRDD中的compute方法:
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
可以看到首先调用的是ShuffleManager的getReader方法来获得ShuffleReader,然后再调用ShuffleReader的read方法来读取map阶段输出的中间数据,而不管是HashShuffleManager还是SortShuffleManager,其getReader方法内部都是实例化了BlockStoreShuffleReader,而BlockStoreShuffleReader正是实现了ShuffleReader接口:
override def getReader[K, C](
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}
然后来看BlockStoreShuffleReader的read方法是具体如何工作的,即如何读取Map阶段输出的中间结果的:
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
// 首先实例化ShuffleBlockFetcherIterator
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
// Wrap the streams for compression based on configuration
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}
val ser = Serializer.getSerializer(dep.serializer)
val serializerInstance = ser.newInstance()
// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
record
}),
context.taskMetrics().updateShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
}
首先实例化ShuffleBlockFetcherIterator,实例化的时候传入了几个参数,这里介绍一下几个重要的:
- blockManager.shuffleClient
- mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
- SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024
shuffleClient就是用来读取其他executors上的shuffle文件的,有可能是ExternalShuffleClient或者BlockTransferService:
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores)
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
securityManager.isSaslEncryptionEnabled())
} else {
blockTransferService
}
而默认使用的是BlockTransferService,因为externalShuffleServiceEnabled默认为false:
private[spark]
val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
接下来的mapOutputTracker.getMapSizesByExecutorId就是获得该reduce task的数据来源(数据的元数据信息),传入的参数是shuffle的Id和partition的起始位置,返回的是Seq[(BlockManagerId, Seq[(BlockId, Long)])],也就是说数据是来自于哪个节点的哪些block的,并且block的数据大小是多少。
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
// 获得Map阶段输出的中间计算结果的元数据信息
val statuses = getStatuses(shuffleId)
// Synchronize on the returned array because, on the driver, it gets mutated in place
// 将获得的元数据信息转化成形如Seq[(BlockManagerId, Seq[(BlockId, Long)])]格式的位置信息,用来读取指定的Map阶段产生的数据
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
}
}
最后的getSizeAsMb获取的是一项配置参数,代表一次从Map端获取的最大的数据量。
获取元数据信息
下面我们来着重分析一下是怎样通过getStatuses来获取元数据信息的:
private def getStatuses(shuffleId: Int): Array[MapStatus] = {
// 根据shuffleId获得MapStatus组成的数组:Array[MapStatus]
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
// 如果没有获取到就进行fetch操作
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
val startTime = System.currentTimeMillis
// 用来保存fetch来的MapStatus
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
// 有可能有别的任务正在进行fetch,所以这里使用synchronized关键字保证同步
// Someone else is fetching it; wait for them to be done
while (fetching.contains(shuffleId)) {
try {
fetching.wait()
} catch {
case e: InterruptedException =>
}
}
// Either while we waited the fetch happened successfully, or
// someone fetched it in between the get and the fetching.synchronized.
// 等待过后继续尝试获取
fetchedStatuses = mapStatuses.get(shuffleId).orNull
if (fetchedStatuses == null) {
// We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
if (fetchedStatuses == null) {
// 如果得到了fetch的权利就进行抓取
// We won the race to fetch the statuses; do so
logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
// This try-finally prevents hangs due to timeouts:
try {
// 调用askTracker方法发送消息,消息的格式为GetMapOutputStatuses(shuffleId)
val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
// 将得到的序列化后的数据进行反序列化
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
// 保存到本地的mapStatuses中
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
}
logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
s"${System.currentTimeMillis - startTime} ms")
if (fetchedStatuses != null) {
// 最后将抓取到的元数据信息返回
return fetchedStatuses
} else {
logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
}
} else {
// 如果获取到了Array[MapStatus]就直接返回
return statuses
}
}
来看一下用来发送消息的askTracker方法,发送的消息是一个case class:GetMapOutputStatuses(shuffleId)
protected def askTracker[T: ClassTag](message: Any): T = {
try {
trackerEndpoint.askWithRetry[T](message)
} catch {
case e: Exception =>
logError("Error communicating with MapOutputTracker", e)
throw new SparkException("Error communicating with MapOutputTracker", e)
}
}
MapOutputTrackerMasterEndpoint在接收到该消息后的处理:
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = context.senderAddress.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
// 获得Map阶段的输出数据的序列化后的元数据信息
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
// 序列化后的大小
val serializedSize = mapOutputStatuses.length
// 判断是否超过maxAkkaFrameSize的限制
if (serializedSize > maxAkkaFrameSize) {
val msg = s"Map output statuses were $serializedSize bytes which " +
s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
/* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender.
* A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */
val exception = new SparkException(msg)
logError(msg, exception)
context.sendFailure(exception)
} else {
// 如果没有超过限制就将获得的元数据信息返回
context.reply(mapOutputStatuses)
}
case StopMapOutputTracker =>
logInfo("MapOutputTrackerMasterEndpoint stopped!")
context.reply(true)
stop()
}
接着来看tracker(MapOutputTrackerMaster)的getSerializedMapOutputStatuses方法:
def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
if (epoch > cacheEpoch) {
cachedSerializedStatuses.clear()
cacheEpoch = epoch
}
// 判断是否已经有缓存的数据
cachedSerializedStatuses.get(shuffleId) match {
case Some(bytes) =>
// 如果有的话就直接返回缓存数据
return bytes
case None =>
// 如果没有的话就从mapStatuses中获得
statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
epochGotten = epoch
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "statuses"; let's serialize and return that
// 序列化操作
val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
if (epoch == epochGotten) {
// 缓存操作
cachedSerializedStatuses(shuffleId) = bytes
}
}
// 返回序列化后的数据
bytes
}
获得序列化后的数据后会到getStatuses方法,将得到的序列化后的数据进行反序列化,并将反序列化后的数据保存到该Executor(也就是本地)的mapStatuses中,下次再使用的时候就不必重复的进行fetch操作,最后将获得的元数据信息转化成形如Seq[(BlockManagerId, Seq[(BlockId, Long)])]格式的位置信息,用来读取指定的Map阶段产生的数据。
根据得到的元数据信息抓取数据(分为远程和本地)
说完了这些参数后我们回到ShuffleBlockFetcherIterator的实例化过程,ShuffleBlockFetcherIterator实例化的时候会执行一个initialize()方法,用来进行一系列的初始化操作:
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
// 不管最后task是success还是failure,都要进行cleanup操作
context.addTaskCompletionListener(_ => cleanup())
// Split local and remote blocks.
// 将local和remote Blocks分离开,并将remote的返回给remoteRequests
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
// 这里的fetchRequests是一个队列,我们将远程的请求以随机的顺序加入到该队列,然后使用下面的
// fetchUpToMaxBytes方法取出队列中的远程请求,同时对大小进行限制
fetchRequests ++= Utils.randomize(remoteRequests)
// Send out initial requests for blocks, up to our maxBytesInFlight
// 从fetchRequests取出远程请求,并使用sendRequest方法发送请求
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
// 获取本地的Blocks
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
首先来看一下splitLocalRemoteBlocks方法是如何将remote和local的blocks分离开来的:
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
// 为了将大小控制在maxBytesInFlight以下,可以增加并行度,即从1个节点增加到5个
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
// Tracks total number of blocks (including zero sized blocks)
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
// 这里就是判断所要获取的是本地的block还是远程的block
if (address.executorId == blockManager.blockManagerId.executorId) {
// Filter out zero-sized blocks
// 过滤掉大小为0的blocks
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
numBlocksToFetch += localBlocks.size
} else {
// 这里就是远程的部分,主要就是构建remoteRequests
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
// Skip empty blocks
if (size > 0) {
curBlocks += ((blockId, size))
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
} else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
// 满足大小的限制就构建一个FetchRequest并加入到remoteRequests中
if (curRequestSize >= targetRequestSize) {
// Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curBlocks = new ArrayBuffer[(BlockId, Long)]
logDebug(s"Creating fetch request of $curRequestSize at $address")
curRequestSize = 0
}
}
// Add in the final request
// 最后将剩余的blocks构成一个FetchRequest
if (curBlocks.nonEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
// 最后返回remoteRequests
remoteRequests
}
这里的FetchRequest是一个数据结构,保存了要获取的blocks的位置信息,而remoteRequests就是这些FetchRequest组成的ArrayBuffer:
case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
然后使用fetchUpToMaxBytes()方法来获取远程的blocks信息:
private def fetchUpToMaxBytes(): Unit = {
// Send fetch requests up to maxBytesInFlight
while (fetchRequests.nonEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
}
可以看出内部就是从上一步获取的remoteRequests中取出一个FetchRequest并使用sendRequest发送该请求:
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
// 1、首先获得要fetch的blocks的信息
// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val blockIds = req.blocks.map(_._1.toString)
val address = req.address
// 2、然后通过shuffleClient的fetchBlocks方法来获取对应远程节点上的数据
// 默认是通过NettyBlockTransferService的fetchBlocks方法实现的
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
// 3、最后,不管成功还是失败,都将结果保存在results中
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
)
}
请求的结果最终保存在了results中,成功的就是SuccessFetchResult,失败的就是FailureFetchResult,具体是怎么fetchBlocks的就不在此说明,本文最后会给出一张图进行简要的概述,有兴趣的可以继续进行追踪,其实底层是通过NettyBlockTransferService实现的,通过index文件查找到data文件。
接下来看一下使用fetchLocalBlocks()方法来获取本地的blocks信息的过程:
private[this] def fetchLocalBlocks() {
val iter = localBlocks.iterator
while (iter.hasNext) {
val blockId = iter.next()
try {
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
} catch {
case e: Exception =>
// If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
return
}
}
}
这里就相对简单了,进行迭代,如果获取到就将SuccessFetchResult保存到results中,如果没有就将FailureFetchResult保存到results中,至此ShuffleBlockFetcherIterator的实例化及初始化过程结束,接下来我们再回到BlockStoreShuffleReader的read方法中:
override def read(): Iterator[Product2[K, C]] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
// Wrap the streams for compression based on configuration
// 将上面获取的信息进行压缩处理
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}
val ser = Serializer.getSerializer(dep.serializer)
val serializerInstance = ser.newInstance()
// Create a key/value iterator for each stream
// 为每个stream创建一个key/value的iterator
val recordIter = wrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// 统计系统相关的部分
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
record
}),
context.taskMetrics().updateShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
// 判断是否需要进行map端的聚合操作
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// 是否需要进行排序
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabl
// the ExternalSorter won't spill to disk.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.sto
case None =>
aggregatedIter
}
}
上面代码中aggregator和keyOrdering的部分在分析Shuffle Write的时候已经分析过了,这里我们再简单看一下相关的部分。
aggregator
首先来看aggregator部分:不管是否进行聚合操作,即不管最后执行的是combineCombinersByKey方法还是combineValuesByKey方法,最后都会执行ExternalAppendOnlyMap的insertAll方法:
combineCombinersByKey方法的实现:
def combineCombinersByKey(
iter: Iterator[_ <: Product2[K, C]],
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
combiners.insertAll(iter)
updateMetrics(context, combiners)
combiners.iterator
}
combineValuesByKey方法的实现:
def combineValuesByKey(
iter: Iterator[_ <: Product2[K, V]],
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
updateMetrics(context, combiners)
combiners.iterator
}
而这个insertAll方法在Shuffle Write的部分已经介绍过了:
def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
if (currentMap == null) {
throw new IllegalStateException(
"Cannot insert new elements into a map after calling iterator")
}
// An update function for the map that we reuse across entries to avoid allocating
// a new closure each time
var curEntry: Product2[K, V] = null
val update: (Boolean, C) => C = (hadVal, oldVal) => {
if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
}
while (entries.hasNext) {
curEntry = entries.next()
val estimatedSize = currentMap.estimateSize()
if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
if (maybeSpill(currentMap, estimatedSize)) {
currentMap = new SizeTrackingAppendOnlyMap[K, C]
}
currentMap.changeValue(curEntry._1, update)
addElementsRead()
}
}
而这里的next方法会最终调用ShuffleBlockFetcherIterator的next方法:
override def next(): (BlockId, InputStream) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size
case _ =>
}
// Send fetch requests up to maxBytesInFlight
// 这里就是关键的代码,即不断的去抓去数据,直到抓去到所有的数据
fetchUpToMaxBytes()
result match {
case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
case SuccessFetchResult(blockId, address, _, buf) =>
try {
(result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
} catch {
case NonFatal(t) =>
throwFetchFailedException(blockId, address, t)
}
}
}
可以看到该方法中就是一直抓数据,直到所有的数据都抓取到,然后就是执行combiners.iterator:
override def iterator: Iterator[(K, C)] = {
if (currentMap == null) {
throw new IllegalStateException(
"ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
}
if (spilledMaps.isEmpty) {
CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap())
} else {
new ExternalIterator()
}
}
接下来就看一下ExternalIterator的实例化都做了什么工作:
// A queue that maintains a buffer for each stream we are currently merging
// This queue maintains the invariant that it only contains non-empty buffers
private val mergeHeap = new mutable.PriorityQueue[StreamBuffer]
// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
// 按照key的hashcode进行排序
private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](
currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())
// 将map中的数据和spillFile中的数据的iterator组合在一起
private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
// 不断迭代,直到将所有数据都读出来,最后将所有的数据保存在mergeHeap中
inputStreams.foreach { it =>
val kcPairs = new ArrayBuffer[(K, C)]
readNextHashCode(it, kcPairs)
if (kcPairs.length > 0) {
mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
}
}
最后将所有读取的数据都保存在了mergeHeap中,再来看一下有排序的情况。
keyOrdering
// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
可以看到使用了ExternalSorter的insertAll方法,和Shuffle Write的时候操作是一样的,这里我们就不进行重复说明了,具体的内容可以参考上一篇文章,最后还是用张图来总结一下Shuffle Read的流程:
本文参照的是Spark 1.6.3版本的源码,同时给出Spark 2.1.0版本的连接:
Spark 1.6.3 源码
Spark 2.1.0 源码
本文为原创,欢迎转载,转载请注明出处、作者,谢谢!