spark累加器v2

自定义累加器类步骤:MyAccumulator
1.需继承AccumulatorV2
2.定义成员累加值并初始化 res
3.重写isZero方法 对累加值进行状态检查,系统内部调用
4.重写copy方法 拷贝一个新的AccumulatorV2,系统内部调用
5.重写reset方法 重置AccumulatorV2中的累加值res,系统内部调用
6.重写add方法 实现对res进行累加的逻辑,手动调用
7.重写merge方法 合并每条数据累加结果,系统内部调用
8.重写value方法 取得累加值,手动调用

使用步骤:
1.实例化一个自定义累加器对象 val myAcc = new MyAccumulator
2.sparkcontext注册累加器对象 sc.register(myAcc,“累加器名称自定义”)
3.遍历rdd每一条数据,按需执行累加器对象的add方法
4.在driver端通过累加器对象的value方法取得累加结果

原理步骤:
1.每一个partition调用copy方法把res初始值拷贝到每一个partition,copy前会调用isZero检查
2.每一个partition调用reset方法把刚拷贝到每个partition的res进行重置,reset前会调用isZero检查
3.每一个partition调用isZero方法检查每个partition的res的值是否重置成功,检查失败抛出AssertionError异常
4.手动进行累加时,每一个partition调用add方法在各自partition中进行累加操作
5.各个partition累加完成后依次调用merge方法对所有partition的累加结果逐个合并,实现merge前会先调用isZero检查即将合并进来的累加结果,若状态检查失败,则不进行合并
6.合并结果为累加最终结果
注意:在streaming中,每个batch开始前都会调用reset方法

import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.util.AccumulatorV2
import scala.collection.mutable.ArrayBuffer

class TestAccumulator extends AccumulatorV2[String, ArrayBuffer[String]] {
  private var res = ArrayBuffer[String]()

  override def isZero: Boolean = {
    println(s"check res is empty:$res")
    res.isEmpty
  }

  override def reset = {
    println(s"resetting $res")
    res.clear()
  }

  override def add(v: String): Unit = {
    println(s"add $v to $res")
    res += v
  }

  override def merge(other: AccumulatorV2[String, ArrayBuffer[String]]): Unit = {
    other match {
      case any: TestAccumulator =>
        println(s"merge ${this.res} with $any")
        res ++= any.value.diff(this.res)
      case _ => throw new UnsupportedOperationException(s"cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
    }
  }

  override def value: ArrayBuffer[String] = {
    println(s"getting res value:$res")
    res
  }

  override def copy(): AccumulatorV2[String, ArrayBuffer[String]] = {
    println(s"copying $res")
    val newTestAcc = new TestAccumulator
    newTestAcc.res = this.res
    newTestAcc
  }
}

object test_accumulator {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
    val conf = new SparkConf().setMaster("local[2]").setAppName("test accumulator")
    val sc = new SparkContext(conf)
    val rdd = sc.makeRDD(Seq("a", "b", "c", "d", "b", "a", "b", "e", "f", "z", "a"))
    val testAcc = new TestAccumulator
    sc.register(testAcc, "test acc")
    rdd.foreachPartition(part => {
      part.foreach(record => {
        if (testAcc.value.contains(record)) {
          println(s"already esist:$record")
        } else {
          testAcc.add(record)
        }
      })
    })
    println(testAcc.value)
    sc.stop()
  }
}

打印为:ArrayBuffer(a, b, c, d, e, f, z)

你可能感兴趣的:(spark,spark累加器)