大部分map task与reduce task的执行是在不同的节点上,reduce执行时需要跨节点去拉取其它节点上的ShuffleMapTask结果,那么对集群内部的网络资源消耗会很严重。我们希望最大化地减少不必要的消耗, 于是对Shuffle过程的期望有:
本文将从减少网络资源消耗的角度探讨reduce 端的Shuffle流程。主要知识点如下:
概述
BlockStoreShuffleReader 通过从其他节点网络拉取block data数据的方式,读取指定的[startPartition, endPartition)范围内的数据。
成员变量
private[spark] class BlockStoreShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C], //shuffleHandle
startPartition: Int, //指定的startPartition
endPartition: Int, //指定的endPartition
context: TaskContext,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
private val dep = handle.dependency //ShuffleDependency
read()方法
reducer端在拉取数据时,有一个重要的调优参数如下:
spark.reducer.maxSizeInFlight
默认值:48m
参数说明:该参数用于设置shuffle read task的buffer缓冲大小,而这个buffer缓冲决定了每次能够拉取多少数据。
调优建议:如果作业可用的内存资源较为充足的话,可以适当增加这个参数的大小(比如96m),从而减少拉取数据的次数,也就可以减少网络传输的次数,进而提升性能。在实践中发现,合理调节该参数,性能会有1%~5%的提升。如果可用的内存资源不足,出现reduce oom,则应该减少reduce task每次拉取的数据量 设置spark.reducer.maxSizeInFlight 24m。
/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
//创建ShuffleBlockFetcherIterator
val wrappedStreams = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
serializerManager.wrapStream,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
//设置shuffle read task的buffer缓冲大小
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()
// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] =
//如果shuffleDependency定义了聚合
if (dep.aggregator.isDefined) {
//如果还定义了map端的聚合
if (dep.mapSideCombine) {
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
//如果没有定义map端的聚合
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
//如果没有定义聚合
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
// Sort the output if there is a sort ordering defined.
val resultIter = dep.keyOrdering match {
//如果shuffleDependency定义了keyOrdering参数
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
//创建外部排序器:ExternalSorter
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
// Use completion callback to stop sorter if task was finished/cancelled.
context.addTaskCompletionListener[Unit](_ => {
sorter.stop()
})
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
resultIter match {
case _: InterruptibleIterator[Product2[K, C]] => resultIter
case _ =>
// Use another interruptible iterator here to support task cancellation as aggregator
// or(and) sorter may have consumed previous interruptible iterator.
new InterruptibleIterator[Product2[K, C]](context, resultIter)
}
}
概述
该Iterator用于拉取多种block数据。对于local block,它将从本地block manager获取block数据;对于remote block,它通过使用BlockTransferService拉取remote block。
它会创建一个(BlockID,InputStream) 的tuple类型的迭代器,所以调用者可以通过管道的方式处理这些block。
它会对remote fetch进行限流,控制不会超过MaxBytesInFlight (一次请求中的最大字节数),从而避免使用了太多内存。MaxBytesInFlight即对应上面所说的调优参数spark.reducer.maxSizeInFlight。
成员变量
* @param context [[TaskContext]], used for metrics update
* @param shuffleClient [[ShuffleClient]] for fetching remote blocks
* @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage. Note that zero-sized blocks are
* already excluded, which happened in
* [[MapOutputTracker.convertMapStatuses]].
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
* @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
* for a given remote host:port.
* @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
* @param detectCorrupt whether to detect any corruption in fetched blocks.
*/
private[spark]
final class ShuffleBlockFetcherIterator(
context: TaskContext,
shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
maxBlocksInFlightPerAddress: Int,
maxReqSizeShuffleToMem: Long,
detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {
/**
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
* the number of bytes in flight is limited to maxBytesInFlight.
*/
//FetchRequest队列,将FetchRequest入队,用以限制字节数不超过maxBytesInFlight
private[this] val fetchRequests = new Queue[FetchRequest]
构造函数
初始化ShuffleBlockFetcherIterator时会调用initialize()方法
initialize()
private[this] def initialize(): Unit = {
// Add a task completion callback (called in both success case and failure case) to cleanup.
context.addTaskCompletionListener[Unit](_ => cleanup())
// Split local and remote blocks.
//区分local block和remote block
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
fetchRequests ++= Utils.randomize(remoteRequests)
assert ((0 == reqsInFlight) == (0 == bytesInFlight),
"expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
// Send out initial requests for blocks, up to our maxBytesInFlight
//发送fetch request获取remote block
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
//获取local block
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
fetchUpToMaxBytes()方法
private def fetchUpToMaxBytes(): Unit = {
// Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
// immediately, defer the request until the next time it can be processed.
// Process any outstanding deferred fetch requests if possible.
if (deferredFetchRequests.nonEmpty) {
for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
while (isRemoteBlockFetchable(defReqQueue) &&
!isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
val request = defReqQueue.dequeue()
logDebug(s"Processing deferred fetch request for $remoteAddress with "
+ s"${request.blocks.length} blocks")
send(remoteAddress, request)
if (defReqQueue.isEmpty) {
deferredFetchRequests -= remoteAddress
}
}
}
}
// Process any regular fetch requests if possible.
//如果FetchRequest队列头元素中的remote block可fetchable
while (isRemoteBlockFetchable(fetchRequests)) {
//从FetchRequest队列中出队一个FetchRequest
val request = fetchRequests.dequeue()
//获取该FetchRequest的address
val remoteAddress = request.address
if (isRemoteAddressMaxedOut(remoteAddress, request)) {
logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks")
val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]())
defReqQueue.enqueue(request)
deferredFetchRequests(remoteAddress) = defReqQueue
} else {
send(remoteAddress, request)
}
}
//发送fetch request
def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
sendRequest(request)
numBlocksInFlightPerAddress(remoteAddress) =
numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
}
//判断remote block是否可fetchable,主要判断条件是是否达到限流阈值
def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
fetchReqQueue.nonEmpty &&
(bytesInFlight == 0 ||
(reqsInFlight + 1 <= maxReqsInFlight &&
//如果已有的字节数加上本次FetchRequest(队列的首元素)中的remote block size不超过maxBytesInFlight
bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
}
// Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a
// given remote address.
//判断发送的fetch request是否超过给定remote address的最大block size
def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = {
numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size >
maxBlocksInFlightPerAddress
}
}
sendRequest()方法
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
reqsInFlight += 1
// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
val blockIds = req.blocks.map(_._1.toString)
val address = req.address
val blockFetchingListener = new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
ShuffleBlockFetcherIterator.this.synchronized {
if (!isZombie) {
// Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
}
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
// Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
// already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
// the data and write it to file directly.
if (req.size > maxReqSizeShuffleToMem) {
//从远程节点异步拉取一系列的blocks
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, this)
} else {
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
blockFetchingListener, null)
}
}
NettyBlockTransferService的继承关系
fetchBlocks()方法
override def fetchBlocks(
host: String,
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener,
tempFileManager: DownloadFileManager): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
//创建TransportClient
val client = clientFactory.createClient(host, port)
//创建OneForOneBlockFetcher
new OneForOneBlockFetcher(client, appId, execId, blockIds, listener,
transportConf, tempFileManager).start()
}
}
val maxRetries = transportConf.maxIORetries()
if (maxRetries > 0) {
// Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
// a bug in this code. We should remove the if statement once we're sure of the stability.
new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
} else {
blockFetchStarter.createAndStart(blockIds, listener)
}
} catch {
case e: Exception =>
logError("Exception while beginning fetchBlocks", e)
blockIds.foreach(listener.onBlockFetchFailure(_, e))
}
}
概述
从给定的流中拉取连续的chunk。TransportClient通过将数据分成几百kb到几mb不等的chunk,从而有效地传输大量的数据。
TransportClient#sendRPC()方法用于构建客户端与服务端之间的传输流。
TransportClient#fetchChunk()方法用于从流中拉取chunk。
一个典型的工作流如下:
client.sendRPC(new OpenFile("/foo")) returns StreamId = 100
client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
...
client.sendRPC(new CloseStream(100))
TransportClient通过TransportClientFactory实现初始化。一个TransportClient可以对应多个stream,但是一个stream只能对应一个TransportClient,以防返回的数据乱序。
TransportClient负责发送请求到server端,而TransportResponseHandler负责处理从server端接收回来的数据。
//netty的channel实现类
private final Channel channel;
private final TransportResponseHandler handler;
@Nullable private String clientId;
参考:Spark性能调优之Shuffle调优
Executor是如何fetch shuffle的数据文件