Spark updateWithState

import org.apache.kafka.common.serialization.StringDeserializer
import org.apache.spark.SparkConf
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe
import org.apache.spark.streaming.kafka010.KafkaUtils
import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent
import org.apache.spark.streaming.{Seconds, State, StateSpec, StreamingContext}


object WordCount {
  def getKafkaData(ssc: StreamingContext, topicName: String, brokerUrl: String, groupId: String = "defaultAlarm", readFrom: String = "latest"): DStream[String] = {
    val kafkaParams = Map[String, Object](
      "bootstrap.servers" -> brokerUrl,
      "key.deserializer" -> classOf[StringDeserializer],
      "value.deserializer" -> classOf[StringDeserializer],
      "group.id" -> groupId,
      "auto.offset.reset" -> readFrom,
      "enable.auto.commit" -> (false: java.lang.Boolean)
    )
    val dataSource = KafkaUtils.createDirectStream[String, String](
      ssc,
      PreferConsistent,
      Subscribe[String, String](Array(topicName), kafkaParams)
    ).map(_.value)
    dataSource
  }
  // key= word,value = count, key and value 即为dstream的key value; state 自定义类型,返回值,即经过map后的输出。  
  def updateCounterState(word: String, count: Option[Int], state: State[Long]): Option[(String, Long)] = {
    val sum = count.getOrElse(0).toLong + state.getOption.getOrElse(0L)
    val output = (word, sum)
    state.update(sum)
    Some(output)
  }

  def main(args: Array[String]) {
    val interval = args(0).toInt
    val sc = new SparkConf().setAppName("wordcount")
    val ssc = new StreamingContext(sc, Seconds(interval))
    val topicName = "wordcount"
    val brokerUrl = "your.kafka.host.ip:9092"
    //val dataSource = ssc.socketTextStream("your.server.host.ip", 5560)
    val dataSource = getKafkaData(ssc, topicName, brokerUrl, "wordcount", "earliest")
    val wordStream = dataSource.flatMap(_.split(" ")).map(word => (word, 1))
    // 初始化状态,key,value = state
    val initialRDD = ssc.sparkContext.parallelize(List(("dummy", 100L), ("source", 32L)))
    val stateSpec = StateSpec.function(updateCounterState _)
      .initialState(initialRDD)
      .numPartitions(2)
      .timeout(Seconds(60))

    val wordCountStateStream = wordStream.mapWithState(stateSpec)
    wordCountStateStream.print()
    // 获取所有状态
    val stateSnapshotStream = wordCountStateStream.stateSnapshots()
    stateSnapshotStream.foreachRDD( _.foreach(print))
    ssc.start()
    ssc.awaitTermination()
  }

}



你可能感兴趣的:(spark,大数据)