【源码追踪】SparkStreaming 中用 Direct 方式每次从 Kafka 拉取多少条数据(offset取值范围)

我们知道 SparkStreaming 用 Direct 的方式拉取 Kafka 数据时,是根据 kafka 中的 fromOffsets 和 untilOffsets 来进行获取数据的,而 fromOffsets 一般都是需要我们自己管理的,而每批次的 untilOffsets 是由 Driver 程序自动帮我们算出来的。
于是产生了一个疑问:untilOffsets 是怎么算出来的?
接下来就通过查看源码的方式来找出答案~

首先我们写一个最简单的 wordcount 程序,代码如下:

/**
  * Created by Lin_wj1995 on 2018/4/19.
  * 来源:https://blog.csdn.net/Lin_wj1995
  */
object DirectKafkaWordCount {
  def main(args: Array[String]) {
    val Array(brokers, topics) = args
    val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount")
    val ssc = new StreamingContext(sparkConf, Seconds(2))

    val topicsSet = topics.split(",").toSet
    val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers)
    val messages = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](ssc, kafkaParams, topicsSet)

    //拿到数据
    val lines = messages.map(_._2)
    val words = lines.flatMap(_.split(" "))
    val wordCounts = words.map(x => (x, 1L)).reduceByKey(_ + _)
    wordCounts.print()

    // 启动
    ssc.start()
    ssc.awaitTermination()
  }
}

我们可以看出, createDirectStream 是获得数据的关键方法的,我们点击进去

  def createDirectStream[
    K: ClassTag,
    V: ClassTag,
    KD <: Decoder[K]: ClassTag,
    VD <: Decoder[V]: ClassTag] (
      ssc: StreamingContext,
      kafkaParams: Map[String, String],
      topics: Set[String]
  ): InputDStream[(K, V)] = {
    val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
    //kafka cluster 连接对象
    val kc = new KafkaCluster(kafkaParams)
    //读取数据的开始位置
    val fromOffsets = getFromOffsets(kc, kafkaParams, topics)
    //该方法返回了一个DirectKafkaInputDStream的对象
    new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
      ssc, kafkaParams, fromOffsets, messageHandler)
  }

ok,重点来了,点击 DirectKafkaInputDStream ,看一下该类内部是如何的,由于该类内部的方法都是重点,所有我把该类重点的属性和方法有选择性的贴出来:
建议从下往上读!~

private[streaming]
class DirectKafkaInputDStream[
  K: ClassTag,
  V: ClassTag,
  U <: Decoder[K]: ClassTag,
  T <: Decoder[V]: ClassTag,
  R: ClassTag](
    ssc_ : StreamingContext,
    val kafkaParams: Map[String, String],
    val fromOffsets: Map[TopicAndPartition, Long],
    messageHandler: MessageAndMetadata[K, V] => R
  ) extends InputDStream[R](ssc_) with Logging {
  /**
    * 为了拿到每个分区leader上的最新偏移量(默认值为1),Driver发出请求的最大的连续重试次数
    * 默认值为1,也就是说最多请求 2 次
    */
  val maxRetries = context.sparkContext.getConf.getInt(
    "spark.streaming.kafka.maxRetries", 1)

  /**
    * 通过 receiver tracker 异步地维持和发送新的 rate limits 给 receiver
    * 注意:如果参数 spark.streaming.backpressure.enabled 没有设置,那么返回为None
   */
  override protected[streaming] val rateController: Option[RateController] = {
    /**
      * isBackPressureEnabled方法对应着“spark.streaming.backpressure.enabled”参数
      * 参数说明:简单来讲就是自动推测程序的执行情况并控制接收数据的条数,为了防止处理数据的时间大于批次时间而导致的数据堆积
      *           默认是没有开启的
      */
    if (RateController.isBackPressureEnabled(ssc.conf)) {
      Some(new DirectKafkaRateController(id,
        RateEstimator.create(ssc.conf, context.graph.batchDuration)))
    } else {
      None
    }
  }

  //拿到与Kafka集群的连接
  protected val kc = new KafkaCluster(kafkaParams)

  //每个partition每次最多获取多少条数据,默认是0
  private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt(
      "spark.streaming.kafka.maxRatePerPartition", 0)

  /**
    * 真实算出每个partition获取数据的最大条数
    */
  protected def maxMessagesPerPartition: Option[Long] = {
    val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) //每批都根据rateContoller预估获取多少条数据
    val numPartitions = currentOffsets.keys.size

    val effectiveRateLimitPerPartition = estimatedRateLimit
      .filter(_ > 0)
      .map { limit =>
        if (maxRateLimitPerPartition > 0) {
          /*
          如果 spark.streaming.kafka.maxRatePerPartition 该参数有设置值且大于0
          那么就取 maxRateLimitPerPartition 和 rateController 算出来的值 之间的最小值(为什么取最小值,因为这样是最保险的)
           */
          Math.min(maxRateLimitPerPartition, (limit / numPartitions))
        } else {
          /*
          如果 spark.streaming.kafka.maxRatePerPartition 该参数没有设置
          那么就直接用 rateController 算出来的值
           */
          limit / numPartitions
        }
      }.getOrElse(maxRateLimitPerPartition) //如果没有设置自动推测的话,则返回参数设定的接收速率

    if (effectiveRateLimitPerPartition > 0) {
      val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
      Some((secsPerBatch * effectiveRateLimitPerPartition).toLong)
    } else {
      /*
      如果没有设置 spark.streaming.kafka.maxRatePerPartition 参数,则返回None
       */
      None
    }
  }

  //拿到每批的起始 offset
  protected var currentOffsets = fromOffsets

  /**
    * 获取此时此刻topic中每个partition 最大的(最新的)offset
    */
  @tailrec
  protected final def latestLeaderOffsets(retries: Int): Map[TopicAndPartition, LeaderOffset] = {
    val o = kc.getLatestLeaderOffsets(currentOffsets.keySet)
    // Either.fold would confuse @tailrec, do it manually
    if (o.isLeft) {
      val err = o.left.get.toString
      if (retries <= 0) {
        throw new SparkException(err)
      } else {
        log.error(err)
        Thread.sleep(kc.config.refreshLeaderBackoffMs)
        latestLeaderOffsets(retries - 1)//如果获取失败,则重试,且重试次数 -1
      }
    } else {
      o.right.get //如果没有问题,则拿到最新的 offset
    }
  }

  // limits the maximum number of messages per partition
  /**
    * ★★★★★重要方法,答案就在这里
    * @param leaderOffsets 该参数的offset是当前最新的offset
    * @return 包含untilOffsets的信息
    */
  protected def clamp(
    leaderOffsets: Map[TopicAndPartition, LeaderOffset]): Map[TopicAndPartition, LeaderOffset] = {
    maxMessagesPerPartition.map { mmp =>
      leaderOffsets.map { case (tp, lo) =>
        /**
          * 如果有设定自动推测,那么就将值设定为: min(自动推测出来的offset,此时此刻最新的offset)
          */
        tp -> lo.copy(offset = Math.min(currentOffsets(tp) + mmp, lo.offset))
      }
    }.getOrElse(leaderOffsets) //如果没有设定自动推测,那么untilOffsets的值就是最新的offset
  }

  override def compute(validTime: Time): Option[KafkaRDD[K, V, U, T, R]] = {
    //====》★★★★★从这里作为入口尽心查看
    val untilOffsets = clamp(latestLeaderOffsets(maxRetries))
    //根据offset去拉取数据,完!
    val rdd = KafkaRDD[K, V, U, T, R](
      context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler)

  。。。
  答案找到了,下面的就不写了
  。。。。

你可能感兴趣的:(Spark)