spark aggregate & treeAggregate

aggregate和treeAggregate都是org.apache.spark.rdd包下的RDD类的方法。

aggregate

首先来看这个方法的签名

abstract class RDD[T: ClassTag](
    @transient private var _sc: SparkContext,
    @transient private var deps: Seq[Dependency[_]]
  ) extends Serializable with Logging {
...
...
/**
   * Aggregate the elements of each partition, and then the results for all the partitions, using
   * given combine functions and a neutral "zero value". This function can return a different result
   * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U
   * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are
   * allowed to modify and return their first argument instead of creating a new U to avoid memory
   * allocation.
   *
   * @param zeroValue the initial value for the accumulated result of each partition for the
   *                  `seqOp` operator, and also the initial value for the combine results from
   *                  different partitions for the `combOp` operator - this will typically be the
   *                  neutral element (e.g. `Nil` for list concatenation or `0` for summation)
   * @param seqOp an operator used to accumulate results within a partition
   * @param combOp an associative operator used to combine results from different partitions
   */
  def aggregate[U: ClassTag](zeroValue: U)(
      seqOp: (U, T) => U, 
      combOp: (U, U) => U): U = withScope {...}
...
...
}

可以看到aggregate方法接受三个参数:aggregate(zeroValue)(seqOp, combOp),这里参数被分在两组括号中,这个写法涉及到了柯里化(Currying),感兴趣的同学可以去进一步了解一下。

OK,现在简单翻译一下这个方法的注释:

该方法(function)首先对每个partition的元素执行聚合(aggregate)操作,然后对所有partition的结果再次执行聚合操作。
聚合操作使用了传入参数中的combOp作为聚合函数(combine functions),使用zeroValue作为聚合操作中的零元
本方法返回值的类型U可不同于所属RDD对象的类型:T。
因此我们需要一个函数来将一个T对象转换(merge into)为一个U对象,以及一个函数来将两个U对象合并(merge)为一个U对象。如同scala.TraversableOnce一样。这两个方法都支持对其第一个入参修改,从而避免频繁申请空间创建新的U对象。

zeroValue : 零元,seqOp方法的初始值,也是combOp方法的初始值(如list拼接中的Nil、加法中的0
seqOp : 单partition做聚合操作的方法
combOp : 多partition之间做合并的方法

下面来看一个具体的例子(参考了这篇回答):

scala> val listRDD = spark.sparkContext.parallelize(Seq(1,2,3,4), 2)
scala> def seqOp(localResult: Seq[Int], listElement: Int) = {Seq(localResult(0) + listElement, localResult(1) + 1) }
scala> def combOp(localResultA: Seq[Int], localResultB: Seq[Int]) = {Seq(localResultA(0)+localResultB(0), localResultA(1)+localResultB(1))}
scala> listRDD.aggregate(Seq(0, 0))(seqOp, combOp)
res1: Seq[Int] = List(10, 4)

这里新建了一个序列Seq(1, 2, 3, 4),并划分到两个partition中:

partition0: Seq(1, 2)
partition1: Seq(3, 4)

最终想统计一个数对Seq(序列的和, 序列元素个数)。序列的和为1+2+3+4=10,序列个数显然是4
计算方法如下:

  1. 对每个partition:
    a. 初始化聚合结果为Seq(0, 0)
    b. 对当前partition的序列元素,依次执行聚合操作seqOp
    c. 得到当前partition的聚合结果Seq(partition_sum, partition_count)

  2. 对所有partition:
    a. 依次合并各partition的聚合结果,合并方法为combOp
    b. 得到合并结果Seq(total_sum, total_count)

计算过程如下图所示:

            (0, 0) <-- zeroValue

[1, 2]                  [3, 4]

0 + 1 = 1               0 + 3 = 3
0 + 1 = 1               0 + 1 = 1

1 + 2 = 3               3 + 4 = 7
1 + 1 = 2               1 + 1 = 2       
    |                       |
    v                       v
  (3, 2)                  (7, 2)
      \                    / 
        \                /
          \            /
           ------------
           |  combOp  |
           ------------
                |
                v
             (10, 4)

这里需要注意,当我们把zeroValue改为(1, 0)的时候,我们其实是无法通过上面的图示来预期结果的,结果并不一定会变为(12, 4),因为在spark内部计算的时候,可能会多次使用该值做初始化。因此在选择zeroValue的时候应谨慎。
OK,到此为止我们就了解了aggregate的使用方法,下面来看treeAggregate

treeAggregate

/**
   * Aggregates the elements of this RDD in a multi-level tree pattern.
   * This method is semantically identical to [[org.apache.spark.rdd.RDD#aggregate]].
   *
   * @param depth suggested depth of the tree (default: 2)
   */
  def treeAggregate[U: ClassTag](zeroValue: U)(
      seqOp: (U, T) => U,
      combOp: (U, U) => U,
      depth: Int = 2): U = withScope {...}

其实基本上和aggregate是一样的,但是在aggregate中,需要把各partition的结果汇总发到driver上使用combOp进行最后一步汇总合并,这里有时会成为瓶颈(带宽、依次遍历各partition结果并合并),而treeAggregate就是用来优化这一环节的,按照树结构来reduce,提升性能。

treeAggregate提供了一个新的参数depth,就是用来指定这个reduce树的深度的,默认为2。

了解了aggregatetreeAggregate后,我们就知道了,在实际使用中,尽量还是使用treeAggregate吧。

你可能感兴趣的:(spark aggregate & treeAggregate)