总结Structured Streaming中的状态操作: mapGroupsWithState
、flatMapGroupsWithState
。
每次Trigger
后,将给定的Function
应用于有数据的每个分组,同时维护每组的状态。
先看下mapGroupsWithState Operator,如下:
// S: 状态类型 U: 返回类型
// func: 应用于每组上的函数。K: 当前分组的Key Iterator[V]:当前批次下当前分组内的数据(V: 组内每条数据的类型) GroupState[S]: 当前这个分组的状态(状态里保存了老的聚合结果以及当前的状态是否超时等等)
// timeoutConf: 超时配置。有三种: GroupStateTimeout.ProcessingTimeTimeout()、GroupStateTimeout.EventTimeTimeout()、GroupStateTimeout.NoTimeout() 。
// 不超时
def mapGroupsWithState[S: Encoder, U: Encoder](func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]
// 柯里化函数,可自定义超时模式, 一般用这个
def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]
// 不超时
def mapGroupsWithState[S, U](func: MapGroupsWithStateFunction[K, V, S, U],stateEncoder: Encoder[S],outputEncoder: Encoder[U]): Dataset[U]
// 可自定义超时模式
def mapGroupsWithState[S, U](func: MapGroupsWithStateFunction[K, V, S, U],stateEncoder: Encoder[S],outputEncoder: Encoder[U],timeoutConf: GroupStateTimeout): Dataset[U]
注意:
mapGroupsWithState
算子的实现来看,mapGroupsWithState
是flatMapGroupsWithState
算子的一个特例。如下实现:def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
Dataset[U](
sparkSession,
FlatMapGroupsWithState[K, V, S, U](
flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
groupingAttributes,
dataAttributes,
OutputMode.Update,
isMapGroupsWithState = true,
timeoutConf,
child = logicalPlan))
}
可以看到:
A. `mapGroupsWithState`内部实际上还是通过调用`flatMapGroupsWithState`算子的实现`FlatMapGroupsWithState`来实现的。
B. `mapGroupsWithState`的`func`会被转换成返回单个元素的`Iterator`。
C. `mapGroupsWithState`的输出模式是`Update`。
mapGroupsWithState
对应的输出模式目前只能是Update
,即...writeStream.outputMode("update")...
。不能是Complete
或Append
。
mapGroupsWithState
的func
调用时机: 每次Trigger, 会为该Trigger中有数据的每个组调用一次func
;另外,如果某个组超时,会在超时的组上调用一次func
。 Trigger中某个组有数据,该组状态不会是超时状态。
超时配置,支持三种模式的超时: 无超时
、处理时间超时
、事件时间超时
。
A. 无超时: GroupStateTimeout.NoTimeout()
。无超时模式
会导致状态无限增长。
B. 事件时间超时: GroupStateTimeout.EventTimeTimeout()
。基于事件时间的超时。此前必须使用withWatermark
指定Watermark
。
C. 处理时间超时: GroupStateTimeout.ProcessingTimeTimeout()
。基于处理时间的超时。处理时间模式下无需指定Watermark
。
关于数据参与状态计算以及状态的清理。
A. 事件时间下,即通过withWatermark
指定Watermark
,同时超时配置为GroupStateTimeout.EventTimeTimeout()
。此时,Watermark
决定了数据是否能参与状态计算(超过阈值的就不参与状态计算了),分组的状态清理由Spark内部完成。
B. 处理时间下,即不设置Watermark
,同时超时配置为GroupStateTimeout.ProcessingTimeTimeout()
。此时,所有数据都能参与状态计算,分组的状态需要手动清理(可以设置超时阈值GroupState.setTimeoutDuration,当超时后,通过GroupState.remove()清除)。
mapGroupsWithState
统计每个分组的PV,并手动维护状态// 测试数据,如下:
// eventTime: 北京时间
{"eventTime": "2016-01-01 10:02:00" ,"eventType": "click" ,"userID":"1"}
package com.bigdata.structured.streaming.state
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.time.format.DateTimeFormatter
import java.time.{LocalDateTime, ZoneId}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, Trigger}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SparkSession, functions}
import org.slf4j.LoggerFactory
/**
* Author: Wang Pei
* Summary:
* 基于事件时间, 用`mapGroupsWithState`统计每个分组的PV,并手动维护状态
*/
object MapGroupsWithState {
lazy val logger = LoggerFactory.getLogger(MapGroupsWithState.getClass)
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[3]").appName(this.getClass.getSimpleName.replace("$", "")).getOrCreate()
import spark.implicits._
// 注册UDF
spark.udf.register("timezoneToTimestamp", timezoneToTimestamp _)
// 定义Kafka JSON Schema
val jsonSchema =
"""{"type":"struct","fields":[{"name":"eventTime","type":"string","nullable":true},{"name":"eventType","type":"string","nullable":true},{"name":"userID","type":"string","nullable":true}]}"""
// InputTable
val inputTable = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", "kafka01:9092,kafka02:9092,kafka03:9092")
.option("subscribe", "test_1")
.load()
// ResultTable
val resultTable = inputTable
.select(from_json(col("value").cast("string"), DataType.fromJson(jsonSchema)).as("value"))
.select($"value.*")
// 增加时间列
.withColumn("timestamp", functions.callUDF("timezoneToTimestamp", functions.col("eventTime"), lit("yyyy-MM-dd HH:mm:ss"), lit("GMT+8")))
.filter($"timestamp".isNotNull && $"eventType".isNotNull && $"userID".isNotNull)
// 定义Watermark, 迟到阈值为2分钟
.withWatermark("timestamp", "2 minutes")
// GroupByKey分组, Key: `分钟,用户ID`
.groupByKey((row: Row) => {
val timestamp = row.getAs[Timestamp]("timestamp")
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm")
val currentEventTimeMinute = sdf.format(new Date(timestamp.getTime))
currentEventTimeMinute + "," + row.getAs[String]("userID")
})
// mapGroupsWithState
.mapGroupsWithState[(String, Long), (String, String, Long)](GroupStateTimeout.EventTimeTimeout())((groupKey: String, currentBatchRows: Iterator[Row], groupState: GroupState[(String, Long)]) => {
println("当前组对应的Key: " + groupKey)
println("当前Watermark: " + groupState.getCurrentWatermarkMs())
println("当前组的状态是否存在: " + groupState.exists)
println("当前组的状态是否过期: " + groupState.hasTimedOut)
var totalValue = 0L
// 当前组状态已过期,则清除状态
if (groupState.hasTimedOut) {
println("清除状态...")
groupState.remove()
// 当前组状态已存在,则根据需要处理
} else if (groupState.exists) {
println("增量聚合....")
// 历史值: 从状态中获取
val historyValue = groupState.get._2
// 当前值: 从当前组的新数据计算得到
val currentValue = currentBatchRows.size
// 总值=历史+当前
totalValue = historyValue + currentValue
// 更新状态
val newState = (groupKey, totalValue)
groupState.update(newState)
// 事件时间模式下,不需要设置超时时间,会根据Watermark机制自动超时
// 处理时间模式下,可设置个超时时间,根据超时时间清理状态,避免状态无限增加
// groupState.setTimeoutDuration(1 * 10 * 1000)
// 当前组状态不存在,则初始化状态
} else {
println("初始化状态...")
totalValue = currentBatchRows.size
val initialState = (groupKey, totalValue * 1L)
groupState.update(initialState)
}
if (totalValue != 0) {
val groupKeyArray = groupKey.split(",")
(groupKeyArray(0), groupKeyArray(1), totalValue)
} else {
null
}
}).filter(_ != null).toDF("minute", "userID", "pv")
// Query Start
val query = resultTable
.writeStream
.format("console")
.option("truncate", "false")
.outputMode("update")
.trigger(Trigger.ProcessingTime("2 seconds"))
.start()
query.awaitTermination()
}
/**
* 带时区的时间转换为Timestamp
*
* @param dateTime
* @param dataTimeFormat
* @param dataTimeZone
* @return
*/
def timezoneToTimestamp(dateTime: String, dataTimeFormat: String, dataTimeZone: String): Timestamp = {
var output: Timestamp = null
try {
if (dateTime != null) {
val format = DateTimeFormatter.ofPattern(dataTimeFormat)
val eventTime = LocalDateTime.parse(dateTime, format).atZone(ZoneId.of(dataTimeZone));
output = new Timestamp(eventTime.toInstant.toEpochMilli)
}
} catch {
case ex: Exception => logger.error("时间转换异常..." + dateTime, ex)
}
output
}
}
每次Trigger
后,将给定的Function
应用于有数据的每个分组,同时维护每组的状态。
先看下flatMapGroupsWithState Operator,如下:
// S: 状态类型 U: 返回类型
// outputMode: 输出模式
// timeoutConf: 超时配置。有三种: GroupStateTimeout.ProcessingTimeTimeout()、GroupStateTimeout.EventTimeTimeout()、GroupStateTimeout.NoTimeout() 。
// func: 应用于每组上的函数。K: 当前分组的Key Iterator[V]:当前批次下当前分组内的数据(V: 组内每条数据的类型) GroupState[S]: 当前这个分组的状态(状态里保存了老的聚合结果以及当前的状态是否超时等等)
def flatMapGroupsWithState[S: Encoder, U: Encoder](
outputMode: OutputMode,
timeoutConf: GroupStateTimeout)(
func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]
def flatMapGroupsWithState[S, U](
func: FlatMapGroupsWithStateFunction[K, V, S, U],
outputMode: OutputMode,
stateEncoder: Encoder[S],
outputEncoder: Encoder[U],
timeoutConf: GroupStateTimeout): Dataset[U]
注意:
flatMapGroupsWithState
的func
调用时机,同mapGroupsWithState
一样。
flatMapGroupsWithState
返回一个迭代器Iterator(多条)
,mapGroupsWithState
返回单条。
flatMapGroupsWithState
的输出模式应和...writeStream...outputMode()...
设置的输出模式一致。且输出模式只能为Append
或Update
,不能为Complete
。
flatMapGroupsWithState
的超时配置,同mapGroupsWithState
一样。
flatMapGroupsWithState
对数据参与状态计算以及状态的清理,同mapGroupsWithState
一样。
flatMapGroupsWithState
统计每个分组的PV,并手动维护状态// 测试数据,如下:
// eventTime: 北京时间
{"eventTime": "2016-01-01 10:02:00" ,"eventType": "click" ,"userID":"1"}
package com.bigdata.structured.streaming.state
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.time.format.DateTimeFormatter
import java.time.{LocalDateTime, ZoneId}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, Trigger}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SparkSession, functions}
import org.slf4j.LoggerFactory
import scala.collection.mutable.ArrayBuffer
/**
* Author: Wang Pei
* Summary:
* 基于处理时间,用`flatMapGroupsWithState`统计每个分组的PV,并手动维护状态
*/
object FlapMapGroupsWithState {
lazy val logger = LoggerFactory.getLogger(MapGroupsWithState.getClass)
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().master("local[3]").appName(this.getClass.getSimpleName.replace("$", "")).getOrCreate()
import spark.implicits._
// 注册UDF
spark.udf.register("timezoneToTimestamp", timezoneToTimestamp _)
// 定义Kafka JSON Schema
val jsonSchema =
"""{"type":"struct","fields":[{"name":"eventTime","type":"string","nullable":true},{"name":"eventType","type":"string","nullable":true},{"name":"userID","type":"string","nullable":true}]}"""
// InputTable
val inputTable = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", "kafka01:9092,kafka02:9092,kafka03:9092")
.option("subscribe", "test_1")
.load()
// ResultTable
val resultTable = inputTable
.select(from_json(col("value").cast("string"), DataType.fromJson(jsonSchema)).as("value"))
.select($"value.*")
// 增加时间列
.withColumn("timestamp", functions.callUDF("timezoneToTimestamp", functions.col("eventTime"), lit("yyyy-MM-dd HH:mm:ss"), lit("GMT+8")))
.filter($"timestamp".isNotNull && $"eventType".isNotNull && $"userID".isNotNull)
// 基于处理时间,不需要设置Watermark
// .withWatermark("timestamp", "2 minutes")
// GroupByKey分组, Key: `分钟,用户ID`
.groupByKey((row: Row) => {
val timestamp = row.getAs[Timestamp]("timestamp")
val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm")
val currentEventTimeMinute = sdf.format(new Date(timestamp.getTime))
currentEventTimeMinute + "," + row.getAs[String]("userID")
})
// flatMapGroupsWithState
.flatMapGroupsWithState[(String, Long), (String, String, Long)](OutputMode.Update(), GroupStateTimeout.ProcessingTimeTimeout())((groupKey: String, currentBatchRows: Iterator[Row], groupState: GroupState[(String, Long)]) => {
println("当前组对应的Key: " + groupKey)
println("当前ProcessingTime: " + groupState.getCurrentProcessingTimeMs())
println("当前组的状态是否存在: " + groupState.exists)
println("当前组的状态是否过期: " + groupState.hasTimedOut)
var totalValue = 0L
// 当前组状态已过期,则清除状态
if (groupState.hasTimedOut) {
println("清除状态...")
groupState.remove()
// 当前组状态已存在,则根据需要处理
} else if (groupState.exists) {
println("增量聚合....")
// 历史值: 从状态中获取
val historyValue = groupState.get._2
// 当前值: 从当前组的新数据计算得到
val currentValue = currentBatchRows.size
// 总值=历史+当前
totalValue = historyValue + currentValue
// 更新状态
val newState = (groupKey, totalValue)
groupState.update(newState)
// 设置状态超时时间为10秒
groupState.setTimeoutDuration(10 * 1000)
// 当前组状态不存在,则初始化状态
} else {
println("初始化状态...")
totalValue = currentBatchRows.size
val initialState = (groupKey, totalValue * 1L)
groupState.update(initialState)
groupState.setTimeoutDuration(10 * 1000)
}
val output = ArrayBuffer[(String, String, Long)]()
if (totalValue != 0) {
val groupKeyArray = groupKey.split(",")
output.append((groupKeyArray(0), groupKeyArray(1), totalValue))
}
output.iterator
}).toDF("minute", "userID", "pv")
// Query Start
val query = resultTable
.writeStream
.format("console")
.option("truncate", "false")
.outputMode("update")
.trigger(Trigger.ProcessingTime("10 seconds"))
.start()
query.awaitTermination()
}
/**
* 带时区的时间转换为Timestamp
*
* @param dateTime
* @param dataTimeFormat
* @param dataTimeZone
* @return
*/
def timezoneToTimestamp(dateTime: String, dataTimeFormat: String, dataTimeZone: String): Timestamp = {
var output: Timestamp = null
try {
if (dateTime != null) {
val format = DateTimeFormatter.ofPattern(dataTimeFormat)
val eventTime = LocalDateTime.parse(dateTime, format).atZone(ZoneId.of(dataTimeZone));
output = new Timestamp(eventTime.toInstant.toEpochMilli)
}
} catch {
case ex: Exception => logger.error("时间转换异常..." + dateTime, ex)
}
output
}
}