序
本文主要研究一下flink Table的AggregateFunction
实例
/**
* Accumulator for WeightedAvg.
*/
public static class WeightedAvgAccum {
public long sum = 0;
public int count = 0;
}
/**
* Weighted Average user-defined aggregate function.
*/
public static class WeightedAvg extends AggregateFunction {
@Override
public WeightedAvgAccum createAccumulator() {
return new WeightedAvgAccum();
}
@Override
public Long getValue(WeightedAvgAccum acc) {
if (acc.count == 0) {
return 0L;
} else {
return acc.sum / acc.count;
}
}
public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) {
acc.sum += iValue * iWeight;
acc.count += iWeight;
}
public void retract(WeightedAvgAccum acc, long iValue, int iWeight) {
acc.sum -= iValue * iWeight;
acc.count -= iWeight;
}
public void merge(WeightedAvgAccum acc, Iterable it) {
Iterator iter = it.iterator();
while (iter.hasNext()) {
WeightedAvgAccum a = iter.next();
acc.count += a.count;
acc.sum += a.sum;
}
}
public void resetAccumulator(WeightedAvgAccum acc) {
acc.count = 0;
acc.sum = 0L;
}
}
// register function
BatchTableEnvironment tEnv = ...
tEnv.registerFunction("wAvg", new WeightedAvg());
// use function
tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user");
复制代码
- WeightedAvg继承了AggregateFunction,实现了getValue、accumulate、retract、merge、resetAccumulator方法
AggregateFunction
flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/functions/AggregateFunction.scala
abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
/**
* Creates and init the Accumulator for this [[AggregateFunction]].
*
* @return the accumulator with the initial value
*/
def createAccumulator(): ACC
/**
* Called every time when an aggregation result should be materialized.
* The returned value could be either an early and incomplete result
* (periodically emitted as data arrive) or the final result of the
* aggregation.
*
* @param accumulator the accumulator which contains the current
* aggregated results
* @return the aggregation result
*/
def getValue(accumulator: ACC): T
/**
* Returns true if this AggregateFunction can only be applied in an OVER window.
*
* @return true if the AggregateFunction requires an OVER window, false otherwise.
*/
def requiresOver: Boolean = false
/**
* Returns the TypeInformation of the AggregateFunction's result.
*
* @return The TypeInformation of the AggregateFunction's result or null if the result type
* should be automatically inferred.
*/
def getResultType: TypeInformation[T] = null
/**
* Returns the TypeInformation of the AggregateFunction's accumulator.
*
* @return The TypeInformation of the AggregateFunction's accumulator or null if the
* accumulator type should be automatically inferred.
*/
def getAccumulatorType: TypeInformation[ACC] = null
}
复制代码
- AggregateFunction继承了UserDefinedFunction;它有两个泛型,一个T表示value的泛型,一个ACC表示Accumulator的泛型;它定义了createAccumulator、getValue、getResultType、getAccumulatorType方法(
这几个方法中子类必须实现createAccumulator、getValue方法
) - 对于AggregateFunction,有一个accumulate方法这里没定义,但是需要子类定义及实现,该方法接收ACC,T等参数,返回void;另外还有retract、merge、resetAccumulator三个方法是可选的,需要子类根据情况去定义及实现
- 对于datastream bounded over aggregate操作,要求实现restract方法,该方法接收ACC,T等参数,返回void;对于datastream session window grouping aggregate以及dataset grouping aggregate操作,要求实现merge方法,该方法接收ACC,java.lang.Iterable
两个参数,返回void;对于dataset grouping aggregate操作,要求实现resetAccumulator方法,该方法接收ACC参数,返回void
DataSetPreAggFunction
flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala
class DataSetPreAggFunction(genAggregations: GeneratedAggregationsFunction)
extends AbstractRichFunction
with GroupCombineFunction[Row, Row]
with MapPartitionFunction[Row, Row]
with Compiler[GeneratedAggregations]
with Logging {
private var output: Row = _
private var accumulators: Row = _
private var function: GeneratedAggregations = _
override def open(config: Configuration) {
LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
s"Code:\n$genAggregations.code")
val clazz = compile(
getRuntimeContext.getUserCodeClassLoader,
genAggregations.name,
genAggregations.code)
LOG.debug("Instantiating AggregateHelper.")
function = clazz.newInstance()
output = function.createOutputRow()
accumulators = function.createAccumulators()
}
override def combine(values: Iterable[Row], out: Collector[Row]): Unit = {
// reset accumulators
function.resetAccumulator(accumulators)
val iterator = values.iterator()
var record: Row = null
while (iterator.hasNext) {
record = iterator.next()
// accumulate
function.accumulate(accumulators, record)
}
// set group keys and accumulators to output
function.setAggregationResults(accumulators, output)
function.setForwardedFields(record, output)
out.collect(output)
}
override def mapPartition(values: Iterable[Row], out: Collector[Row]): Unit = {
combine(values, out)
}
}
复制代码
- DataSetPreAggFunction的combine方法会调用function.accumulate(accumulators, record),其中accumulators为Row[WeightedAvgAccum]类型,record为Row类型;function为生成的类,它继承了GeneratedAggregations,其code在genAggregations中,而genAggregations则由AggregateUtil.createDataSetAggregateFunctions方法生成,它会去调用WeightedAvg的accumulate方法
GeneratedAggregations
flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala
abstract class GeneratedAggregations extends Function {
/**
* Setup method for [[org.apache.flink.table.functions.AggregateFunction]].
* It can be used for initialization work. By default, this method does nothing.
*
* @param ctx The runtime context.
*/
def open(ctx: RuntimeContext)
/**
* Sets the results of the aggregations (partial or final) to the output row.
* Final results are computed with the aggregation function.
* Partial results are the accumulators themselves.
*
* @param accumulators the accumulators (saved in a row) which contains the current
* aggregated results
* @param output output results collected in a row
*/
def setAggregationResults(accumulators: Row, output: Row)
/**
* Copies forwarded fields, such as grouping keys, from input row to output row.
*
* @param input input values bundled in a row
* @param output output results collected in a row
*/
def setForwardedFields(input: Row, output: Row)
/**
* Accumulates the input values to the accumulators.
*
* @param accumulators the accumulators (saved in a row) which contains the current
* aggregated results
* @param input input values bundled in a row
*/
def accumulate(accumulators: Row, input: Row)
/**
* Retracts the input values from the accumulators.
*
* @param accumulators the accumulators (saved in a row) which contains the current
* aggregated results
* @param input input values bundled in a row
*/
def retract(accumulators: Row, input: Row)
/**
* Initializes the accumulators and save them to a accumulators row.
*
* @return a row of accumulators which contains the aggregated results
*/
def createAccumulators(): Row
/**
* Creates an output row object with the correct arity.
*
* @return an output row object with the correct arity.
*/
def createOutputRow(): Row
/**
* Merges two rows of accumulators into one row.
*
* @param a First row of accumulators
* @param b The other row of accumulators
* @return A row with the merged accumulators of both input rows.
*/
def mergeAccumulatorsPair(a: Row, b: Row): Row
/**
* Resets all the accumulators.
*
* @param accumulators the accumulators (saved in a row) which contains the current
* aggregated results
*/
def resetAccumulator(accumulators: Row)
/**
* Cleanup for the accumulators.
*/
def cleanup()
/**
* Tear-down method for [[org.apache.flink.table.functions.AggregateFunction]].
* It can be used for clean up work. By default, this method does nothing.
*/
def close()
}
复制代码
- GeneratedAggregations定义了accumulate(accumulators: Row, input: Row)、resetAccumulator(accumulators: Row)等方法
AggregateUtil
flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
object AggregateUtil {
type CalcitePair[T, R] = org.apache.calcite.util.Pair[T, R]
type JavaList[T] = java.util.List[T]
//......
/**
* Create functions to compute a [[org.apache.flink.table.plan.nodes.dataset.DataSetAggregate]].
* If all aggregation functions support pre-aggregation, a pre-aggregation function and the
* respective output type are generated as well.
*/
private[flink] def createDataSetAggregateFunctions(
generator: AggregationCodeGenerator,
namedAggregates: Seq[CalcitePair[AggregateCall, String]],
inputType: RelDataType,
inputFieldTypeInfo: Seq[TypeInformation[_]],
outputType: RelDataType,
groupings: Array[Int],
tableConfig: TableConfig): (
Option[DataSetPreAggFunction],
Option[TypeInformation[Row]],
Either[DataSetAggFunction, DataSetFinalAggFunction]) = {
val needRetract = false
val (aggInFields, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
namedAggregates.map(_.getKey),
inputType,
needRetract,
tableConfig)
val (gkeyOutMapping, aggOutMapping) = getOutputMappings(
namedAggregates,
groupings,
inputType,
outputType
)
val aggOutFields = aggOutMapping.map(_._1)
if (doAllSupportPartialMerge(aggregates)) {
// compute preaggregation type
val preAggFieldTypes = gkeyOutMapping.map(_._2)
.map(inputType.getFieldList.get(_).getType)
.map(FlinkTypeFactory.toTypeInfo) ++ accTypes
val preAggRowType = new RowTypeInfo(preAggFieldTypes: _*)
val genPreAggFunction = generator.generateAggregations(
"DataSetAggregatePrepareMapHelper",
inputFieldTypeInfo,
aggregates,
aggInFields,
aggregates.indices.map(_ + groupings.length).toArray,
isDistinctAggs,
isStateBackedDataViews = false,
partialResults = true,
groupings,
None,
groupings.length + aggregates.length,
needRetract,
needMerge = false,
needReset = true,
None
)
// compute mapping of forwarded grouping keys
val gkeyMapping: Array[Int] = if (gkeyOutMapping.nonEmpty) {
val gkeyOutFields = gkeyOutMapping.map(_._1)
val mapping = Array.fill[Int](gkeyOutFields.max + 1)(-1)
gkeyOutFields.zipWithIndex.foreach(m => mapping(m._1) = m._2)
mapping
} else {
new Array[Int](0)
}
val genFinalAggFunction = generator.generateAggregations(
"DataSetAggregateFinalHelper",
inputFieldTypeInfo,
aggregates,
aggInFields,
aggOutFields,
isDistinctAggs,
isStateBackedDataViews = false,
partialResults = false,
gkeyMapping,
Some(aggregates.indices.map(_ + groupings.length).toArray),
outputType.getFieldCount,
needRetract,
needMerge = true,
needReset = true,
None
)
(
Some(new DataSetPreAggFunction(genPreAggFunction)),
Some(preAggRowType),
Right(new DataSetFinalAggFunction(genFinalAggFunction))
)
}
else {
val genFunction = generator.generateAggregations(
"DataSetAggregateHelper",
inputFieldTypeInfo,
aggregates,
aggInFields,
aggOutFields,
isDistinctAggs,
isStateBackedDataViews = false,
partialResults = false,
groupings,
None,
outputType.getFieldCount,
needRetract,
needMerge = false,
needReset = true,
None
)
(
None,
None,
Left(new DataSetAggFunction(genFunction))
)
}
}
//......
}
复制代码
- AggregateUtil的createDataSetAggregateFunctions方法主要是生成GeneratedAggregationsFunction,然后创建DataSetPreAggFunction或DataSetAggFunction;之所以动态生成code,主要是用户自定义的诸如accumulate方法的参数是动态的,而flink代码是基于GeneratedAggregations定义的accumulate(accumulators: Row, input: Row)方法来调用,因此动态生成的code用于适配,在accumulate(accumulators: Row, input: Row)方法里头将Row转换为调用用户定义的accumulate方法所需的参数,然后调用用户定义的accumulate方法
小结
- AggregateFunction继承了UserDefinedFunction;它有两个泛型,一个T表示value的泛型,一个ACC表示Accumulator的泛型;它定义了createAccumulator、getValue、getResultType、getAccumulatorType方法(
这几个方法中子类必须实现createAccumulator、getValue方法
);对于AggregateFunction,有一个accumulate方法这里没定义,但是需要子类定义及实现,该方法接收ACC,T等参数,返回void;另外还有retract、merge、resetAccumulator三个方法是可选的,需要子类根据情况去定义及实现(对于datastream bounded over aggregate操作,要求实现restract方法,该方法接收ACC,T等参数,返回void;对于datastream session window grouping aggregate以及dataset grouping aggregate操作,要求实现merge方法,该方法接收ACC,java.lang.Iterable
)两个参数,返回void;对于dataset grouping aggregate操作,要求实现resetAccumulator方法,该方法接收ACC参数,返回void - DataSetPreAggFunction的combine方法会调用function.accumulate(accumulators, record),其中accumulators为Row[WeightedAvgAccum]类型,record为Row类型;function为生成的类,它继承了GeneratedAggregations,其code在genAggregations中,而genAggregations则由AggregateUtil.createDataSetAggregateFunctions方法生成,它会去调用WeightedAvg的accumulate方法;GeneratedAggregations定义了accumulate(accumulators: Row, input: Row)、resetAccumulator(accumulators: Row)等方法
- AggregateUtil的createDataSetAggregateFunctions方法主要是生成GeneratedAggregationsFunction,然后创建DataSetPreAggFunction或DataSetAggFunction;之所以动态生成code,主要是用户自定义的诸如accumulate方法的参数是动态的,而flink代码是基于GeneratedAggregations定义的accumulate(accumulators: Row, input: Row)方法来调用,因此动态生成的code用于适配,在accumulate(accumulators: Row, input: Row)方法里头将Row转换为调用用户定义的accumulate方法所需的参数,然后调用用户定义的accumulate方法
doc
- Aggregation Functions
- 聊聊flink Table的Distinct Aggregation