自定义累加器类步骤:MyAccumulator
1.需继承AccumulatorV2
2.定义成员累加值并初始化 res
3.重写isZero方法 对累加值进行状态检查,系统内部调用
4.重写copy方法 拷贝一个新的AccumulatorV2,系统内部调用
5.重写reset方法 重置AccumulatorV2中的累加值res,系统内部调用
6.重写add方法 实现对res进行累加的逻辑,手动调用
7.重写merge方法 合并每条数据累加结果,系统内部调用
8.重写value方法 取得累加值,手动调用
使用步骤:
1.实例化一个自定义累加器对象 val myAcc = new MyAccumulator
2.sparkcontext注册累加器对象 sc.register(myAcc,“累加器名称自定义”)
3.遍历rdd每一条数据,按需执行累加器对象的add方法
4.在driver端通过累加器对象的value方法取得累加结果
原理步骤:
1.每一个partition调用copy方法把res初始值拷贝到每一个partition,copy前会调用isZero检查
2.每一个partition调用reset方法把刚拷贝到每个partition的res进行重置,reset前会调用isZero检查
3.每一个partition调用isZero方法检查每个partition的res的值是否重置成功,检查失败抛出AssertionError异常
4.手动进行累加时,每一个partition调用add方法在各自partition中进行累加操作
5.各个partition累加完成后依次调用merge方法对所有partition的累加结果逐个合并,实现merge前会先调用isZero检查即将合并进来的累加结果,若状态检查失败,则不进行合并
6.合并结果为累加最终结果
注意:在streaming中,每个batch开始前都会调用reset方法
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.util.AccumulatorV2
import scala.collection.mutable.ArrayBuffer
class TestAccumulator extends AccumulatorV2[String, ArrayBuffer[String]] {
private var res = ArrayBuffer[String]()
override def isZero: Boolean = {
println(s"check res is empty:$res")
res.isEmpty
}
override def reset = {
println(s"resetting $res")
res.clear()
}
override def add(v: String): Unit = {
println(s"add $v to $res")
res += v
}
override def merge(other: AccumulatorV2[String, ArrayBuffer[String]]): Unit = {
other match {
case any: TestAccumulator =>
println(s"merge ${this.res} with $any")
res ++= any.value.diff(this.res)
case _ => throw new UnsupportedOperationException(s"cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
}
override def value: ArrayBuffer[String] = {
println(s"getting res value:$res")
res
}
override def copy(): AccumulatorV2[String, ArrayBuffer[String]] = {
println(s"copying $res")
val newTestAcc = new TestAccumulator
newTestAcc.res = this.res
newTestAcc
}
}
object test_accumulator {
def main(args: Array[String]): Unit = {
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
val conf = new SparkConf().setMaster("local[2]").setAppName("test accumulator")
val sc = new SparkContext(conf)
val rdd = sc.makeRDD(Seq("a", "b", "c", "d", "b", "a", "b", "e", "f", "z", "a"))
val testAcc = new TestAccumulator
sc.register(testAcc, "test acc")
rdd.foreachPartition(part => {
part.foreach(record => {
if (testAcc.value.contains(record)) {
println(s"already esist:$record")
} else {
testAcc.add(record)
}
})
})
println(testAcc.value)
sc.stop()
}
}
打印为:ArrayBuffer(a, b, c, d, e, f, z)