Spark 中 Aggregate 的实现

使用数据准备

Seq((1, 1, 1, 1), (1, 2, 1, 1)).toDF("a", "b", "c", "d").createTempView("v")

无 Distinct Aggregate

Query : SELECT count(b) FROM v GROUP BY a

== Analyzed Logical Plan ==
count(b): string
GlobalLimit 21
+- LocalLimit 21
   +- Project [cast(count(b)#296L as string) AS count(b)#299]
      +- Aggregate [a#287], [count(b#288) AS count(b)#296L]
         +- SubqueryAlias v
            +- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]
               +- LocalRelation [_1#278, _2#279, _3#280, _4#281]

== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21
   +- Aggregate [a#287], [cast(count(1) as string) AS count(b)#299]
      +- Project [_1#278 AS a#287]
         +- LocalRelation [_1#278, _2#279, _3#280, _4#281]

== Physical Plan ==
CollectLimit 21
+- *(2) HashAggregate(keys=[a#287], functions=[count(1)], output=[count(b)#299])
   +- Exchange hashpartitioning(a#287, 5), ENSURE_REQUIREMENTS, [id=#108]
      +- *(1) HashAggregate(keys=[a#287], functions=[partial_count(1)], output=[a#287, count#302L])
         +- *(1) Project [_1#278 AS a#287]
            +- *(1) LocalTableScan [_1#278, _2#279, _3#280, _4#281]

生成物理Aggregate

  • 将原来的 groupingExpressions 全部转换为 NamedExpression,得到新的 namedGroupingExpressions

  • 过滤其中的 AggregateExpression 得到 aggregateExpressions = count(key#13)

  • 所有result expressions 使用新的 namedGroupingExpressions 和 AggregateExpression 进行替换

One Distinct Aggregate

Query: SELECT count(distinct b),sum(c) FROM v GROUP BY a

Plan

== Analyzed Logical Plan ==
count(DISTINCT b): string, sum(c): string
GlobalLimit 21
+- LocalLimit 21
   +- Project [cast(count(DISTINCT b)#297L as string) AS count(DISTINCT b)#303, cast(sum(c)#298L as string) AS sum(c)#304]
      +- Aggregate [a#287], [count(distinct b#288) AS count(DISTINCT b)#297L, sum(c#289, None) AS sum(c)#298L]
         +- SubqueryAlias v
            +- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]
               +- LocalRelation [_1#278, _2#279, _3#280, _4#281]

== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21
   +- Aggregate [a#287], [cast(count(distinct b#288) as string) AS count(DISTINCT b)#303, cast(sum(c#289, None) as string) AS sum(c)#304]
      +- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289]
         +- LocalRelation [_1#278, _2#279, _3#280, _4#281]

== Physical Plan ==
CollectLimit 21
+- *(3) HashAggregate(keys=[a#287], functions=[sum(c#289, None), count(distinct b#288)], output=[count(DISTINCT b)#303, sum(c)#304])
   +- Exchange hashpartitioning(a#287, 5), ENSURE_REQUIREMENTS, [id=#123]
      +- *(2) HashAggregate(keys=[a#287], functions=[merge_sum(c#289, None), partial_count(distinct b#288)], output=[a#287, sum#308L, count#311L])
         +- *(2) HashAggregate(keys=[a#287, b#288], functions=[merge_sum(c#289, None)], output=[a#287, b#288, sum#308L])
            +- Exchange hashpartitioning(a#287, b#288, 5), ENSURE_REQUIREMENTS, [id=#118]
               +- *(1) HashAggregate(keys=[a#287, b#288], functions=[partial_sum(c#289, None)], output=[a#287, b#288, sum#308L])
                  +- *(1) Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289]
                     +- *(1) LocalTableScan [_1#278, _2#279, _3#280, _4#281]

四个 HashAggregate 说明:

  • Agg1 和 Agg2 : 根据 group by 字段和 distinct 字段进行聚合分组,同时计算 regular aggregate expression
  • Agg3 和 Agg4 : 根据 group by 字段进行聚合分组,计算regular aggregate expression 和 distinct aggregate expression
partialAggregate (
  groupingExpressions = groupingExpressions ++ distinctExpressions,
  aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
)
          |
          |
         \|/
partialAggregate (
  requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes),
  groupingExpressions = groupingAttributes ++ distinctAttributes,
  aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
)
          |
          |
         \|/
partialAggregate (
  groupingExpressions = groupingAttributes,  // 全局Grouping attributes
  aggregateExpressions = 
    functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) // 非distinct function的 partial merge 模式
    ++ distinctAggregateExpressions,  // 前面虽然对 groupingAttributes ++ distinctAttributes 去重,但是不保证对 groupingAttributes 去重,所以重写 functionsWithDistinct 依旧保留 distinct, 模式 Partial
  aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes
)
          |
          |
         \|/
finalAggregate (
  requiredChildDistributionExpressions = Some(groupingAttributes), // Shuffle by 全局Grouping attributes
  groupingExpressions = groupingAttributes, 
  aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // 非distinct function的 Final 模式
  ++ distinctAggregateExpressions, // 同上,模式改为 Final
  aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes
)

Aggregate Distinct 多次同一个expresson

Query: SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i

== Optimized Logical Plan ==
Aggregate [i#284], [sum(distinct j#285, None) AS sum(DISTINCT j)#292, max(j#285) AS max(DISTINCT j)#293]
+- Project [_1#277 AS i#284, _2#278 AS j#285]
   +- LocalRelation [_1#277, _2#278, _3#279]

== Physical Plan ==
*(3) HashAggregate(keys=[i#284], functions=[max(j#285), sum(distinct j#296, None)], output=[sum(DISTINCT j)#292, max(DISTINCT j)#293])
+- Exchange hashpartitioning(i#284, 5), ENSURE_REQUIREMENTS, [id=#119]
   +- *(2) HashAggregate(keys=[i#284], functions=[merge_max(j#285), partial_sum(distinct j#296, None)], output=[i#284, max#298, sum#301])
      +- *(2) HashAggregate(keys=[i#284, j#296], functions=[merge_max(j#285)], output=[i#284, j#296, max#298])
         +- Exchange hashpartitioning(i#284, j#296, 5), ENSURE_REQUIREMENTS, [id=#114]
            +- *(1) HashAggregate(keys=[i#284, knownfloatingpointnormalized(normalizenanandzero(j#285)) AS j#296], functions=[partial_max(j#285)], output=[i#284, j#296, max#298])
               +- *(1) Project [_1#277 AS i#284, _2#278 AS j#285]
                  +- *(1) LocalTableScan [_1#277, _2#278, _3#279]

Spark 中 Aggregate 的实现_第1张图片

第一组Aggregate 实际上应该就是对i, j 进行了聚合,此时 j 变成了 distinct value,重命名为 distinct j#296

第二组Aggregate 根据 i 进行聚合,计算 max(j) 和 sum(distinct j)

Aggregate Distinct 多次不同expresson

Spark在生成物理计划的时候只支持处理一个 distinct group,如果distinct group 大于一个,RewriteDistinctAggregates 会重写distinct 为 expand,然后安装无 Distinct 的Aggregate 来处理。

Query: SELECT count(DISTINCT b, c), count(DISTINCT c, d) FROM v GROUP BY a

RewriteDistinctAggregates 处理逻辑:

  • 找出所有 distinct group 以及对应的表达式:

    • Group Set(b#288, c#289) 对应 List(count(distinct b#288, c#289))
    • Group Set(b#288, d#290) 对应 List(count(distinct b#288, d#290))
  • 将distinct group 中所有的key 字段解析出来并重新做映射 distinctAggChildren = (b, c, d),后面我们只要保证根据 groupByAttrs + 这些字段 + gid 去重,就可以保证数据的 Distinct

  • 接着处理第一步的 distinct group, 注意新生成的表达式是从 project 中取的结果

    • Distinct group count(distinct b#288, c#289) 对应的 projection = (b, c, null, gid(1)) 新的 expression = count(if ((gid#307 = 1)) v.b#308 else null, if ((gid#307 = 1)) v.c#309 else null)
    • Distinct group count(distinct c#289, d#290) 对应的 projection = (null, c, d, gid(2)) 新的 expression = count(if ((gid#307 = 2)) v.c#309 else null, if ((gid#307 = 2)) v.d#310 else null)
  • 处理 regularAggExpression

  • 生成 Expand -> firstAggregate -> Aggregate 新的Plan

DAG

Spark 中 Aggregate 的实现_第2张图片

== Analyzed Logical Plan ==
count(DISTINCT b, c): string, count(DISTINCT c, d): string
GlobalLimit 21
+- LocalLimit 21
   +- Project [cast(count(DISTINCT b, c)#297L as string) AS count(DISTINCT b, c)#303, cast(count(DISTINCT c, d)#298L as string) AS count(DISTINCT c, d)#304]
      +- Aggregate [a#287], [count(distinct b#288, c#289) AS count(DISTINCT b, c)#297L, count(distinct c#289, d#290) AS count(DISTINCT c, d)#298L]
         +- SubqueryAlias v
            +- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]
               +- LocalRelation [_1#278, _2#279, _3#280, _4#281]

== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21
   +- Aggregate [a#287], [cast(count(if ((gid#307 = 1)) v.`b`#308 else null, if ((gid#307 = 1)) v.`c`#309 else null) as string) AS count(DISTINCT b, c)#303, cast(count(if ((gid#307 = 2)) v.`c`#309 else null, if ((gid#307 = 2)) v.`d`#310 else null) as string) AS count(DISTINCT c, d)#304]
      +- Aggregate [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307], [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307]
         +- Expand [ArrayBuffer(a#287, b#288, c#289, null, 1), ArrayBuffer(a#287, null, c#289, d#290, 2)], [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307]
            +- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]
               +- LocalRelation [_1#278, _2#279, _3#280, _4#281]

== Physical Plan ==
CollectLimit 21
+- *(3) HashAggregate(keys=[a#287], functions=[count(if ((gid#307 = 1)) v.`b`#308 else null, if ((gid#307 = 1)) v.`c`#309 else null), count(if ((gid#307 = 2)) v.`c`#309 else null, if ((gid#307 = 2)) v.`d`#310 else null)], output=[count(DISTINCT b, c)#303, count(DISTINCT c, d)#304])
   +- Exchange hashpartitioning(a#287, 5), ENSURE_REQUIREMENTS, [id=#128]
      +- *(2) HashAggregate(keys=[a#287], functions=[partial_count(if ((gid#307 = 1)) v.`b`#308 else null, if ((gid#307 = 1)) v.`c`#309 else null), partial_count(if ((gid#307 = 2)) v.`c`#309 else null, if ((gid#307 = 2)) v.`d`#310 else null)], output=[a#287, count#313L, count#314L])
         +- *(2) HashAggregate(keys=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307], functions=[], output=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307])
            +- Exchange hashpartitioning(a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307, 5), ENSURE_REQUIREMENTS, [id=#123]
               +- *(1) HashAggregate(keys=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307], functions=[], output=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307])
                  +- *(1) Expand [ArrayBuffer(a#287, b#288, c#289, null, 1), ArrayBuffer(a#287, null, c#289, d#290, 2)], [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307]
                     +- *(1) Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]
                        +- *(1) LocalTableScan [_1#278, _2#279, _3#280, _4#281]

你可能感兴趣的:(spark,spark,大数据)