Spark Structured Streaming状态操作: mapGroupsWithState、flatMapGroupsWithState

总结Structured Streaming中的状态操作: mapGroupsWithStateflatMapGroupsWithState

mapGroupsWithState

每次Trigger后,将给定的Function应用于有数据的每个分组,同时维护每组的状态。

先看下mapGroupsWithState Operator,如下:

// S: 状态类型 U: 返回类型
// func: 应用于每组上的函数。K: 当前分组的Key Iterator[V]:当前批次下当前分组内的数据(V: 组内每条数据的类型)  GroupState[S]: 当前这个分组的状态(状态里保存了老的聚合结果以及当前的状态是否超时等等)
// timeoutConf: 超时配置。有三种: GroupStateTimeout.ProcessingTimeTimeout()、GroupStateTimeout.EventTimeTimeout()、GroupStateTimeout.NoTimeout() 。

// 不超时
def mapGroupsWithState[S: Encoder, U: Encoder](func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]

// 柯里化函数,可自定义超时模式, 一般用这个
def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U]

// 不超时
def mapGroupsWithState[S, U](func: MapGroupsWithStateFunction[K, V, S, U],stateEncoder: Encoder[S],outputEncoder: Encoder[U]): Dataset[U]

// 可自定义超时模式
def mapGroupsWithState[S, U](func: MapGroupsWithStateFunction[K, V, S, U],stateEncoder: Encoder[S],outputEncoder: Encoder[U],timeoutConf: GroupStateTimeout): Dataset[U]

注意:

  1. mapGroupsWithState算子的实现来看,mapGroupsWithStateflatMapGroupsWithState算子的一个特例。如下实现:
def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = {
    val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s))
    Dataset[U](
      sparkSession,
      FlatMapGroupsWithState[K, V, S, U](
        flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]],
        groupingAttributes,
        dataAttributes,
        OutputMode.Update,
        isMapGroupsWithState = true,
        timeoutConf,
        child = logicalPlan))
}

可以看到:

A. `mapGroupsWithState`内部实际上还是通过调用`flatMapGroupsWithState`算子的实现`FlatMapGroupsWithState`来实现的。

B. `mapGroupsWithState`的`func`会被转换成返回单个元素的`Iterator`。

C. `mapGroupsWithState`的输出模式是`Update`。
  1. mapGroupsWithState对应的输出模式目前只能是Update,即...writeStream.outputMode("update")...。不能是CompleteAppend

  2. mapGroupsWithStatefunc调用时机: 每次Trigger, 会为该Trigger中有数据的每个组调用一次func;另外,如果某个组超时,会在超时的组上调用一次func。 Trigger中某个组有数据,该组状态不会是超时状态。

  3. 超时配置,支持三种模式的超时: 无超时处理时间超时事件时间超时

    A. 无超时: GroupStateTimeout.NoTimeout()无超时模式会导致状态无限增长。

    B. 事件时间超时: GroupStateTimeout.EventTimeTimeout()。基于事件时间的超时。此前必须使用withWatermark指定Watermark

    C. 处理时间超时: GroupStateTimeout.ProcessingTimeTimeout()。基于处理时间的超时。处理时间模式下无需指定Watermark

  4. 关于数据参与状态计算以及状态的清理。

    A. 事件时间下,即通过withWatermark指定Watermark,同时超时配置为GroupStateTimeout.EventTimeTimeout()。此时,Watermark决定了数据是否能参与状态计算(超过阈值的就不参与状态计算了),分组的状态清理由Spark内部完成。

    B. 处理时间下,即不设置Watermark,同时超时配置为GroupStateTimeout.ProcessingTimeTimeout()。此时,所有数据都能参与状态计算,分组的状态需要手动清理(可以设置超时阈值GroupState.setTimeoutDuration,当超时后,通过GroupState.remove()清除)。

基于事件时间,用mapGroupsWithState统计每个分组的PV,并手动维护状态

测试数据

// 测试数据,如下:
// eventTime: 北京时间
 {"eventTime": "2016-01-01 10:02:00" ,"eventType": "click" ,"userID":"1"}

代码实现

package com.bigdata.structured.streaming.state

import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.time.format.DateTimeFormatter
import java.time.{LocalDateTime, ZoneId}

import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, Trigger}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SparkSession, functions}
import org.slf4j.LoggerFactory


/**
  * Author: Wang Pei
  * Summary:
  *   基于事件时间, 用`mapGroupsWithState`统计每个分组的PV,并手动维护状态
  */
object MapGroupsWithState {

  lazy val logger = LoggerFactory.getLogger(MapGroupsWithState.getClass)

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().master("local[3]").appName(this.getClass.getSimpleName.replace("$", "")).getOrCreate()
    import spark.implicits._

    // 注册UDF
    spark.udf.register("timezoneToTimestamp", timezoneToTimestamp _)
    // 定义Kafka JSON Schema
    val jsonSchema =
      """{"type":"struct","fields":[{"name":"eventTime","type":"string","nullable":true},{"name":"eventType","type":"string","nullable":true},{"name":"userID","type":"string","nullable":true}]}"""

    // InputTable
    val inputTable = spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", "kafka01:9092,kafka02:9092,kafka03:9092")
      .option("subscribe", "test_1")
      .load()

    // ResultTable
    val resultTable = inputTable
      .select(from_json(col("value").cast("string"), DataType.fromJson(jsonSchema)).as("value"))
      .select($"value.*")
      // 增加时间列
      .withColumn("timestamp", functions.callUDF("timezoneToTimestamp", functions.col("eventTime"), lit("yyyy-MM-dd HH:mm:ss"), lit("GMT+8")))
      .filter($"timestamp".isNotNull && $"eventType".isNotNull && $"userID".isNotNull)
      // 定义Watermark, 迟到阈值为2分钟
      .withWatermark("timestamp", "2 minutes")
      // GroupByKey分组, Key: `分钟,用户ID`
      .groupByKey((row: Row) => {
      val timestamp = row.getAs[Timestamp]("timestamp")
      val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm")
      val currentEventTimeMinute = sdf.format(new Date(timestamp.getTime))
      currentEventTimeMinute + "," + row.getAs[String]("userID")
    })
      // mapGroupsWithState
      .mapGroupsWithState[(String, Long), (String, String, Long)](GroupStateTimeout.EventTimeTimeout())((groupKey: String, currentBatchRows: Iterator[Row], groupState: GroupState[(String, Long)]) => {

      println("当前组对应的Key: " + groupKey)
      println("当前Watermark: " + groupState.getCurrentWatermarkMs())
      println("当前组的状态是否存在: " + groupState.exists)
      println("当前组的状态是否过期: " + groupState.hasTimedOut)

      var totalValue = 0L

        // 当前组状态已过期,则清除状态
      if (groupState.hasTimedOut) {
        println("清除状态...")
        groupState.remove()

        // 当前组状态已存在,则根据需要处理
      } else if (groupState.exists) {
        println("增量聚合....")
        // 历史值: 从状态中获取
        val historyValue = groupState.get._2
        // 当前值: 从当前组的新数据计算得到
        val currentValue = currentBatchRows.size
        // 总值=历史+当前
        totalValue = historyValue + currentValue

        // 更新状态
        val newState = (groupKey, totalValue)
        groupState.update(newState)

        // 事件时间模式下,不需要设置超时时间,会根据Watermark机制自动超时
        // 处理时间模式下,可设置个超时时间,根据超时时间清理状态,避免状态无限增加
        // groupState.setTimeoutDuration(1 * 10 * 1000)

        // 当前组状态不存在,则初始化状态
      } else {
        println("初始化状态...")
        totalValue = currentBatchRows.size
        val initialState = (groupKey, totalValue * 1L)
        groupState.update(initialState)
      }

      if (totalValue != 0) {
        val groupKeyArray = groupKey.split(",")
        (groupKeyArray(0), groupKeyArray(1), totalValue)
      } else {
        null
      }

    }).filter(_ != null).toDF("minute", "userID", "pv")

    // Query Start
    val query = resultTable
      .writeStream
      .format("console")
      .option("truncate", "false")
      .outputMode("update")
      .trigger(Trigger.ProcessingTime("2 seconds"))
      .start()

    query.awaitTermination()

  }

  /**
    * 带时区的时间转换为Timestamp
    *
    * @param dateTime
    * @param dataTimeFormat
    * @param dataTimeZone
    * @return
    */
  def timezoneToTimestamp(dateTime: String, dataTimeFormat: String, dataTimeZone: String): Timestamp = {
    var output: Timestamp = null
    try {
      if (dateTime != null) {
        val format = DateTimeFormatter.ofPattern(dataTimeFormat)
        val eventTime = LocalDateTime.parse(dateTime, format).atZone(ZoneId.of(dataTimeZone));
        output = new Timestamp(eventTime.toInstant.toEpochMilli)
      }
    } catch {
      case ex: Exception => logger.error("时间转换异常..." + dateTime, ex)
    }
    output
  }
}

flatMapGroupsWithState

每次Trigger后,将给定的Function应用于有数据的每个分组,同时维护每组的状态。

先看下flatMapGroupsWithState Operator,如下:

// S: 状态类型 U: 返回类型
// outputMode: 输出模式
// timeoutConf: 超时配置。有三种: GroupStateTimeout.ProcessingTimeTimeout()、GroupStateTimeout.EventTimeTimeout()、GroupStateTimeout.NoTimeout() 。
// func: 应用于每组上的函数。K: 当前分组的Key Iterator[V]:当前批次下当前分组内的数据(V: 组内每条数据的类型)  GroupState[S]: 当前这个分组的状态(状态里保存了老的聚合结果以及当前的状态是否超时等等) 

def flatMapGroupsWithState[S: Encoder, U: Encoder](
       outputMode: OutputMode,
       timeoutConf: GroupStateTimeout)(
      func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U]

def flatMapGroupsWithState[S, U](
      func: FlatMapGroupsWithStateFunction[K, V, S, U],
      outputMode: OutputMode,
      stateEncoder: Encoder[S],
      outputEncoder: Encoder[U],
      timeoutConf: GroupStateTimeout): Dataset[U]

注意:

  1. flatMapGroupsWithStatefunc调用时机,同mapGroupsWithState一样。

  2. flatMapGroupsWithState返回一个迭代器Iterator(多条)mapGroupsWithState返回单条。

  3. flatMapGroupsWithState的输出模式应和...writeStream...outputMode()...设置的输出模式一致。且输出模式只能为AppendUpdate,不能为Complete

  4. flatMapGroupsWithState的超时配置,同mapGroupsWithState一样。

  5. flatMapGroupsWithState对数据参与状态计算以及状态的清理,同mapGroupsWithState一样。

基于处理时间, 用flatMapGroupsWithState统计每个分组的PV,并手动维护状态

测试数据

// 测试数据,如下:
// eventTime: 北京时间
 {"eventTime": "2016-01-01 10:02:00" ,"eventType": "click" ,"userID":"1"}

代码实现

package com.bigdata.structured.streaming.state

import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.time.format.DateTimeFormatter
import java.time.{LocalDateTime, ZoneId}

import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, Trigger}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SparkSession, functions}
import org.slf4j.LoggerFactory

import scala.collection.mutable.ArrayBuffer

/**
  * Author: Wang Pei
  * Summary:
  *     基于处理时间,用`flatMapGroupsWithState`统计每个分组的PV,并手动维护状态
  */
object FlapMapGroupsWithState {

  lazy val logger = LoggerFactory.getLogger(MapGroupsWithState.getClass)

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().master("local[3]").appName(this.getClass.getSimpleName.replace("$", "")).getOrCreate()
    import spark.implicits._

    // 注册UDF
    spark.udf.register("timezoneToTimestamp", timezoneToTimestamp _)
    // 定义Kafka JSON Schema
    val jsonSchema =
      """{"type":"struct","fields":[{"name":"eventTime","type":"string","nullable":true},{"name":"eventType","type":"string","nullable":true},{"name":"userID","type":"string","nullable":true}]}"""

    // InputTable
    val inputTable = spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", "kafka01:9092,kafka02:9092,kafka03:9092")
      .option("subscribe", "test_1")
      .load()

    // ResultTable
    val resultTable = inputTable
      .select(from_json(col("value").cast("string"), DataType.fromJson(jsonSchema)).as("value"))
      .select($"value.*")
      // 增加时间列
      .withColumn("timestamp", functions.callUDF("timezoneToTimestamp", functions.col("eventTime"), lit("yyyy-MM-dd HH:mm:ss"), lit("GMT+8")))
      .filter($"timestamp".isNotNull && $"eventType".isNotNull && $"userID".isNotNull)
      // 基于处理时间,不需要设置Watermark
      // .withWatermark("timestamp", "2 minutes")
      // GroupByKey分组, Key: `分钟,用户ID`
      .groupByKey((row: Row) => {
      val timestamp = row.getAs[Timestamp]("timestamp")
      val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm")
      val currentEventTimeMinute = sdf.format(new Date(timestamp.getTime))
      currentEventTimeMinute + "," + row.getAs[String]("userID")
    })
      // flatMapGroupsWithState
      .flatMapGroupsWithState[(String, Long), (String, String, Long)](OutputMode.Update(), GroupStateTimeout.ProcessingTimeTimeout())((groupKey: String, currentBatchRows: Iterator[Row], groupState: GroupState[(String, Long)]) => {

      println("当前组对应的Key: " + groupKey)
      println("当前ProcessingTime: " + groupState.getCurrentProcessingTimeMs())
      println("当前组的状态是否存在: " + groupState.exists)
      println("当前组的状态是否过期: " + groupState.hasTimedOut)

      var totalValue = 0L

        // 当前组状态已过期,则清除状态
      if (groupState.hasTimedOut) {
        println("清除状态...")
        groupState.remove()

        // 当前组状态已存在,则根据需要处理
      } else if (groupState.exists) {
        println("增量聚合....")
        // 历史值: 从状态中获取
        val historyValue = groupState.get._2
        // 当前值: 从当前组的新数据计算得到
        val currentValue = currentBatchRows.size
        // 总值=历史+当前
        totalValue = historyValue + currentValue

        // 更新状态
        val newState = (groupKey, totalValue)
        groupState.update(newState)

        // 设置状态超时时间为10秒
        groupState.setTimeoutDuration(10 * 1000)

        // 当前组状态不存在,则初始化状态
      } else {
        println("初始化状态...")
        totalValue = currentBatchRows.size
        val initialState = (groupKey, totalValue * 1L)
        groupState.update(initialState)
        groupState.setTimeoutDuration(10 * 1000)
      }

      val output = ArrayBuffer[(String, String, Long)]()
      if (totalValue != 0) {
        val groupKeyArray = groupKey.split(",")
        output.append((groupKeyArray(0), groupKeyArray(1), totalValue))
      }
      output.iterator

    }).toDF("minute", "userID", "pv")

    // Query Start
    val query = resultTable
      .writeStream
      .format("console")
      .option("truncate", "false")
      .outputMode("update")
      .trigger(Trigger.ProcessingTime("10 seconds"))
      .start()

    query.awaitTermination()

  }

  /**
    * 带时区的时间转换为Timestamp
    *
    * @param dateTime
    * @param dataTimeFormat
    * @param dataTimeZone
    * @return
    */
  def timezoneToTimestamp(dateTime: String, dataTimeFormat: String, dataTimeZone: String): Timestamp = {
    var output: Timestamp = null
    try {
      if (dateTime != null) {
        val format = DateTimeFormatter.ofPattern(dataTimeFormat)
        val eventTime = LocalDateTime.parse(dateTime, format).atZone(ZoneId.of(dataTimeZone));
        output = new Timestamp(eventTime.toInstant.toEpochMilli)
      }
    } catch {
      case ex: Exception => logger.error("时间转换异常..." + dateTime, ex)
    }
    output
  }
}

你可能感兴趣的:(Spark,Spark)