spark mapWithState 实现

mapWithState()可以保存流的状态,并能做到当前rdd和前一段时间的rdd进行比较或者聚合。

当stream调用mapWithState()方法的时候,将会返回一个MapWithStateDStreamImpl。

@Experimental
def mapWithState[StateType: ClassTag, MappedType: ClassTag](
    spec: StateSpec[K, V, StateType, MappedType]
  ): MapWithStateDStream[K, V, StateType, MappedType] = {
  new MapWithStateDStreamImpl[K, V, StateType, MappedType](
    self,
    spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
  )
}

override val mustCheckpoint = true

/** Override the default checkpoint duration */
override def initialize(time: Time): Unit = {
  if (checkpointDuration == null) {
    checkpointDuration = slideDuration * DEFAULT_CHECKPOINT_DURATION_MULTIPLIER
  }
  super.initialize(time)
}

private val internalStream =
  new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
override def compute(validTime: Time): Option[RDD[MappedType]] = {
  internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
}

其强制进行checkpoint,并规定了checkpoint的时间间隔。

当其compute()方法被调用的时候,真正经过调用的是其内部类InternalMapWithStateDStream的compute()方法。

override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
  // Get the previous state or create a new empty state RDD
  val prevStateRDD = getOrCompute(validTime - slideDuration) match {
    case Some(rdd) =>
      if (rdd.partitioner != Some(partitioner)) {
        // If the RDD is not partitioned the right way, let us repartition it using the
        // partition index as the key. This is to ensure that state RDD is always partitioned
        // before creating another state RDD using it
        MapWithStateRDD.createFromRDD[K, V, S, E](
          rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
      } else {
        rdd
      }
    case None =>
      MapWithStateRDD.createFromPairRDD[K, V, S, E](
        spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
        partitioner,
        validTime
      )
  }


  // Compute the new state RDD with previous state RDD and partitioned data RDD
  // Even if there is no data RDD, use an empty one to create a new state RDD
  val dataRDD = parent.getOrCompute(validTime).getOrElse {
    context.sparkContext.emptyRDD[(K, V)]
  }
  val partitionedDataRDD = dataRDD.partitionBy(partitioner)
  val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
    (validTime - interval).milliseconds
  }
  Some(new MapWithStateRDD(
    prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
}

首先,获取上一个时间窗口的rdd,如果没有,则新建一个空的MapWithStateRDD作为下一个时间窗口可以访问到的前置rdd。如果取得到,那么判断前置rdd是否和当前分区一致,如果一致,则直接获得,否则也是以前置rdd为基础新建一个当前分区情况的MapWithStateRDD。

之后获取当前的实时rdd并分区,将该rdd和在方法之初获得的前置rdd为参数,构造新的MapWithStateRDD。

 

最后看到这个MapWithStateRDD的compue()方法的实现。

override def compute(
    partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {

  val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
  val prevStateRDDIterator = prevStateRDD.iterator(
    stateRDDPartition.previousSessionRDDPartition, context)
  val dataIterator = partitionedDataRDD.iterator(
    stateRDDPartition.partitionedDataRDDPartition, context)

  val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
  val newRecord = MapWithStateRDDRecord.updateRecordWithData(
    prevRecord,
    dataIterator,
    mappingFunction,
    batchTime,
    timeoutThresholdTime,
    removeTimedoutData = doFullScan // remove timed-out data only when full scan is enabled
  )
  Iterator(newRecord)
}

根据分区获得前置rdd和当前rdd的分区的数据,调用MapWithStateRDDRecord的updateRecordWithData()方法根据用户所定义的方法去更新新的rdd。

def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
  prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
  dataIterator: Iterator[(K, V)],
  mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
  batchTime: Time,
  timeoutThresholdTime: Option[Long],
  removeTimedoutData: Boolean
): MapWithStateRDDRecord[K, S, E] = {
  // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
  val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }

  val mappedData = new ArrayBuffer[E]
  val wrappedState = new StateImpl[S]()

  // Call the mapping function on each record in the data iterator, and accordingly
  // update the states touched, and collect the data returned by the mapping function
  dataIterator.foreach { case (key, value) =>
    wrappedState.wrap(newStateMap.get(key))
    val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
    if (wrappedState.isRemoved) {
      newStateMap.remove(key)
    } else if (wrappedState.isUpdated
        || (wrappedState.exists && timeoutThresholdTime.isDefined)) {
      newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
    }
    mappedData ++= returned
  }

  // Get the timed out state records, call the mapping function on each and collect the
  // data returned
  if (removeTimedoutData && timeoutThresholdTime.isDefined) {
    newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
      wrappedState.wrapTimingOutState(state)
      val returned = mappingFunction(batchTime, key, None, wrappedState)
      mappedData ++= returned
      newStateMap.remove(key)
    }
  }

  MapWithStateRDDRecord(newStateMap, mappedData)
}

首先获取前置rdd的所有数据,并遍历当前的rdd的key,如果能够在前置的rdd中获取得到相应的key,那么就获取之前rdd的键值对调用用户定义的mappingFunction执行用户所定义的逻辑。

根据用户方法返回的结果执行更新或者移出操作,并返回,达成实时rdd与历史rdd比较归并的目的。

你可能感兴趣的:(spark)