Spark Streaming + Kafka + Redis状态管理 top100场景 Exactly Once

  最近面试蚂蚁金服一面的时候,和面试官聊项目问题的时候,发现我这边业务实现的top100场景好像没有实现exactly once语义,我们项目的offset是存储在zk中,然后业务处理完毕后,最后再提交offset更新到zk,这种时候就会出现一个问题就是如果业务处理完毕,数据已经更新到redis中进行了累加,然后offset更新zk没成功宕机了,再次重启的时候就会读取老的offset导致数据重复消费两次。由于我们这里是实时top100,每个批次数据来了需要累加式的更新老的数据,即业务处理不是幂等的,所以这种方式是有问题的(这里如果业务处理是幂等的,最后提交offset其实最终效果来说和exactly once是一样的)。
  对此,某天早上地铁上班时看到公众号推荐的一篇关于分布式事务的实现方案的文章,受到其中介绍的维护一个第三方表的模式来实现分布式事务的启示,我们这里可以直接用乐观锁的思想加上第三个辅助表的形式,来实现我们的Spark Streaming + Kafka +Redis 实现exactly once语义的top100。
  乐观锁的实现不了解的可以自己百度在此不再赘述,我们利用的就是其中的一种实现,给每条记录添加一个版本号,而这个版本号就是和我们的批次相关联起来的,这样来保证每条记录只被消费一次,话不多说,直接上设计图:
Spark Streaming + Kafka + Redis状态管理 top100场景 Exactly Once_第1张图片
代码实现如下:

package main.scala

import com.mmtrix.java.constant.ConfigInfo
import com.mmtrix.java.utils.RedisShardedPool
import com.mmtrix.scala.utils.{SimpleKafkaCluster, SparkStreamUtil, StreamingConfig}
import main.java.MyConstant
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.spark.rdd.RDD

import scala.collection.JavaConverters._
import org.apache.spark.streaming.dstream.InputDStream
import org.apache.spark.streaming.kafka.{HasOffsetRanges, OffsetRange}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import redis.clients.jedis.ShardedJedisPipeline

import scala.collection.mutable.ArrayBuffer

class Test {
  def processOne(rdd: RDD[ConsumerRecord[String, String]]): Unit = {
    rdd.foreachPartition(part => {
      val batchRec = ArrayBuffer.empty[ConsumerRecord[String, String]]
      while (part.hasNext) {
        val rec = part.next()
        batchRec.append(rec)
        if (batchRec.length == MyConstant.BATCH_SIZE) {
          // 批量查询更新数据
          // ...
          batchUpd(batchRec)
        }
      }
    })
    // 业务指标处理完毕,更新redis中关于该业务指标的参数
      ...
  }

  def batchUpd(batchRec: ArrayBuffer[ConsumerRecord[String, String]]): Unit = {
    // 批量更新,更新前判断版本号是否
  }

  def process(dStream: InputDStream[ConsumerRecord[String, String]], ssc: StreamingContext)(implicit streamingConfig: StreamingConfig, kc: SimpleKafkaCluster) = {
    val topic = streamingConfig.topic
    val group = streamingConfig.group
    val topicKey = group + topic
    dStream.foreachRDD(rdd => {
      val jedis = RedisShardedPool.getJedis.pipelined()
      val entireSta = jedis.hgetAll(MyConstant.BATCH_STATUS)
      jedis.sync()
      val batchSta = entireSta.get()
      val status = batchSta.get("status")
      val batchCnt = batchSta.get("batch_cnt")
      if (status == "start") { // 读取到的上个批次状态为start,说明上个批次处理异常

        val totalIndexSta = jedis.hgetAll(MyConstant.INDEX_STATUS) // 读取上个批次各个子任务处理状态
        jedis.sync()
        val indexSta = totalIndexSta.get()
        val oneSta = indexSta.get("one_status") // 指标1的状态
        val oneCnt = indexSta.get("one_cnt") // 指标1的版本号
        if (oneSta == "start") {
          //说明指标1没处理完毕任务失败
          processOne(rdd)
          // 数据处理完毕,更新one_status和one_cnt,由于用的是hmset,这里甚至都不需要用事务,因为两个指标是一起更新的。
          jedis.hmset(MyConstant.INDEX_STATUS, Map("one_status" -> "finish", "one_cnt" -> batchCnt).asJava)
        }
        val secSta = indexSta.get("sec_status")
        val secCnt = indexSta.get("sec_cnt")
        // ... 后续处理和指标一完全一致

      } else { //应用正常运行,上一轮任务执行正常结束
        //先初始化当前批次执行状态,批次号+1,状态更新
        jedis.hmset(topicKey,Map("status"->"start","batch_cnt"->(batchCnt.toInt+1).toString).asJava)
        processOne(rdd)
        //processSec(rdd)
        //processThird(rdd) ...
      }
      //执行到这里,说明以上所有业务已经处理完毕,通过事务方式更新辅助中间数据
      val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
      updOffsetAnStatus(offsetRanges,batchCnt,jedis,topicKey)
      jedis.hmset(topicKey,Map("offset"->"").asJava)
    })
  }

  def updOffsetAnStatus(ranges: Array[OffsetRange],batchCnt:String,jedis:ShardedJedisPipeline,topicKey:String): Unit ={
    var map = Map[String,String]("status"->"finish","batch_cnt"->batchCnt)
    for (o <- ranges) {
      val field = o.partition.toString
      val value = o.untilOffset.toString
      map += (field -> value)
    }
    jedis.hmset(topicKey,map.asJava)
  }

  def start(): Unit = {
    def funcToCreateSSC(): StreamingContext = {
      val sparkConf = new org.apache.spark.SparkConf().setAppName(ConfigInfo.sparkJobNameConfig)
      //sparkConf.set(...)

      implicit val streamingConfig = new StreamingConfig
      implicit val kc = new SimpleKafkaCluster(streamingConfig.kafkaParams)

      val ssc = new StreamingContext(sparkConf, Seconds(ConfigInfo.durationConfig))
      val kafkaStream = SparkStreamUtil.createDirectStream(ssc)
      process(kafkaStream, ssc)
      ssc.checkpoint(ConfigInfo.checkpointDirectoryConfig)
      ssc
    }

    FileSystem.get(new Configuration()).deleteOnExit(new Path(ConfigInfo.checkpointDirectoryConfig))
    val ssc = StreamingContext.getOrCreate(ConfigInfo.checkpointDirectoryConfig, funcToCreateSSC)
    ssc.start()
    ssc.awaitTermination()
  }
}
package com.mmtrix.scala.utils

import _root_.kafka.message.MessageAndMetadata
import com.mmtrix.java.utils.RedisShardedPool
import kafka.common.TopicAndPartition
import kafka.serializer.DefaultDecoder
import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.spark.SparkException
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.dstream.InputDStream
import org.apache.spark.streaming.kafka.KafkaUtils

import collection.JavaConverters._

object SparkStreamUtil {

  implicit val streamingConfig = new StreamingConfig
  implicit val kc = new SimpleKafkaCluster(streamingConfig.kafkaParams)

  def createDirectStream(ssc: StreamingContext)(implicit streamingConfig: StreamingConfig, kc: SimpleKafkaCluster): InputDStream[ConsumerRecord[String, String]] = {
    //从 redis 上读取offset开始消费message
    val messages = {
      var hasConsumed = true
      val kafkaPartitionsE = kc.getPartitions(streamingConfig.topicSet)
      if (kafkaPartitionsE.isLeft) throw new SparkException("get kafka partition failed:")
      val kafkaPartitions = kafkaPartitionsE.right.get // 先从zk读取当前kafka最新的partition
      // 依据zk获取的topicAndPartition去redis读取数据,并且要判断是否是第一次开始消费
      val jedis = RedisShardedPool.getJedis
      val partitionCnt = kafkaPartitions.size // 记录从zk读取的一共有多少个分区,用于判断是否集群新增了partition
      val topic = streamingConfig.topic
      val group = streamingConfig.group
      val topicKey = group + topic
      val redisPartitionCnt = jedis.hgetAll(group + topic) // 读取redis中当前这个group在topic下的消费情况
      val previousNum = redisPartitionCnt.getOrDefault("partition_num", "0").toInt //之前数据库中分区信息
      if (previousNum == 0) {
        //没有被消费过,则从zk中最新的offset开始消费。
        val leaderLatestOffsets = kc.getLatestLeaderOffsets(kafkaPartitions).right.get

        // 初始化redis对应分区的offset数据
        leaderLatestOffsets.map(tp => {
          val offset = tp._2.offset
          val partition = tp._1.partition
          jedis.hset(topicKey, partition.toString, offset.toString)
        })

      } else if (previousNum < partitionCnt) { // 新增分区
        // 说明分区数改变了,需要新增分区信息到redis
        val leaderLatestOffsets = kc.getLatestLeaderOffsets(kafkaPartitions).right.get
        leaderLatestOffsets.map(tp => {
          val offset = tp._2.offset
          val partition = tp._1.partition
          val partitionInfo = jedis.hget(topicKey, partition.toString) //获取数据,判断当前分区之前是否已经存在
          if (partitionInfo.isEmpty) { // 分区数据为空,说明这个分区是新增的
            jedis.hset(topicKey, partition.toString, offset.toString)
          }
        })
      } else if (previousNum > partitionCnt) { //减少了分区
        // ...
      }

      //以上操作完毕后,redis中存储的一定就是当前需要消费的各个分区中的offset正确数据
      val infos = jedis.hgetAll(topicKey).asScala //所有分区offset数据

      val offsetRange = infos.map(info => {
        val partition = info._1
        val offset = info._2
        val tp = TopicAndPartition(topic, partition.toInt)
        (tp, offset.toLong)
      }).toMap

      KafkaUtils.createDirectStream[String, String, DefaultDecoder, DefaultDecoder, ConsumerRecord[String, String]](
        ssc, streamingConfig.kafkaParams, offsetRange, (mmd: MessageAndMetadata[String, String]) => {
          new ConsumerRecord[String,String](mmd.topic,mmd.partition,mmd.key(),mmd.offset)
        })
    }
    messages
  }
  def main(args: Array[String]): Unit = {
  }

}

  以上部分redis更新部分内容没有写,由于redis采用的是分片模式,所以就把所有状态都放到一个固定的key下了,然后通过hmset一次性进行设置,这样也可以避免用事务,算是可有可无的优化吧。另外一点需要注意的是关于SparkStreamUtil类的实现,流程控制其实理解了就很好实现,但是这个类还是有点小坑的,这个类需要兼容初次启动时Redis中没有相关kafka数据时数据的初始化,以及新增或者删除分区时的识别,此类初始化数据以及分区的感知都是依赖于kafka自身zk中的元数据信息,当然其实这里最好的自动实时感知分区变化的方式应该是自定义一个DirectKafkaInputDStream类型的InputStream,具体实现参考文章:https://blog.csdn.net/chen20111/article/details/80827226
  代码中还有一个优化的地方就是,由于业务中有多个指标的更新,每个指标更新完毕后,会维护一个对应指标的状态,这样假设有十个指标需要更新,然后更新到第五个指标应用挂了,那么再次重启时,前四个业务指标部分就可以不需要重复执行了(因为更新前判断上个批次这四个指标状态是finish),这样可以提高应用中途宕机重启时的速度。最最最后一点是,由于这里top100是把状态管理全部挪到了Redis中,所以其实是完全可以弃用Checkpoint的,因为即便宕机了,重启之后最后一个批次的执行状态其实都记录在Redis中了,所以有没有Checkpoint都无所谓了的。
  (代码写的比较急,可能会存在小部分漏洞,欢迎大家指正~)

你可能感兴趣的:(Spark,Streaming)