shuffle是作业执行过程中的一个重要阶段,对作业性能有很大影响,不管是对hadoop还是spark,shuffle都是一个核心环节,spark的shuffle和hadoop的shuffle的原理大致相同,shuffle发生在ShuffleMapTask中,在一个task处理partition数据时,需要对外输出作为下个stage的数据源,这个输出可能是不落盘的,但如果数据量很大,导致内存放不下,这时候就需要shuffle了,shuffle包含spill、和merge两个阶段,最终会输出为一个shuffle文件(spark1.3.1)和一个包含个partition索引的index文件。如果包含map端的aggregate和order则会对partition和KV数据进行排序,负责会按partition id进行排序,这样在reduce阶段,通过索引和该shuffle文件就能找到reduce需要的数据。
上一篇提到task的启动在runTask函数中,我们从这里开始看整个shuffle的过程
override def runTask(context: TaskContext): MapStatus = {
// 反序列化获得rdd和其依赖
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
metrics = Some(context.taskMetrics)
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
//获得SortShuffleWriter实例
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
//中间数据的操作再这里写出,有先写内存,不够时写磁盘
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
return writer.stop(success = true).get
} catch {
.....
}
}
写出时用到了ExternalSorter,根据是否启用map端合并创建不用类型的实例,如果指定了aggregator和keyOrdering则会按顺序输出
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
sorter = new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
sorter.insertAll(records)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
sorter = new ExternalSorter[K, V, V](
None, Some(dep.partitioner), None, dep.serializer)
sorter.insertAll(records)
}
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val outputFile = shuffleBlockManager.getDataFile(dep.shuffleId, mapId)
val blockId = shuffleBlockManager.consolidateId(dep.shuffleId, mapId)//合并spill文件
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)//写索引文件
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
下面是遍历一个partition中的数据并做输出
def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
// TODO: stop combining if we find that the reduction factor isn't high
val shouldCombine = aggregator.isDefined
//如果执行combine则会对map数据进行merge操作
if (shouldCombine) {
// Combine values in-memory first using our AppendOnlyMap
val mergeValue = aggregator.get.mergeValue
val createCombiner = aggregator.get.createCombiner
var kv: Product2[K, V] = null
//merge函数
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
//遍历partition数据
while (records.hasNext) {
addElementsRead()
kv = records.next()
map.changeValue((getPartition(kv._1), kv._1), update)
//对外输出,可能产生spill,至于何时spill,看下面分析
maybeSpillCollection(usingMap = true)
}
} else if (bypassMergeSort) {
// SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
if (records.hasNext) {
spillToPartitionFiles(records.map { kv =>
((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
})
}
} else {
// Stick values into our buffer
while (records.hasNext) {
addElementsRead()
val kv = records.next()
buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
}
}
}
这里包含了spill的触发条件,最关键条件就是内存不够,当当前内粗不足时会首先申请内存,如果申请后仍然不够则需要spill,那么内存够的情况下,数据放哪里呢?继续看后面分析
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
myMemoryThreshold += granted
if (myMemoryThreshold <= currentMemory) {
// We were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold); spill the current collection
_spillCount += 1
logSpillage(currentMemory)
spill(collection)
_elementsRead = 0
// Keep track of spills, and release memory
_memoryBytesSpilled += currentMemory
releaseMemoryForThisThread()
return true
}
}
false
}
如果内存够用,则把数据放入一个SizeTrackingAppendOnlyMap数据结构中,这是一个只能append的kv map集合,由spark本身实现,注意并不是java的collection,有兴趣的朋友可以翻下代码看看,内部存储是一个Array
返回头看看spill写文件的过程,上面讲到了spill的触发条件,那么条件满足后文件怎么写的,写到哪里呢,看下面分析
/**
* Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
*/
override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
if (bypassMergeSort) {
spillToPartitionFiles(collection)
} else {
spillToMergeableFile(collection)
}
}
下面操作会真正的写出数据,并且是按key有序的输出,输出形式是 key1,value1,key2,value2.....,真正写文件的时候还是比较简单的,直接用BlockObjectWriter把一个个对象写到文件
/**
* Spill our in-memory collection to a sorted file that we can merge later (normal code path).
* We add this file into spilledFiles to find it later.
*
* Alternatively, if bypassMergeSort is true, we spill to separate files for each partition.
* See spillToPartitionedFiles() for that code path.
*
* @param collection whichever collection we're using (map or buffer)
*/
private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
assert(!bypassMergeSort)
// Because these files may be read during shuffle, their compression must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more context.
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
var objectsWritten = 0 // Objects written since the last flush
// List of batch sizes (bytes) in the order they are written to disk
val batchSizes = new ArrayBuffer[Long]
// How many elements we have in each partition
val elementsPerPartition = new Array[Long](numPartitions)
// Flush the disk writer's contents to disk, and update relevant variables.
// The writer is closed at the end of this process, and cannot be reused.
def flush() = {
val w = writer
writer = null
w.commitAndClose()
_diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
batchSizes.append(curWriteMetrics.shuffleBytesWritten)
objectsWritten = 0
}
var success = false
try {
val it = collection.destructiveSortedIterator(partitionKeyComparator)
while (it.hasNext) {
val elem = it.next()
val partitionId = elem._1._1
val key = elem._1._2
val value = elem._2
writer.write(key)
writer.write(value)
elementsPerPartition(partitionId) += 1
objectsWritten += 1
if (objectsWritten == serializerBatchSize) {
flush()
curWriteMetrics = new ShuffleWriteMetrics()
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
}
}
if (objectsWritten > 0) {
flush()
} else if (writer != null) {
val w = writer
writer = null
w.revertPartialWritesAndClose()
}
success = true
} finally {
if (!success) {
// This code path only happens if an exception was thrown above before we set success;
// close our stuff and let the exception be thrown further
if (writer != null) {
writer.revertPartialWritesAndClose()
}
if (file.exists()) {
file.delete()
}
}
}
//把spillwenjian加入集合中,以便后期合并
spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
}
最后spill文件会被合并成一个shuffle文件,spill文件的名字类似于temp_shuffle_&^%%$%之类的名字,合并为一个shuffle文件后,其他文件会被删除,上面就是整个shuffle过程,中间还有很多细节,感兴趣的朋友可以深入分析下。