关于spark中rdd.sortByKey的简单分析

基本介绍

最近在做一个文件archive的事情,其中需要对目录文件下的索引排序,最开始是用的内部归并排序,这在目录里面文件还比较少的时候,没什么大问题;但是发现有一个目录下的文件数太多,无法正常排序,因为那样会OOM;所以就打算先通过rdd里面的sortByKey来先将文件分段排序然后再整合到目标文件中。所以我写下了这么一段代码:

sc.parallelize(data)
.flatMap(dealFunction)
.sortByKey(_._1)
.someOtherOperations

sortByKey 主要用途就是将目标tuples根据key值在不同的range段排序:比如有原始数据((5, 5), (4, 4), (3, 3), (3, 3), (2, 2), (4, 4), (5, 5), (6, 6))。我们希望在两个range partition中排序,那么最终的结果为(((2, 2), (3,3), (3,3)), ((4, 4),(4, 4),(5, 5), (5, 5), (6, 6)))。那么如何确定各个段的边界呢?那么这里就有一个统计学原理知识,就是先对数据抽样,然后根据抽样数据来决定各个段的边界以此来保证段中的数据尽量均匀。
sortByKey的基本用法我就不介绍了,这里主要来讲讲里面的一些具体实现及一些为新手所不能理解的地方:

  1. sortByKey被认为是一个transformation, 但是我在看spark UI的时候,却为sortByKey产生了一个job,因为稍微了解spark的同学都知道只有action才会产生job。
  2. 我的flatMap中的dealFunction函数被调用了两次。

原理分析

经过Google及阅读源码发现:

def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.length)
      : RDD[(K, V)] = self.withScope
  {
    val part = new RangePartitioner(numPartitions, self, ascending)
    new ShuffledRDD[K, V, V](self, part)
      .setKeyOrdering(if (ascending) ordering else ordering.reverse)
  }
def getPartition(key: Any): Int = {
    val k = key.asInstanceOf[K]
    var partition = 0
    if (rangeBounds.length <= 128) {
      // If we have less than 128 partitions naive search
      while (partition < rangeBounds.length && ordering.gt(k, rangeBounds(partition))) {
        partition += 1
      }
    } else {
      // Determine which binary search method to use only once.
      partition = binarySearch(rangeBounds, k)
      // binarySearch either returns the match location or -[insertion point]-1
      if (partition < 0) {
        partition = -partition-1
      }
      if (partition > rangeBounds.length) {
        partition = rangeBounds.length
      }
    }
    if (ascending) {
      partition
    } else {
      rangeBounds.length - partition
    }
  }

在上述代码中我们可以看到在Partitioner 里面的getPartition函数中,是根据key在rangeBounds里面的位置来判断对应的key是处于哪一个range partition中的,那么我们来看一下rangeBounds的生成。其基本思路就是根据一些参数来决定抽样的样本数量,并获取样本数来划分range段的边界。
代码如下:

private var rangeBounds: Array[K] = {
    if (partitions <= 1) {
      Array.empty
    } else {
      // This is the sample size we need to have roughly balanced output partitions, capped at 1M.
      // Cast to double to avoid overflowing ints or longs
      val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6)
      // Assume the input partitions are roughly balanced and over-sample a little bit.
      val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt
      val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) // colect sample map key from target rdd
      if (numItems == 0L) {
        Array.empty
      } else {
        // If a partition contains much more than the average number of items, we re-sample from it
        // to ensure that enough items are collected from that partition.
        val fraction = math.min(sampleSize / math.max(numItems, 1L), 1.0)
        val candidates = ArrayBuffer.empty[(K, Float)]
        val imbalancedPartitions = mutable.Set.empty[Int]
        sketched.foreach { case (idx, n, sample) =>
          if (fraction * n > sampleSizePerPartition) {
            imbalancedPartitions += idx
          } else {
            // The weight is 1 over the sampling probability.
            val weight = (n.toDouble / sample.length).toFloat
            for (key <- sample) {
              candidates += ((key, weight))
            }
          }
        }
        if (imbalancedPartitions.nonEmpty) {
          // Re-sample imbalanced partitions with the desired sampling probability.
          val imbalanced = new PartitionPruningRDD(rdd.map(_._1), imbalancedPartitions.contains)
          val seed = byteswap32(-rdd.id - 1)
          val reSampled = imbalanced.sample(withReplacement = false, fraction, seed).collect()
          val weight = (1.0 / fraction).toFloat
          candidates ++= reSampled.map(x => (x, weight))
        }
        RangePartitioner.determineBounds(candidates, math.min(partitions, candidates.size))
      }
    }
  }

def sketch[K : ClassTag](
      rdd: RDD[K],
      sampleSizePerPartition: Int): (Long, Array[(Int, Long, Array[K])]) = {
    val shift = rdd.id
    // val classTagK = classTag[K] // to avoid serializing the entire partitioner object
    val sketched = rdd.mapPartitionsWithIndex { (idx, iter) =>
      val seed = byteswap32(idx ^ (shift << 16))
      val (sample, n) = SamplingUtils.reservoirSampleAndCount(
        iter, sampleSizePerPartition, seed)
      Iterator((idx, n, sample))
    }.collect() // here generated a job
    val numItems = sketched.map(_._2).sum
    (numItems, sketched)
  }

获取样本sample的函数为sketch,我们可以看到这段代码中有一个collect操作,所以这就不难解释我们的疑惑1了,因为在sample的过程中有一个collect会产生一个job。
那么第二个疑惑是为啥呢?产生job就产生job呗,为啥我之前的flatMap里面的操作你要执行两次?
那这里就要回到spark中的stage概念来了,在spark中一个job会划分为多个stage,而stage的划分是跟wide transformation有关。flatMap是一个narrow transformation,这样的话由于在同一个stage中,所以sortByKey中的sample job会把其所在的stage中的操作跑一遍,而外层的job会把整个所有stage都跑一遍这样你sortByKey所在的stage中的操作就会跑两遍,具体见图:


sortByKey Sample Job
Total Job

在上述两个图中,stage1中的flatMap和stage2中的flatMap其实是同一个flatmap操作,这样就可以解释为啥我的flatMap中的操作为啥执行两次了。

解决方案

如果我的flatMap的操作比较重,都是一些访问文件的操作,那么有什么好的方法可以避免因为sample而导致的两次执行问题吗?
那么这里就可以介绍一下spark中的stage cache:就是在shuffle结束一个stage的时候,spark会cache住stage中的结果数据,这样下一次如果遇到要重新运行该stage的时候可以直接拿最终的结果,而不需要重新运行完整的stage过程。所以结合上图我们可以在stage1中的flatMap后面加一个shuffle操作来拆分一个stage,这样下一次执行stage1的时候就可以直接获取数据了,我们可以通过添加一个repartition来切分一些stage,以保证sortByKey的sample执行时是在一个新的stage中,这样sample job 和 原始job可以复用一个stage1中的数据,
代码如下:

sc.parallelize(data)
.flatMap(dealFunction)
.repartition(partitions)
.sortByKey(_._1)
.someOtherOperations

最终的结果如下:


sortByKey Sample Job
Total Job

注:因为shuffle的过程是比较耗时的,对于内存和IO也是有较高损耗的,目前这个方法是我目前能想到的比较好的解决比较重的重复操作的方法,我玩Spark的时间也比较短,如果大神能能指教一下最好了

shuffle performance-impactl#performance-impact

最后又发现可以在sortByKey之前调用一次cache或者persist进行rdd缓存。

你可能感兴趣的:(关于spark中rdd.sortByKey的简单分析)