Flink 的AggregateFunction是一个基于中间计算结果状态进行增量计算的函数。由于是迭代计算方式,所以,在窗口处理过程中,不用缓存整个窗口的数据,所以效率执行比较高。
该函数会将给定的聚合函数应用于每个窗口和键。 对每个元素调用聚合函数,以递增方式聚合值,并将每个键和窗口的状态保持在一个累加器中。
def aggregate[ACC: TypeInformation, R: TypeInformation](
aggregateFunction: AggregateFunction[T, ACC, R]): DataStream[R] = {
val accumulatorType: TypeInformation[ACC] = implicitly[TypeInformation[ACC]]
val resultType: TypeInformation[R] = implicitly[TypeInformation[R]]
asScalaStream(javaStream.aggregate(
clean(aggregateFunction), accumulatorType, resultType))
}
参数类型:AggregateFunction接口。该接口的继承关系和方法如下:
AggregateFunction需要复写的方法有:
从SocketSource接收数据,时间语义采用ProcessingTime,通过Flink 时间窗口以及aggregate方法统计用户在24小时内的平均消费金额。
package org.ourhome.streamapi
import org.apache.flink.api.common.functions.AggregateFunction
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.scala._
import org.apache.flink.streaming.api.windowing.time.Time
/**
* @Author Do
* @Date 2020/4/24 22:51
*/
object WindowFunctionAggrectionTest {
def main(args: Array[String]): Unit = {
val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment
env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime)
env.setParallelism(1)
val socketData: DataStream[String] = env.socketTextStream("local", 9999)
socketData.print("input ")
socketData.map(line => {
ConsumerMess(line.split(",")(0).toInt, line.split(",")(1).toDouble)
})
.keyBy(_.userId)
.timeWindow(Time.hours(24))
.aggregate(new MyAggregrateFunction)
.print("output ")
env.execute()
}
case class ConsumerMess(userId:Int, spend:Double)
// The type of the values that are aggregated (input values)
// The type of the accumulator (intermediate aggregate state).
// The type of the aggregated result
class MyAggregrateFunction extends AggregateFunction[ConsumerMess, (Int, Double), Double] {
override def createAccumulator(): (Int, Double) = (0, 0)
override def add(value: ConsumerMess, accumulator: (Int, Double)): (Int, Double) = {
(accumulator._1 + 1, accumulator._2 + value.spend)
}
override def getResult(accumulator: (Int, Double)): Double = {
accumulator._2/accumulator._1
}
override def merge(a: (Int, Double), b: (Int, Double)): (Int, Double) = {
(a._1 + b._1, b._2 + a._2)
}
}
}
nc -lk 9999
123,666
123,456
123,12
123,3
123,46
123,666
input > 123,666
input > 123,456
output > 561.0
input > 123,12
input > 123,3
input > 123,46
input > 123,666
output > 181.75
根据输出可见: