Spark之自定义AccumulatorV2

本文介绍如何使用Spark2中自定义累加器来实现数据的统计。

Spark2.x之后,之前的的accumulator被废除,用AccumulatorV2代替;

1.自定义Accumulator

class StrAccumulator extends AccumulatorV2[String, String] {

  // a=10|b=20
  private var v = ""

  override def isZero: Boolean = v == ""

  override def copy(): AccumulatorV2[String, String] = {
    val newAcc = new StrAccumulator
    newAcc.v = this.v
    newAcc
  }

  override def reset(): Unit = v = ""

  override def add(v: String): Unit = {
    if (v == null || v == "") {
      return this.v
    }

    val oldValue = getFieldFromConcatString(this.v, "\\|", v)
    if (oldValue != null) {
      val newValue = (oldValue.toInt + 1).toString
      this.v = setFieldInConcatString(this.v, "\\|", v, newValue)
    } else {
      if (isZero) {
        this.v = v + "=" + 1
      } else {
        this.v = this.v + "|" + v + "=" + 1
      }
    }

    this.v
  }

  override def merge(other: AccumulatorV2[String, String]): Unit = other match {
    case o: StrAccumulator => v += o.v
    case _ => throw new UnsupportedOperationException(
      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
  }

  override def value: String = v

  def getFieldFromConcatString(str: String, delimiter: String, field: String): String = {
    val fields = str.split(delimiter)
    for (concatField <- fields) {
      if (concatField.split("=").length == 2) {
        val fieldName = concatField.split("=")(0)
        val fieldValue = concatField.split("=")(1)
        if (fieldName == field)
          return fieldValue
      }
    }
    null
  }

  def setFieldInConcatString(str: String, delimiter: String, field: String, newValue: String): String = {
    val fields = str.split(delimiter)

    var break = false
    for (i <- 0 until fields.length if !break) {
      if (fields(i).split("=")(0) == field) {
        val concatField = field + "=" + newValue
        fields(i) = concatField
        break = true
      }
    }

    fields.mkString("|")
  }
}

2.使用

需求:统计Session总数的时候,同时计算Session的步长

测试数据

session-1   1
session-1   2
session-1   3
session-2   1
session-2   2

测试代码

object AccumulatorTest {
  def main(args: Array[String]): Unit = {
    //创建一个Config
    val conf = new SparkConf()
      .setAppName("AccumulatorTest")
      .setMaster("local")

    //核心创建SparkContext对象
    val sc = new SparkContext(conf)

    // 注册累加器
    val strAccumulator = new StrAccumulator
    sc.register(strAccumulator)

    //WordCount
    sc.textFile("D:\\workspaces\\idea\\hadoop\\spark\\data\\session.txt")
      .map(line => {
        val lines = line.split("\t")
        val sessionId = lines(0)
        val pageId = lines(1)

        // 累加统计
        strAccumulator.add(sessionId)

        (sessionId, 1L)
      })
      .reduceByKey(_ + _)
      .sortBy(_._2, false)
      .foreach(println)

    println(strAccumulator.value)

    //停止SparkContext对象
    sc.stop()
  }
}

打印结果

(session-1,123)
(session-2,12)

session-1=3|session-2=2

这样在统计Session数量的同时,也计算了每个session的步长,当然还可以计算其它属性。比如每个Session的会话时长,会话时长区间统计,会话步长区间统计等等。

你可能感兴趣的:(Spark之自定义AccumulatorV2)