Seq((1, 1, 1, 1), (1, 2, 1, 1)).toDF("a", "b", "c", "d").createTempView("v")
Query : SELECT count(b) FROM v GROUP BY a
== Analyzed Logical Plan ==
count(b): string
GlobalLimit 21
+- LocalLimit 21
+- Project [cast(count(b)#296L as string) AS count(b)#299]
+- Aggregate [a#287], [count(b#288) AS count(b)#296L]
+- SubqueryAlias v
+- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]
+- LocalRelation [_1#278, _2#279, _3#280, _4#281]
== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21
+- Aggregate [a#287], [cast(count(1) as string) AS count(b)#299]
+- Project [_1#278 AS a#287]
+- LocalRelation [_1#278, _2#279, _3#280, _4#281]
== Physical Plan ==
CollectLimit 21
+- *(2) HashAggregate(keys=[a#287], functions=[count(1)], output=[count(b)#299])
+- Exchange hashpartitioning(a#287, 5), ENSURE_REQUIREMENTS, [id=#108]
+- *(1) HashAggregate(keys=[a#287], functions=[partial_count(1)], output=[a#287, count#302L])
+- *(1) Project [_1#278 AS a#287]
+- *(1) LocalTableScan [_1#278, _2#279, _3#280, _4#281]
生成物理Aggregate
将原来的 groupingExpressions 全部转换为 NamedExpression,得到新的 namedGroupingExpressions
过滤其中的 AggregateExpression 得到 aggregateExpressions = count(key#13)
所有result expressions 使用新的 namedGroupingExpressions 和 AggregateExpression 进行替换
Query: SELECT count(distinct b),sum(c) FROM v GROUP BY a
Plan
== Analyzed Logical Plan ==
count(DISTINCT b): string, sum(c): string
GlobalLimit 21
+- LocalLimit 21
+- Project [cast(count(DISTINCT b)#297L as string) AS count(DISTINCT b)#303, cast(sum(c)#298L as string) AS sum(c)#304]
+- Aggregate [a#287], [count(distinct b#288) AS count(DISTINCT b)#297L, sum(c#289, None) AS sum(c)#298L]
+- SubqueryAlias v
+- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]
+- LocalRelation [_1#278, _2#279, _3#280, _4#281]
== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21
+- Aggregate [a#287], [cast(count(distinct b#288) as string) AS count(DISTINCT b)#303, cast(sum(c#289, None) as string) AS sum(c)#304]
+- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289]
+- LocalRelation [_1#278, _2#279, _3#280, _4#281]
== Physical Plan ==
CollectLimit 21
+- *(3) HashAggregate(keys=[a#287], functions=[sum(c#289, None), count(distinct b#288)], output=[count(DISTINCT b)#303, sum(c)#304])
+- Exchange hashpartitioning(a#287, 5), ENSURE_REQUIREMENTS, [id=#123]
+- *(2) HashAggregate(keys=[a#287], functions=[merge_sum(c#289, None), partial_count(distinct b#288)], output=[a#287, sum#308L, count#311L])
+- *(2) HashAggregate(keys=[a#287, b#288], functions=[merge_sum(c#289, None)], output=[a#287, b#288, sum#308L])
+- Exchange hashpartitioning(a#287, b#288, 5), ENSURE_REQUIREMENTS, [id=#118]
+- *(1) HashAggregate(keys=[a#287, b#288], functions=[partial_sum(c#289, None)], output=[a#287, b#288, sum#308L])
+- *(1) Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289]
+- *(1) LocalTableScan [_1#278, _2#279, _3#280, _4#281]
四个 HashAggregate 说明:
partialAggregate (
groupingExpressions = groupingExpressions ++ distinctExpressions,
aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial))
)
|
|
\|/
partialAggregate (
requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes),
groupingExpressions = groupingAttributes ++ distinctAttributes,
aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
)
|
|
\|/
partialAggregate (
groupingExpressions = groupingAttributes, // 全局Grouping attributes
aggregateExpressions =
functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) // 非distinct function的 partial merge 模式
++ distinctAggregateExpressions, // 前面虽然对 groupingAttributes ++ distinctAttributes 去重,但是不保证对 groupingAttributes 去重,所以重写 functionsWithDistinct 依旧保留 distinct, 模式 Partial
aggregateAttributes = mergeAggregateAttributes ++ distinctAggregateAttributes
)
|
|
\|/
finalAggregate (
requiredChildDistributionExpressions = Some(groupingAttributes), // Shuffle by 全局Grouping attributes
groupingExpressions = groupingAttributes,
aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // 非distinct function的 Final 模式
++ distinctAggregateExpressions, // 同上,模式改为 Final
aggregateAttributes = finalAggregateAttributes ++ distinctAggregateAttributes
)
Query: SELECT sum(DISTINCT j), max(DISTINCT j) FROM v GROUP BY i
== Optimized Logical Plan ==
Aggregate [i#284], [sum(distinct j#285, None) AS sum(DISTINCT j)#292, max(j#285) AS max(DISTINCT j)#293]
+- Project [_1#277 AS i#284, _2#278 AS j#285]
+- LocalRelation [_1#277, _2#278, _3#279]
== Physical Plan ==
*(3) HashAggregate(keys=[i#284], functions=[max(j#285), sum(distinct j#296, None)], output=[sum(DISTINCT j)#292, max(DISTINCT j)#293])
+- Exchange hashpartitioning(i#284, 5), ENSURE_REQUIREMENTS, [id=#119]
+- *(2) HashAggregate(keys=[i#284], functions=[merge_max(j#285), partial_sum(distinct j#296, None)], output=[i#284, max#298, sum#301])
+- *(2) HashAggregate(keys=[i#284, j#296], functions=[merge_max(j#285)], output=[i#284, j#296, max#298])
+- Exchange hashpartitioning(i#284, j#296, 5), ENSURE_REQUIREMENTS, [id=#114]
+- *(1) HashAggregate(keys=[i#284, knownfloatingpointnormalized(normalizenanandzero(j#285)) AS j#296], functions=[partial_max(j#285)], output=[i#284, j#296, max#298])
+- *(1) Project [_1#277 AS i#284, _2#278 AS j#285]
+- *(1) LocalTableScan [_1#277, _2#278, _3#279]
第一组Aggregate 实际上应该就是对i, j 进行了聚合,此时 j 变成了 distinct value,重命名为 distinct j#296
。
第二组Aggregate 根据 i 进行聚合,计算 max(j) 和 sum(distinct j)
Spark在生成物理计划的时候只支持处理一个 distinct group,如果distinct group 大于一个,RewriteDistinctAggregates
会重写distinct 为 expand,然后安装无 Distinct 的Aggregate 来处理。
Query: SELECT count(DISTINCT b, c), count(DISTINCT c, d) FROM v GROUP BY a
RewriteDistinctAggregates
处理逻辑:找出所有 distinct group 以及对应的表达式:
Set(b#288, c#289)
对应 List(count(distinct b#288, c#289))
Set(b#288, d#290)
对应 List(count(distinct b#288, d#290))
将distinct group 中所有的key 字段解析出来并重新做映射 distinctAggChildren = (b, c, d)
,后面我们只要保证根据 groupByAttrs + 这些字段 + gid
去重,就可以保证数据的 Distinct
接着处理第一步的 distinct group, 注意新生成的表达式是从 project 中取的结果
count(distinct b#288, c#289)
对应的 projection = (b, c, null, gid(1))
新的 expression = count(if ((gid#307 = 1)) v.b#308 else null, if ((gid#307 = 1)) v.c#309 else null)
count(distinct c#289, d#290)
对应的 projection = (null, c, d, gid(2))
新的 expression = count(if ((gid#307 = 2)) v.c#309 else null, if ((gid#307 = 2)) v.d#310 else null)
处理 regularAggExpression
生成 Expand -> firstAggregate -> Aggregate 新的Plan
== Analyzed Logical Plan ==
count(DISTINCT b, c): string, count(DISTINCT c, d): string
GlobalLimit 21
+- LocalLimit 21
+- Project [cast(count(DISTINCT b, c)#297L as string) AS count(DISTINCT b, c)#303, cast(count(DISTINCT c, d)#298L as string) AS count(DISTINCT c, d)#304]
+- Aggregate [a#287], [count(distinct b#288, c#289) AS count(DISTINCT b, c)#297L, count(distinct c#289, d#290) AS count(DISTINCT c, d)#298L]
+- SubqueryAlias v
+- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]
+- LocalRelation [_1#278, _2#279, _3#280, _4#281]
== Optimized Logical Plan ==
GlobalLimit 21
+- LocalLimit 21
+- Aggregate [a#287], [cast(count(if ((gid#307 = 1)) v.`b`#308 else null, if ((gid#307 = 1)) v.`c`#309 else null) as string) AS count(DISTINCT b, c)#303, cast(count(if ((gid#307 = 2)) v.`c`#309 else null, if ((gid#307 = 2)) v.`d`#310 else null) as string) AS count(DISTINCT c, d)#304]
+- Aggregate [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307], [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307]
+- Expand [ArrayBuffer(a#287, b#288, c#289, null, 1), ArrayBuffer(a#287, null, c#289, d#290, 2)], [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307]
+- Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]
+- LocalRelation [_1#278, _2#279, _3#280, _4#281]
== Physical Plan ==
CollectLimit 21
+- *(3) HashAggregate(keys=[a#287], functions=[count(if ((gid#307 = 1)) v.`b`#308 else null, if ((gid#307 = 1)) v.`c`#309 else null), count(if ((gid#307 = 2)) v.`c`#309 else null, if ((gid#307 = 2)) v.`d`#310 else null)], output=[count(DISTINCT b, c)#303, count(DISTINCT c, d)#304])
+- Exchange hashpartitioning(a#287, 5), ENSURE_REQUIREMENTS, [id=#128]
+- *(2) HashAggregate(keys=[a#287], functions=[partial_count(if ((gid#307 = 1)) v.`b`#308 else null, if ((gid#307 = 1)) v.`c`#309 else null), partial_count(if ((gid#307 = 2)) v.`c`#309 else null, if ((gid#307 = 2)) v.`d`#310 else null)], output=[a#287, count#313L, count#314L])
+- *(2) HashAggregate(keys=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307], functions=[], output=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307])
+- Exchange hashpartitioning(a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307, 5), ENSURE_REQUIREMENTS, [id=#123]
+- *(1) HashAggregate(keys=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307], functions=[], output=[a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307])
+- *(1) Expand [ArrayBuffer(a#287, b#288, c#289, null, 1), ArrayBuffer(a#287, null, c#289, d#290, 2)], [a#287, v.`b`#308, v.`c`#309, v.`d`#310, gid#307]
+- *(1) Project [_1#278 AS a#287, _2#279 AS b#288, _3#280 AS c#289, _4#281 AS d#290]
+- *(1) LocalTableScan [_1#278, _2#279, _3#280, _4#281]