Spark是如何实现排序的?

Abstract

昨天丢人现眼的写QuickSort用了40分钟, 当时感觉整个人都不好了.
(╯°□°)╯︵┻━┻ 看孩子一天睡4小时大脑不转哇 d(・`ω´・d*)

External Sort的标准做法是一个QuickSort后边跟一个n-way MergeSort, 理论上的复杂度也是nlogn.

但是由于存在文件IO, 所以实际速度要慢于带内排序很多.

在分布式环境里, 这个问题进一步复杂化, 每台机器持有的是数据的一部分, 如果需要执行经典的外排序, 则需要不断的把所有节点的数据向一个中心节点进行shuffle. 磁盘IO进一步衰退为网络IO.

更进一步分析这个问题, 可以在一开始处理数据的时候, 把数据分为多份. RANGE(0, 1e3)的第一台机器, RANGE(1e3, 2e3)第二台.....

通过HASH的方法, 让每台机器天然有序, 继而每台机器内部跑外排.

这样就需要对数据的分布有一定的了解, 通过抽样来理解数据的整体排布方式, 然后决定每台机器处理的数据范围是一个大的思路.

在下面这个网站, 可以找到排序算法的效率排行榜. 你会发现BAT三家都在打这个榜...

http://sortbenchmark.org/

TeraSort 原理

Spark是如何实现排序的?_第1张图片
TeraSort流程图

TeraSort的核心在于第一步的map(), 这一步 任何一台机器上的Partition i 里的所有对象一定小于 任何一台机器上的Partition i+1,也就是保证了Parition之间的有序性. 继而在reduce阶段, 可以保证shuffle后每个任务收集到的数据的有序性.

这里可以非常直观的看到两个难点

  1. 如何确定每个Partition的范围, 它负责的Range(X, Y)里的X和Y是多少
  2. 如何快速的把一个值映射到它对应的Partition里, 这里需要考虑待排序的是任何实现了Comparable接口的对象. 不一定是个数.

抽象的解决思路是

  1. 对数据进行抽样, 根据抽样结果来构筑每个Partition应该承载什么范围内的数据
  2. 通过Trie Tree来构筑索引, 当一个String或者Long或者任何能够被转义成Char Sequence的对象进来后, 利用Trie来找到它对应的那个Partition.实现中, 对字典树有微弱高的改造, 类似下图中daz会被分到Parition3, 在最后一层中z > b
    Spark是如何实现排序的?_第2张图片
    image.png

Spark源码

执行结构

2.RDD, 执行入口

/spark/core/src/main/scala/org/apache/spark/rdd/RDD.scala

  /**
   * 按照输入的key function, 对这个RDD进行排序
   */
  def sortBy[K](
      // f 执行在key上的 funtion, 返回K型对象, 这里K需要时可以compare的
      f: (T) => K,

      // 默认是正序
      ascending: Boolean = true,

      // 维持当前的partition数量, 这个对抽样后到底怎么分区有影响
      numPartitions: Int = this.partitions.length)

      // 可以看到这里对K的类型进行了隐式转换
      // 保证它是scala.math.Ordering接口兼容的, 以便能够排序
      (implicit ord: Ordering[K], ctag: ClassTag[K]): RDD[T] = withScope {
    this.keyBy[K](f)

         // 对所有的Key进行处理后, 就可以运行排序了, 排序方法在下面
        .sortByKey(ascending, numPartitions)
        .values
  }

3. OrderedRDDFunctions 调用方法

/spark/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala

/**
 * 输入的KEY是能够支持Scala Math排序的
 * 对于没有实现对应接口的.用户可以自己实现, 或者对已有的对象覆盖自己的逻辑
 */
class OrderedRDDFunctions[K : Ordering : ClassTag,
                          V: ClassTag,
                          P <: Product2[K, V] : ClassTag] @DeveloperApi() (
    self: RDD[P])
  extends Logging with Serializable {
  private val ordering = implicitly[Ordering[K]]

  /**
   * 实现了对每个partition执行sort, 由于partition相互之间是有序的
   * 调用`collect`或者`save`可以获得全局有序的对象.
   */
  // TODO: this currently doesn't work on P other than Tuple2!
  def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
      : RDD[(K, V)] = self.withScope
  {
    // 初始化一个RangePartitioner对象, 这个对象负责管理告诉RDD应该如何分配数据
    // 以及每个Range应该是多少
    val part = new RangePartitioner(numPartitions, self, ascending)
    // 对数据进行分片, 然后每片内部再进行排序
    new ShuffledRDD[K, V, V](self, part)
      .setKeyOrdering(if (ascending) ordering else ordering.reverse)
  }

  /**
   * 使用传入的Partition分区方法来切割数据, 然后每个Partition内部再排序
   * 这个方法在特定条件下可以用customer的方法来提升TeraSort的性能
   * 相关论文很多, 核心思想主要是提升locality, 或者针对已经部分有效的数据,直接增加分配的有效性.
   */
  def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = ...
  }

  /**
   *  由于RangePartition中有每个partition的最大值和最小值, 也就是Range的信息
   *  所以给定一个lower和一个uppder值, 可以快速的用getPartition方法定位到最小和最大的PartitionIndex是多少. 从而实现快速的过滤.
   */
  def filterByRange(lower: K, upper: K): RDD[P]  = ...
}

4. Partitioner 实现有序分片

Partitioner在实现中有 HashPartitioner RangePartitioner, 后者直接对应需要内部有序的各种情况.

/spark/core/src/main/scala/org/apache/spark/Partitioner.scala

/**
 * 通过抽样, 把对象映射到RANGE范围大致相同的分片里. 
 * 分片多少和输入的分片数, 以及采样数有关
 */
class RangePartitioner[K : Ordering : ClassTag, V](
    // 期望的分片
    partitions: Int,

    // 这里对RDD进行了约束
    rdd: RDD[_ <: Product2[K, V]],

    // 默认正序
    private var ascending: Boolean = true,
    
    // 默认采样20
    val samplePointsPerPartitionHint: Int = 20)
  extends Partitioner {

  // 构造函数 
  def this(partitions: Int, rdd: RDD[_ <: Product2[K, V]], ascending: Boolean) = ...

  // 条件检查
  require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.")
  require(samplePointsPerPartitionHint > 0,
    s"Sample points per partition must be greater than 0 but found $samplePointsPerPartitionHint")

  private var ordering = implicitly[Ordering[K]]

  // 计算每个Partition应该存储的Range
  private var rangeBounds: Array[K] = {
    if (partitions <= 1) {
      Array.empty
    } else {
      // 确定最大取样数, 封顶1M
      val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6)

      // 假设Dependecy RDD中各个partition里的items数量是大致相同的, 采用常规的采样
      val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt
      val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition)
      if (numItems == 0L) {
        Array.empty
      } else {
        // 如果分片数据倾斜的太严重, 就需要对这个分片做重新采样
        val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
        val candidates = ArrayBuffer.empty[(K, Float)]
        val imbalancedPartitions = mutable.Set.empty[Int]
        sketched.foreach { case (idx, n, sample) =>
          if (fraction * n > sampleSizePerPartition) {
            imbalancedPartitions += idx
          } else {
            // The weight is 1 over the sampling probability.
            val weight = (n.toDouble / sample.length).toFloat
            for (key <- sample) {
              candidates += ((key, weight))
            }
          }
        }
        if (imbalancedPartitions.nonEmpty) {
          // Re-sample imbalanced partitions with the desired sampling probability.
          val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
          val seed = byteswap32(-rdd.id - 1)
          val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
          val weight = (1.0 / fraction).toFloat
          candidates ++= reSampled.map(x => (x, weight))
        }
        RangePartitioner.determineBounds(candidates, math.min(partitions, candidates.size))
      }
    }
  }

  def numPartitions: Int = rangeBounds.length + 1

  // 利用二分查找用来快速的定位分片
  private var binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]

  def getPartition(key: Any): Int = 
    val k = key.asInstanceOf[K]
    var partition = 0
    if (rangeBounds.length <= 128) {
      // 直接顺序查找
    } else {
      // 利用二分查找寻找partition, 在实现中需要考虑几个细节:  小于第一个分片的range, 大于最后一个分片的range, 以及倒序排列.
  }
  
}

你可能感兴趣的:(Spark是如何实现排序的?)