spark streaming updateStateByKey 使用方法

  1. updateStateByKey 解释: 
    以DStream中的数据进行按key做reduce操作,然后对各个批次的数据进行累加 
    在有新的数据信息进入或更新时。能够让用户保持想要的不论什么状。使用这个功能须要完毕两步: 
    1) 定义状态:能够是随意数据类型 
    2) 定义状态更新函数:用一个函数指定怎样使用先前的状态。从输入流中的新值更新状态。 
    对于有状态操作,要不断的把当前和历史的时间切片的RDD累加计算,随着时间的流失,计算的数据规模会变得越来越大。

     

  2. updateStateByKey源代码:

    /**

    • Return a new “state” DStream where the state for each key is updated by applying
    • the given function on the previous state of the key and the new values of the key.
    • org.apache.spark.Partitioner is used to control the partitioning of each RDD.
    • @param updateFunc State update function. If this function returns None, then
    • corresponding state key-value pair will be eliminated.
    • @param partitioner Partitioner for controlling the partitioning of each RDD in the new
    • DStream.
    • @param initialRDD initial state value of each key.
    • @tparam S State type 
      */ 
      def updateStateByKey[S: ClassTag]( 
      updateFunc: (Seq[V], Option[S]) => Option[S], 
      partitioner: Partitioner, 
      initialRDD: RDD[(K, S)] 
      ): DStream[(K, S)] = { 
      val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { 
      iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) 

      updateStateByKey(newUpdateFunc, partitioner, true, initialRDD) 
      }
  3. 代码实现

    • StatefulNetworkWordCount

      object StatefulNetworkWordCount {
      def main(args: Array[String]) {
      if (args.length < 2) {
        System.err.println("Usage: StatefulNetworkWordCount  ")
        System.exit(1)
      }
      
      Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
      
      val updateFunc = (values: Seq[Int], state: Option[Int]) => {
        val currentCount = values.sum
      
        val previousCount = state.getOrElse(0)
      
        Some(currentCount + previousCount)
      }
      
      val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
        iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
      }
      
      val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount").setMaster("local")
      // Create the context with a 1 second batch size
      val ssc = new StreamingContext(sparkConf, Seconds(1))
      ssc.checkpoint(".")
      
      // Initial RDD input to updateStateByKey
      val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))
      
      // Create a ReceiverInputDStream on target ip:port and count the
      // words in input stream of \n delimited test (eg. generated by 'nc')
      val lines = ssc.socketTextStream(args(0), args(1).toInt)
      val words = lines.flatMap(_.split(" "))
      val wordDstream = words.map(x => (x, 1))
      
      // Update the cumulative count using updateStateByKey
      // This will give a Dstream made of state (which is the cumulative count of the words)
      val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
        new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
      stateDstream.print()
      ssc.start()
      ssc.awaitTermination()
      }
      }

NetworkWordCount

import org.apache.spark.SparkConf
import org.apache.spark.HashPartitioner
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.StreamingContext._

object NetworkWordCount {
  def main(args: Array[String]) {
    if (args.length < 2) {
      System.err.println("Usage: NetworkWordCount  ")
      System.exit(1)
    }


    val sparkConf = new SparkConf().setAppName("NetworkWordCount")
    val ssc = new StreamingContext(sparkConf, Seconds(10))
    //使用updateStateByKey前须要设置checkpoint
    ssc.checkpoint("hdfs://master:8020/spark/checkpoint")

    val addFunc = (currValues: Seq[Int], prevValueState: Option[Int]) => {
      //通过Spark内部的reduceByKey按key规约。然后这里传入某key当前批次的Seq/List,再计算当前批次的总和
      val currentCount = currValues.sum
      // 已累加的值
      val previousCount = prevValueState.getOrElse(0)
      // 返回累加后的结果。是一个Option[Int]类型
      Some(currentCount + previousCount)
    }

    val lines = ssc.socketTextStream(args(0), args(1).toInt)
    val words = lines.flatMap(_.split(" "))
    val pairs = words.map(word => (word, 1))

    //val currWordCounts = pairs.reduceByKey(_ + _)
    //currWordCounts.print()

    val totalWordCounts = pairs.updateStateByKey[Int](addFunc)
    totalWordCounts.print()

    ssc.start()
    ssc.awaitTermination()
  }
}

 

你可能感兴趣的:(sparkstreaming)