aggregate和treeAggregate都是org.apache.spark.rdd
包下的RDD
类的方法。
aggregate
首先来看这个方法的签名
abstract class RDD[T: ClassTag](
@transient private var _sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
...
...
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using
* given combine functions and a neutral "zero value". This function can return a different result
* type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U
* and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are
* allowed to modify and return their first argument instead of creating a new U to avoid memory
* allocation.
*
* @param zeroValue the initial value for the accumulated result of each partition for the
* `seqOp` operator, and also the initial value for the combine results from
* different partitions for the `combOp` operator - this will typically be the
* neutral element (e.g. `Nil` for list concatenation or `0` for summation)
* @param seqOp an operator used to accumulate results within a partition
* @param combOp an associative operator used to combine results from different partitions
*/
def aggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U): U = withScope {...}
...
...
}
可以看到aggregate
方法接受三个参数:aggregate(zeroValue)(seqOp, combOp)
,这里参数被分在两组括号中,这个写法涉及到了柯里化(Currying),感兴趣的同学可以去进一步了解一下。
OK,现在简单翻译一下这个方法的注释:
该方法(function)首先对每个partition的元素执行聚合(aggregate)操作,然后对所有partition的结果再次执行聚合操作。
聚合操作使用了传入参数中的combOp
作为聚合函数(combine functions),使用zeroValue
作为聚合操作中的零元。
本方法返回值的类型U可不同于所属RDD对象的类型:T。
因此我们需要一个函数来将一个T对象转换(merge into)为一个U对象,以及一个函数来将两个U对象合并(merge)为一个U对象。如同scala.TraversableOnce
一样。这两个方法都支持对其第一个入参修改,从而避免频繁申请空间创建新的U对象。
zeroValue
: 零元,seqOp
方法的初始值,也是combOp
方法的初始值(如list拼接中的Nil
、加法中的0
)
seqOp
: 单partition做聚合操作的方法
combOp
: 多partition之间做合并的方法
下面来看一个具体的例子(参考了这篇回答):
scala> val listRDD = spark.sparkContext.parallelize(Seq(1,2,3,4), 2)
scala> def seqOp(localResult: Seq[Int], listElement: Int) = {Seq(localResult(0) + listElement, localResult(1) + 1) }
scala> def combOp(localResultA: Seq[Int], localResultB: Seq[Int]) = {Seq(localResultA(0)+localResultB(0), localResultA(1)+localResultB(1))}
scala> listRDD.aggregate(Seq(0, 0))(seqOp, combOp)
res1: Seq[Int] = List(10, 4)
这里新建了一个序列Seq(1, 2, 3, 4)
,并划分到两个partition中:
partition0: Seq(1, 2)
partition1: Seq(3, 4)
最终想统计一个数对Seq(序列的和, 序列元素个数)
。序列的和为1+2+3+4=10
,序列个数显然是4
。
计算方法如下:
对每个partition:
a. 初始化聚合结果为Seq(0, 0)
b. 对当前partition的序列元素,依次执行聚合操作seqOp
c. 得到当前partition的聚合结果Seq(partition_sum, partition_count)
对所有partition:
a. 依次合并各partition的聚合结果,合并方法为combOp
b. 得到合并结果Seq(total_sum, total_count)
计算过程如下图所示:
(0, 0) <-- zeroValue
[1, 2] [3, 4]
0 + 1 = 1 0 + 3 = 3
0 + 1 = 1 0 + 1 = 1
1 + 2 = 3 3 + 4 = 7
1 + 1 = 2 1 + 1 = 2
| |
v v
(3, 2) (7, 2)
\ /
\ /
\ /
------------
| combOp |
------------
|
v
(10, 4)
这里需要注意,当我们把zeroValue
改为(1, 0)
的时候,我们其实是无法通过上面的图示来预期结果的,结果并不一定会变为(12, 4)
,因为在spark内部计算的时候,可能会多次使用该值做初始化。因此在选择zeroValue
的时候应谨慎。
OK,到此为止我们就了解了aggregate
的使用方法,下面来看treeAggregate
。
treeAggregate
/**
* Aggregates the elements of this RDD in a multi-level tree pattern.
* This method is semantically identical to [[org.apache.spark.rdd.RDD#aggregate]].
*
* @param depth suggested depth of the tree (default: 2)
*/
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = 2): U = withScope {...}
其实基本上和aggregate
是一样的,但是在aggregate
中,需要把各partition的结果汇总发到driver上使用combOp
进行最后一步汇总合并,这里有时会成为瓶颈(带宽、依次遍历各partition结果并合并),而treeAggregate
就是用来优化这一环节的,按照树结构来reduce,提升性能。
treeAggregate
提供了一个新的参数depth
,就是用来指定这个reduce树的深度的,默认为2。
了解了aggregate
和treeAggregate
后,我们就知道了,在实际使用中,尽量还是使用treeAggregate
吧。