spark2以后对limit的优化和存在问题

问题

假如我们在spark-shell上执行:
spark.sql("select * from table limit 1000").collect()
spark会开多少多个任务去跑这个任务呢?

实验

OK,我们来做一个实验吧!


job

通过实验结果我们可以看到就开了一个Task执行,but, 是这样的嘛?
其实开多少Task还真不是固定的,这个取决于我们take的条数和这张表底层每个分区数据量的大小,怎么说呢,我们举个。
首先spark2后,spark默认会先去读取一个分区的数据,假如我limit 1000条,那我就从这个分区去取1000条数据就好了,但是如果这个分区的数据不过1000条怎么办,这时spark会通一个公式去计算出下次读取的分区个数。

limit 操作最终会调用 SparkPlan.executeTake(n: Int) 来获取至多 n 条 records, 待我贴出源码

def executeTake(n: Int): Array[InternalRow] = {
    if (n == 0) {
      return new Array[InternalRow](0)
    }

    val childRDD = getByteArrayRdd(n).map(_._2)

    val buf = new ArrayBuffer[InternalRow]
    val totalParts = childRDD.partitions.length
    var partsScanned = 0
    # 通过while循环去runJob获取records, 直到获取的records达到take条数
    while (buf.size < n && partsScanned < totalParts) {
      // The number of partitions to try in this iteration. It is ok for this number to be
      // greater than totalParts because we actually cap it at totalParts in runJob.
      var numPartsToTry = 1L
      if (partsScanned > 0) {
        // If we didn't find any rows after the previous iteration, quadruple and retry.
        // Otherwise, interpolate the number of partitions we need to try, but overestimate
        // it by 50%. We also cap the estimation in the end.
        val limitScaleUpFactor = Math.max(sqlContext.conf.limitScaleUpFactor, 2)
        if (buf.isEmpty) {
          numPartsToTry = partsScanned * limitScaleUpFactor
        } else {
          val left = n - buf.size
          // As left > 0, numPartsToTry is always >= 1
          numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt
          numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
        }
      }

      val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
      val sc = sqlContext.sparkContext
      val res = sc.runJob(childRDD,
        (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty[Byte], p)

      buf ++= res.flatMap(decodeUnsafeRows)

      partsScanned += p.size
    }

    if (buf.size > n) {
      buf.take(n).toArray
    } else {
      buf.toArray
    }
  }

默认情况下每次 runJob 扫描的 partitions 数:

1
4
20
100
500
2500
6875

通过读取的partitions的个数我们可以发现最初读取的partition数量太少,后面读取的partition数据量太多。

其实这边我们可以通过计算每次读取partitions得到的records估算出下去应该读取的分区,这样会比较靠谱些。

你可能感兴趣的:(spark2以后对limit的优化和存在问题)