malePPL.agg(Map("height" -> "max", "sex" -> "count")).show
数据是
身高 性别
这样的一个组合大概有几百万个值
刚开始是使用reducebykey去做计算, 后来发现网上有agg里面直接进行排序获取值的做法, 特地看了一下为什么传进去一个Map(column -> Expression)就能得到想要的结果
首先还是直接进到agg的方法里面:
/** * (Scala-specific) Aggregates on the entire [[DataFrame]] without groups. * {{{ * // df.agg(...) is a shorthand for df.groupBy().agg(...) * df.agg(Map("age" -> "max", "salary" -> "avg")) * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) * }}} * @group dfops * @since 1.3.0 */ def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)
看到他是执行groupBy返回对象的agg方法, 可以看到groupBy是一个GroupData:
@scala.annotation.varargs def groupBy(cols: Column*): GroupedData = { GroupedData(this, cols.map(_.expr), GroupedData.GroupByType) }
GroupedData的agg方法:
def agg(exprs: Map[String, String]): DataFrame = { toDF(exprs.map { case (colName, expr) => strToExpr(expr)(df(colName).expr) }.toSeq) }
可以看到他是使用toDF方法构建一个DataFrame, 看一下strToExpr里面其实是做了一个unresolvedFunction:
private[this] def strToExpr(expr: String): (Expression => Expression) = { val exprToFunc: (Expression => Expression) = { (inputExpr: Expression) => expr.toLowerCase match { // We special handle a few cases that have alias that are not in function registry. case "avg" | "average" | "mean" => UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) case "stddev" | "std" => UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false) // Also special handle count because we need to take care count(*). case "count" | "size" => // Turn count(*) into count(1) inputExpr match { case s: Star => Count(Literal(1)).toAggregateExpression() case _ => Count(inputExpr).toAggregateExpression() } case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false) } } (inputExpr: Expression) => exprToFunc(inputExpr) }
看一下toDF是怎么写的:
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { groupingExprs ++ aggExprs } else { aggExprs } val aliasedAgg = aggregates.map(alias) groupType match { case GroupedData.GroupByType => DataFrame( df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case GroupedData.RollupType => DataFrame( df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) case GroupedData.CubeType => DataFrame( df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) case GroupedData.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) DataFrame( df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan)) } }
在groupBy方法里面我们其实可以看到传入的grouptype是GroupedData.GroupByType
所以这里会去执行:
case GroupedData.GroupByType =>
DataFrame(
df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
Aggregate方法继承自UnaryNode, 也就是一个LogicPlan
case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override lazy val resolved: Boolean = { val hasWindowExpressions = aggregateExpressions.exists ( _.collect { case window: WindowExpression => window }.nonEmpty ) !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions } override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) }
这个logicplan包含了我们传入的表达式, 比如说hight-> max这样的。 经过这几步后, 一个DataFrame被创建了, 按照之前的那片文章来看, DF会做下面这几步去优化logicplan直到一个可执行的物理计划为止: (包含对unresolvedFunction的优化)
1.通过Sqlparse 转成unresolvedLogicplan
2.通过Analyzer转成 resolvedLogicplan
3.通过optimizer转成 optimzedLogicplan
4.通过sparkplanner转成physicalLogicplan
5.通过prepareForExecution 转成executable logicplan
6.通过toRDD等方法执行executedplan去调用tree的doExecute
既然这样, 那么我们看一下unresolvedFunction是怎么会和max min avg等expression关联起来的, 进入analyzer, 看到SQLContext里面创建Analyzer时候传入了一个registry:
protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = ExtractPythonUDFs :: PreInsertCastAndRename :: (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) override val extendedCheckRules = Seq( datasources.PreWriteCheck(catalog) ) }
在这个FunctionRegistry里面包含了所有的expression:
object FunctionRegistry { type FunctionBuilder = Seq[Expression] => Expression val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( // misc non-aggregate functions expression[Abs]("abs"), expression[CreateArray]("array"), expression[Coalesce]("coalesce"), expression[Explode]("explode"), expression[Greatest]("greatest"), expression[If]("if"), expression[IsNaN]("isnan"), expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), expression[Coalesce]("nvl"), expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), expression[CreateNamedStruct]("named_struct"), expression[Sqrt]("sqrt"), expression[NaNvl]("nanvl"), // math functions expression[Acos]("acos"), expression[Asin]("asin"), expression[Atan]("atan"), expression[Atan2]("atan2"), expression[Bin]("bin"), expression[Cbrt]("cbrt"), expression[Ceil]("ceil"), expression[Ceil]("ceiling"), expression[Cos]("cos"), expression[Cosh]("cosh"), expression[Conv]("conv"), expression[EulerNumber]("e"), expression[Exp]("exp"), expression[Expm1]("expm1"), expression[Floor]("floor"), expression[Factorial]("factorial"), expression[Hypot]("hypot"), expression[Hex]("hex"), expression[Logarithm]("log"), expression[Log]("ln"), expression[Log10]("log10"), expression[Log1p]("log1p"), expression[Log2]("log2"), expression[UnaryMinus]("negative"), expression[Pi]("pi"), expression[Pow]("pow"), expression[Pow]("power"), expression[Pmod]("pmod"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), expression[Round]("round"), expression[ShiftLeft]("shiftleft"), expression[ShiftRight]("shiftright"), expression[ShiftRightUnsigned]("shiftrightunsigned"), expression[Signum]("sign"), expression[Signum]("signum"), expression[Sin]("sin"), expression[Sinh]("sinh"), expression[Tan]("tan"), expression[Tanh]("tanh"), expression[ToDegrees]("degrees"), expression[ToRadians]("radians"), // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), expression[Corr]("corr"), expression[Count]("count"), expression[First]("first"), expression[First]("first_value"), expression[Last]("last"), expression[Last]("last_value"), expression[Max]("max"), expression[Average]("mean"), expression[Min]("min"), expression[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), expression[VarianceSamp]("variance"), expression[VariancePop]("var_pop"), expression[VarianceSamp]("var_samp"), expression[Skewness]("skewness"), expression[Kurtosis]("kurtosis"), // string functions expression[Ascii]("ascii"), expression[Base64]("base64"), expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), expression[Encode]("encode"), expression[Decode]("decode"), expression[FindInSet]("find_in_set"), expression[FormatNumber]("format_number"), expression[GetJsonObject]("get_json_object"), expression[InitCap]("initcap"), expression[JsonTuple]("json_tuple"), expression[Lower]("lcase"), expression[Lower]("lower"), expression[Length]("length"), expression[Levenshtein]("levenshtein"), expression[RegExpExtract]("regexp_extract"), expression[RegExpReplace]("regexp_replace"), expression[StringInstr]("instr"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), expression[FormatString]("format_string"), expression[FormatString]("printf"), expression[StringRPad]("rpad"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), expression[StringTrimRight]("rtrim"), expression[SoundEx]("soundex"), expression[StringSpace]("space"), expression[StringSplit]("split"), expression[Substring]("substr"), expression[Substring]("substring"), expression[SubstringIndex]("substring_index"), expression[StringTranslate]("translate"), expression[StringTrim]("trim"), expression[UnBase64]("unbase64"), expression[Upper]("ucase"), expression[Unhex]("unhex"), expression[Upper]("upper"), // datetime functions expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), expression[CurrentTimestamp]("now"), expression[DateDiff]("datediff"), expression[DateAdd]("date_add"), expression[DateFormatClass]("date_format"), expression[DateSub]("date_sub"), expression[DayOfMonth]("day"), expression[DayOfYear]("dayofyear"), expression[DayOfMonth]("dayofmonth"), expression[FromUnixTime]("from_unixtime"), expression[FromUTCTimestamp]("from_utc_timestamp"), expression[Hour]("hour"), expression[LastDay]("last_day"), expression[Minute]("minute"), expression[Month]("month"), expression[MonthsBetween]("months_between"), expression[NextDay]("next_day"), expression[Quarter]("quarter"), expression[Second]("second"), expression[ToDate]("to_date"), expression[ToUnixTimestamp]("to_unix_timestamp"), expression[ToUTCTimestamp]("to_utc_timestamp"), expression[TruncDate]("trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), // collection functions expression[Size]("size"), expression[SortArray]("sort_array"), expression[ArrayContains]("array_contains"), // misc functions expression[Crc32]("crc32"), expression[Md5]("md5"), expression[Sha1]("sha"), expression[Sha1]("sha1"), expression[Sha2]("sha2"), expression[SparkPartitionID]("spark_partition_id"), expression[InputFileName]("input_file_name"), expression[MonotonicallyIncreasingID]("monotonically_increasing_id") )
这样当Analyzer在执行execute方法, 对所有的node进行Rule的时候, 有一个Rule叫ResolveFunctions, 下面是analyzer里面定义的batch:
lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: ResolveGroupingAnalytics :: ResolvePivot :: ResolveUpCast :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: ResolveAliases :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, PullOutNondeterministic, ComputeCurrentTime), Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, CleanupAliases) )
在ResolveFunctions 是这样定义的:
object ResolveFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { registry.lookupFunction(name, children) match { // DISTINCT is not meaningful for a Max or a Min. case max: Max if isDistinct => AggregateExpression(max, Complete, isDistinct = false) case min: Min if isDistinct => AggregateExpression(min, Complete, isDistinct = false) // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) // This function is not an aggregate function, just return the resolved one. case other => other } } } } }
看到这个方法会对所有的expression进行遍历:
registry.lookupFunction(name, children) match{
...
}
如果我们传入的是max或者min, 或者不属于这两者的, 那么直接就能返回aggregateexpression:
AggregateExpression(max, Complete, isDistinct = false)
AggregateExpression(min, Complete, isDistinct = false)
AggregateExpression(agg, Complete, isDistinct)
这样我们传入的max min就被registryFunction里面的expression代替了, 继续通过其他Rule执行来变成resolvedaggreFunction。
可以看到我们定义的max min或者avg其实在构建DataFrame的时候已经在其最总的执行计划里面了, 就不难理解为什么我们这样传入参数就能得到这些结果。
根据测试结果, 传入agg的expression的方法远比rdd计算获取结果快的多。 目前来看, 如果能用agg这样去获取想要的结果, 那么就不要用rdd去进行计算了。
如果有什么不对的地方, 请指正
ps:可以试一下传入的参数不在registryFunction里面的话会由checkAnalysis(resolvedAggregate)这个方法发现及抛出异常