spark SortShuffleWriter的实现

SortShuffleWriter是spark中一种shuffle的方式,一下是其write()方法。

override def write(records: Iterator[Product2[K, V]]): Unit = {
  sorter = if (dep.mapSideCombine) {
    new ExternalSorter[K, V, C](
      context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
  } else {
    // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
    // care whether the keys get sorted in each partition; that will be done on the reduce side
    // if the operation being run is sortByKey.
    new ExternalSorter[K, V, V](
      context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
  }
  sorter.insertAll(records)

  // Don't bother including the time to open the merged output file in the shuffle write time,
  // because it just opens a single file, so is typically too fast to measure accurately
  // (see SPARK-3570).
  val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
  val tmp = Utils.tempFileWith(output)
  try {
    val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
    val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
    shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths,
      writeMetrics.recordsWritten)
  } finally {
    if (tmp.exists() && !tmp.delete()) {
      logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
    }
  }
}

首先,根据进来的数据是否需要聚合来选择不同的ExternalSorter构造方式,之后将所有数据通过ExternalSorter的insertAll()方法进行排序。

def insertAll(records: Iterator[Product2[K, V]]): Unit = {
  // TODO: stop combining if we find that the reduction factor isn't high
  val shouldCombine = aggregator.isDefined

  if (shouldCombine) {
    // Combine values in-memory first using our AppendOnlyMap
    val mergeValue = aggregator.get.mergeValue
    val createCombiner = aggregator.get.createCombiner
    var kv: Product2[K, V] = null
    val update = (hadValue: Boolean, oldValue: C) => {
      if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
    }
    while (records.hasNext) {
      addElementsRead()
      kv = records.next()
      map.changeValue((getPartition(kv._1), kv._1), update)
      maybeSpillCollection(usingMap = true)
    }
  } else {
    // Stick values into our buffer
    while (records.hasNext) {
      addElementsRead()
      val kv = records.next()
      buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
      maybeSpillCollection(usingMap = false)
    }
  }
}

在insertAll()方法中,会根据之前是否需要聚合选择不同的容器来缓存数据,其中需要聚合选择map否则选择buffer。

Map和buffer的实现如下:

@volatile private var map = new PartitionedAppendOnlyMap[K, C]
@volatile private var buffer = new PartitionedPairBuffer[K, C]

看到这两者的区别,首先是PartitionedPairBuffer的insert()方法。

def insert(partition: Int, key: K, value: V): Unit = {
  if (curSize == capacity) {
    growArray()
  }
  data(2 * curSize) = (partition, key.asInstanceOf[AnyRef])
  data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
  curSize += 1
  afterUpdate()
}

其底层的存储结构是一个Array,其中每条记录会占用两个存储空间,第一个则是(分区号,key)的元组,而第二个则是具体的value,因为不需要考虑聚合操作,所以新纪录的加入直接加入到了数组的末端,需要扩容则扩容。

下面是PartitionedAppendOnlyMap()的insert()方法。

def insert(partition: Int, key: K, value: V): Unit = {
  update((partition, key), value)
}
def update(key: K, value: V): Unit = {
  assert(!destroyed, destructionMessage)
  val k = key.asInstanceOf[AnyRef]
  if (k.eq(null)) {
    if (!haveNullValue) {
      incrementSize()
    }
    nullValue = value
    haveNullValue = true
    return
  }
  var pos = rehash(key.hashCode) & mask
  var i = 1
  while (true) {
    val curKey = data(2 * pos)
    if (curKey.eq(null)) {
      data(2 * pos) = k
      data(2 * pos + 1) = value.asInstanceOf[AnyRef]
      incrementSize()  // Since we added a new key
      return
    } else if (k.eq(curKey) || k.equals(curKey)) {
      data(2 * pos + 1) = value.asInstanceOf[AnyRef]
      return
    } else {
      val delta = i
      pos = (pos + delta) & mask
      i += 1
    }
  }
}

虽然相比前者也是Array结构,但在key的分配方式上使用了hash的方式,并且一旦发生碰撞,将会通过线性探测法的方式解决冲突,适合需要聚合操作的场景。

 

回到insertAll()方法,当完成数据在map或者buffer上的插入后,将会通过maybeSpillCollection()方法判断是否需要创建临时文件。

private def maybeSpillCollection(usingMap: Boolean): Unit = {
  var estimatedSize = 0L
  if (usingMap) {
    estimatedSize = map.estimateSize()
    if (maybeSpill(map, estimatedSize)) {
      map = new PartitionedAppendOnlyMap[K, C]
    }
  } else {
    estimatedSize = buffer.estimateSize()
    if (maybeSpill(buffer, estimatedSize)) {
      buffer = new PartitionedPairBuffer[K, C]
    }
  }

  if (estimatedSize > _peakMemoryUsedBytes) {
    _peakMemoryUsedBytes = estimatedSize
  }
}
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
  var shouldSpill = false
  if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
    // Claim up to double our current memory from the shuffle memory pool
    val amountToRequest = 2 * currentMemory - myMemoryThreshold
    val granted = acquireMemory(amountToRequest)
    myMemoryThreshold += granted
    // If we were granted too little memory to grow further (either tryToAcquire returned 0,
    // or we already had more memory than myMemoryThreshold), spill the current collection
    shouldSpill = currentMemory >= myMemoryThreshold
  }
  shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
  // Actually spill
  if (shouldSpill) {
    _spillCount += 1
    logSpillage(currentMemory)
    spill(collection)
    _elementsRead = 0
    _memoryBytesSpilled += currentMemory
    releaseMemory()
  }
  shouldSpill
}

在maybeSpill()方法中,如果当前扩大内存不足当前的二倍,或者当前内存已经超过规定的需要spill的内存,则会开始spill凑走。具体的spill操作在spill()方法中。

override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
  val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
  val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
  spills += spillFile
}

在此处的spill中会对数据进行相应的排序,并写入到相应的临时文件中。在destructiveSortedWritablePartitionedIterator()方法中。

def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
  : WritablePartitionedIterator = {
  val it = partitionedDestructiveSortedIterator(keyComparator)
  new WritablePartitionedIterator {
    private[this] var cur = if (it.hasNext) it.next() else null

    def writeNext(writer: DiskBlockObjectWriter): Unit = {
      writer.write(cur._1._2, cur._2)
      cur = if (it.hasNext) it.next() else null
    }

    def hasNext(): Boolean = cur != null

    def nextPartition(): Int = cur._1._1
  }
}

此处会通过partitionedDestructiveSortedIterator()方法得到已经排序完毕的迭代器。

具体的排序逻辑如下:

def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] {
  override def compare(a: (Int, K), b: (Int, K)): Int = {
    a._1 - b._1
  }
}

/**
 * A comparator for (Int, K) pairs that orders them both by their partition ID and a key ordering.
 */
def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = {
  new Comparator[(Int, K)] {
    override def compare(a: (Int, K), b: (Int, K)): Int = {
      val partitionDiff = a._1 - b._1
      if (partitionDiff != 0) {
        partitionDiff
      } else {
        keyComparator.compare(a._2, b._2)
      }
    }
  }
}

实现还是相对简单,先比较分区号,在比较key的大小,来确定排序的顺序。

 

在得到排好序的迭代器之后一次写入到临时文件中,释放掉当前内存,重新刷新缓存存储数据,完成了spill的目的。

 

 

回到SortShuffleSorter的write()方法,当通过insertAll()将数据写入内存和临时文件之后,需要将其merge成一个大的文件。

val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val tmp = Utils.tempFileWith(output)
try {
  val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
  val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
  shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
  mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths,
    writeMetrics.recordsWritten)
} finally {
  if (tmp.exists() && !tmp.delete()) {
    logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
  }
}

Merge的核心操作在sorter的writePartitionedFile()方法中。

在writePartitionedFile()方法中,在之前的操作中如果数据量并没有到达spill的地步,那么所有需要merge的数据都存在当前的内存中,就只需要重复类似之前spill的操作把当前内存的数据写到最后的文件中,但是如果之前已经存在spill操作,那么就需要把临时文件的数据和当前内存中的数据一起merge到最后的文件中,代码如下:

for ((id, elements) <- this.partitionedIterator) {
  if (elements.hasNext) {
    for (elem <- elements) {
      writer.write(elem._1, elem._2)
    }
    val segment = writer.commitAndGet()
    lengths(id) = segment.length
  }
}

此处可以看到,这里会得到一个分区迭代器,根据分区迭代器的顺序依次将各个分区中的数据依次顺序写入到结果文件中。

其核心逻辑在merge()方法中。

private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
    : Iterator[(Int, Iterator[Product2[K, C]])] = {
  val readers = spills.map(new SpillReader(_))
  val inMemBuffered = inMemory.buffered
  (0 until numPartitions).iterator.map { p =>
    val inMemIterator = new IteratorForPartition(p, inMemBuffered)
    val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
    if (aggregator.isDefined) {
      // Perform partial aggregation across partitions
      (p, mergeWithAggregation(
        iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
    } else if (ordering.isDefined) {
      // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
      // sort the elements without trying to merge them
      (p, mergeSort(iterators, ordering.get))
    } else {
      (p, iterators.iterator.flatten)
    }
  }
}

在这里,会根据数据的分区数量构造相应数量的分区迭代器。

分区迭代器会依次在spill文件中读取所有当前分区的数据,如果定义了聚合或者排序,将会在这里进行操作,否则直接返回。

这里的排序值得一看。

private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
    : Iterator[Product2[K, C]] =
{
  val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
  type Iter = BufferedIterator[Product2[K, C]]
  val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
    // Use the reverse order because PriorityQueue dequeues the max
    override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1)
  })
  heap.enqueue(bufferedIters: _*)  // Will contain only the iterators with hasNext = true
  new Iterator[Product2[K, C]] {
    override def hasNext: Boolean = !heap.isEmpty

    override def next(): Product2[K, C] = {
      if (!hasNext) {
        throw new NoSuchElementException
      }
      val firstBuf = heap.dequeue()
      val firstPair = firstBuf.next()
      if (firstBuf.hasNext) {
        heap.enqueue(firstBuf)
      }
      firstPair
    }
  }
}

这里各个spill的同一分区的数据已经能进行排序,所以不断获取各个spill的第一个数据就可完成排序。

最后得到所有分区上已经排序好的迭代器,一次顺序写入到最后的文件中。

回到write()方法通过writeIndexFileAndComiit()方法。

索引文件的实现很简单,讲各个分区在文件中的起始偏移量写入索引文件即可。

Utils.tryWithSafeFinally {
  // We take in lengths of each block, need to convert it to offsets.
  var offset = 0L
  out.writeLong(offset)
  for (length <- lengths) {
    offset += length
    out.writeLong(offset)
  }
}

 

你可能感兴趣的:(spark)