首先简单解释一下什么是state(状态)管理?我们以wordcount为例。每个batchInterval会计算当前batch的单词计数,那如果需要计算从流开始到目前为止的单词出现的次数,该如计算呢?SparkStreaming提供了两种方法:updateStateByKey和mapWithState 。mapWithState 是1.6版本新增功能,目前属于实验阶段。mapWithState具官方说性能较updateStateByKey提升10倍。那么我们来看看他们到底是如何实现的。
一、updateStateByKey 解析
1.1 updateStateByKey 的使用实例
首先看一个
updateStateByKey函数使用的例子:
object UpdateStateByKeyDemo{
def main(args:Array[String]){
val conf =newSparkConf().setAppName("UpdateStateByKeyDemo")
val ssc =newStreamingContext(conf,Seconds(20))
//要使用updateStateByKey方法,必须设置Checkpoint。
ssc.checkpoint("/checkpoint/")
val socketLines = ssc.socketTextStream("localhost",9999)
socketLines.flatMap(_.split(",")).map(word=>(word,1))
.updateStateByKey(
(currValues:Seq[Int],preValue:Option[Int])=>{
val currValue = currValues.sum //将目前值相加
Some(currValue + preValue.getOrElse(0)) //目前值的和加上历史值
}).print()
ssc.start()
ssc.awaitTermination()
ssc.stop()
}
}
代码很简单,关键地方写了详细的注释。
1.2
updateStateByKey 方法源码分析
我们知道map返回的是MappedDStream,而MappedDStream并没有updateStateByKey方法,并且它的父类DStream中也没有该方法。
但是DStream的伴生对象中有一个隐式转换函数
implicit def toPairDStreamFunctions[K, V](stream:DStream[(K, V)])
(implicit kt:ClassTag[K], vt:ClassTag[V], ord:Ordering[K]=null):
PairDStreamFunctions[K, V]={
newPairDStreamFunctions[K, V](stream)
}
PairDStreamFunction 中updateStateByKey的源码如下:
def updateStateByKey[S:ClassTag](
updateFunc:(Seq[V],Option[S])=>Option[S]
):DStream[(K, S)]= ssc.withScope {
updateStateByKey(updateFunc, defaultPartitioner())
}
其中updateFunc就要传入的参数,他是一个函数,Seq[V]表示当前key对应的所有值,Option[S] 是当前key的历史状态,返回的是新的状态。
最终会调用下面的方法:
def updateStateByKey[S:ClassTag](
updateFunc:(Iterator[(K,Seq[V],Option[S])])=>Iterator[(K, S)],
partitioner:Partitioner,
rememberPartitioner:Boolean
):DStream[(K, S)]= ssc.withScope {
newStateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner,None)
}
在这里面new出了一个StateDStream对象。在其compute方法中,会先获取上一个batch计算出的RDD(包含了至程序开始到上一个batch单词的累计计数),然后在获取本次batch中StateDStream的父类计算出的RDD(本次batch的单词计数)分别是prevStateRDD和parentRDD,然后在调用
computeUsingPreviousRDD 方法:
private[this] def computeUsingPreviousRDD (
parentRDD: RDD[(K, V)], prevStateRDD: RDD[(K, S)])={
// Define the function for the mapPartition operation on cogrouped RDD;
// first map the cogrouped tuple to tuples of required type,
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc =(iterator:Iterator[(K,(Iterable[V],Iterable[S]))])=>{
val i = iterator.map { t =>
val itr = t._2._2.iterator
val headOption =if(itr.hasNext)Some(itr.next())elseNone
(t._1, t._2._1.toSeq, headOption)
}
updateFuncLocal(i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
Some(stateRDD)
}
两个RDD进行cogroup然后应用updateStateByKey传入的函数。cogroup的性能是比较低下的。
二、mapWithState方法解析
2.1 mapWithState方法使用实例:
object StatefulNetworkWordCount{
def main(args:Array[String]){
if(args.length <2){
System.err.println("Usage: StatefulNetworkWordCount <hostname> <port>")
System.exit(1)
}
StreamingExamples.setStreamingLogLevels()
val sparkConf =newSparkConf().setAppName("StatefulNetworkWordCount")
// Create the context with a 1 second batch size
val ssc =newStreamingContext(sparkConf,Seconds(1))
ssc.checkpoint(".")
// Initial state RDD for mapWithState operation
val initialRDD = ssc.sparkContext.parallelize(List(("hello",1),("world",1)))
// Create a ReceiverInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
val lines = ssc.socketTextStream(args(0), args(1).toInt)
val words = lines.flatMap(_.split(" "))
val wordDstream = words.map(x =>(x,1))
// Update the cumulative count using mapWithState
// This will give a DStream made of state (which is the cumulative count of the words)
val mappingFunc =(word:String, one:Option[Int], state:State[Int])=>{
val sum = one.getOrElse(0)+ state.getOption.getOrElse(0)
val output =(word, sum)
state.update(sum)
output
}
val stateDstream = wordDstream.mapWithState(
StateSpec.function(mappingFunc).initialState(initialRDD))
stateDstream.print()
ssc.start()
ssc.awaitTermination()
}
}
def mapWithState[StateType:ClassTag,MappedType:ClassTag](
spec:StateSpec[K, V,StateType,MappedType]
):MapWithStateDStream[K, V,StateType,MappedType]={
newMapWithStateDStreamImpl[K, V,StateType,MappedType](
self,
spec.asInstanceOf[StateSpecImpl[K, V,StateType,MappedType]]
)
}
MapWithStateDStreamImpl 中创建了一个
InternalMapWithStateDStream类型对象
internalStream,在
MapWithStateDStreamImpl的compute方法中调用了
internalStream的getOrCompute方法。
/** Internal implementation of the [[MapWithStateDStream]] */
private[streaming]classMapWithStateDStreamImpl[
KeyType:ClassTag,ValueType:ClassTag,StateType:ClassTag,MappedType:ClassTag](
dataStream:DStream[(KeyType,ValueType)],
spec:StateSpecImpl[KeyType,ValueType,StateType,MappedType])
extendsMapWithStateDStream[KeyType,ValueType,StateType,MappedType](dataStream.context){
private val internalStream =
newInternalMapWithStateDStream[KeyType,ValueType,StateType,MappedType](dataStream, spec)
override def slideDuration:Duration= internalStream.slideDuration
override def dependencies:List[DStream[_]]=List(internalStream)
override def compute(validTime:Time):Option[RDD[MappedType]]={
internalStream.getOrCompute(validTime).map { _.flatMap[MappedType]{ _.mappedData }}
}
InternalMapWithStateDStream
中没有getOrCompute方法,这里调用的是其父类 DStream 的getOrCpmpute方法,该方法中最终会调用
InternalMapWithStateDStream的Compute方法:
/** Method that generates a RDD for the given time */
override def compute(validTime:Time):Option[RDD[MapWithStateRDDRecord[K, S, E]]]={
// Get the previous state or create a new empty state RDD
val prevStateRDD = getOrCompute(validTime - slideDuration) match {
caseSome(rdd)=>
if(rdd.partitioner !=Some(partitioner)){
// If the RDD is not partitioned the right way, let us repartition it using the
// partition index as the key. This is to ensure that state RDD is always partitioned
// before creating another state RDD using it
MapWithStateRDD.createFromRDD[K, V, S, E](
rdd.flatMap { _.stateMap.getAll()}, partitioner, validTime)
}else{
rdd
}
caseNone=>
MapWithStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(newEmptyRDD[(K, S)](ssc.sparkContext)),
partitioner,
validTime
)
}
// Compute the new state RDD with previous state RDD and partitioned data RDD
// Even if there is no data RDD, use an empty one to create a new state RDD
val dataRDD = parent.getOrCompute(validTime).getOrElse {
context.sparkContext.emptyRDD[(K, V)]
}
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
Some(newMapWithStateRDD(
prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
}
根据给定的时间生成一个MapWithStateRDD,首先获取了先前状态的RDD:preStateRDD和当前时间的RDD:dataRDD,然后对dataRDD基于先前状态RDD的分区器进行重新分区获取partitionedDataRDD。最后将
preStateRDD,
partitionedDataRDD和用户定义的函数mappingFunction传给新生成的
MapWithStateRDD对象返回。
下面看一下
MapWithStateRDD的compute方法:
override def compute(
partition:Partition, context:TaskContext):Iterator[MapWithStateRDDRecord[K, S, E]]={
val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
val prevStateRDDIterator = prevStateRDD.iterator(
stateRDDPartition.previousSessionRDDPartition, context)
val dataIterator = partitionedDataRDD.iterator(
stateRDDPartition.partitionedDataRDDPartition, context)
//prevRecord 代表一个分区的数据
val prevRecord =if(prevStateRDDIterator.hasNext)Some(prevStateRDDIterator.next())elseNone
val newRecord =MapWithStateRDDRecord.updateRecordWithData(
prevRecord,
dataIterator,
mappingFunction,
batchTime,
timeoutThresholdTime,
removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
)
Iterator(newRecord)
}
MapWithStateRDDRecord 对应
MapWithStateRDD
的一个分区:
private[streaming]caseclassMapWithStateRDDRecord[K, S, E](
var stateMap:StateMap[K, S], var mappedData:Seq[E])
其中stateMap存储了key的状态,mappedData存储了mapping function函数的返回值
看一下
MapWithStateRDDRecord的
updateRecordWithData方法
def updateRecordWithData[K:ClassTag, V:ClassTag, S:ClassTag, E:ClassTag](
prevRecord:Option[MapWithStateRDDRecord[K, S, E]],
dataIterator:Iterator[(K, V)],
mappingFunction:(Time, K,Option[V],State[S])=>Option[E],
batchTime:Time,
timeoutThresholdTime:Option[Long],
removeTimedoutData:Boolean
):MapWithStateRDDRecord[K, S, E]={
// 创建一个新的 state map 从过去的Recoord中复制 (如果存在) 否则创建一下空的StateMap对象
val newStateMap = prevRecord.map { _.stateMap.copy()}. getOrElse {newEmptyStateMap[K, S]()}
val mappedData =newArrayBuffer[E]
//状态
val wrappedState =newStateImpl[S]()
// Call the mapping function on each record in the data iterator, and accordingly
// update the states touched, and collect the data returned by the mapping function
dataIterator.foreach {case(key, value)=>
//获取key对应的状态
wrappedState.wrap(newStateMap.get(key))
//调用mappingFunction获取返回值
val returned = mappingFunction(batchTime, key,Some(value), wrappedState)
//维护
newStateMap的值
if(wrappedState.isRemoved){
newStateMap.remove(key)
}elseif(wrappedState.isUpdated
||(wrappedState.exists && timeoutThresholdTime.isDefined)){
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
mappedData ++= returned
}
// Get the timed out state records, call the mapping function on each and collect the
// data returned
if(removeTimedoutData && timeoutThresholdTime.isDefined){
newStateMap.getByTime(timeoutThresholdTime.get).foreach {case(key, state, _)=>
wrappedState.wrapTimingOutState(state)
val returned = mappingFunction(batchTime, key,None, wrappedState)
mappedData ++= returned
newStateMap.remove(key)
}
}
MapWithStateRDDRecord(newStateMap, mappedData)
}
最终返回
MapWithStateRDDRecord对象交个
MapWithStateRDD的compute函数,
MapWithStateRDD的compute函数将其封装成Iterator返回。