spark2.1.0自定义累加器AccumulatorV2的使用


类继承AccumulatorV2
class MyAccumulatorV2 extends AccumulatorV2[String, String]
覆写抽象方法:

/** * @author lcjasas * @version 1.0 * @since 2017-01-14 10:19 AM. */
class MyAccumulatorV2 extends AccumulatorV2[String, String] {

  override def isZero: Boolean = ???

  override def copy(): AccumulatorV2[String, String] = ???

  override def reset(): Unit = ???

  override def add(v: String): Unit = ???

  override def merge(other: AccumulatorV2[String, String]): Unit = ???

  override def value: String = ???

}

isZero: 当AccumulatorV2中存在类似数据不存在这种问题时,是否结束程序。
copy: 拷贝一个新的AccumulatorV2
reset: 重置AccumulatorV2中的数据
add: 操作数据累加方法实现
merge: 合并数据
value: AccumulatorV2对外访问的数据结果

下面自定义一个累加器,实现一个字符串中的数据累加统计:

import cn.lcj.project1.utils.StringUtils
import org.apache.spark.util.AccumulatorV2
import org.slf4j.LoggerFactory

/** * @author lcjasas * @version 1.0 * @since 2017-01-14 10:19 AM. */
class MyAccumulatorV2 extends AccumulatorV2[String, String] {

  private val log = LoggerFactory.getLogger("MyAccumulatorV2")

  var result = "user0=0|user1=0|user2=0|user3=0" // 初始值

  override def isZero: Boolean = {
    true
  }

  override def copy(): AccumulatorV2[String, String] = {
    val myAccumulator = new MyAccumulatorV2()
    myAccumulator.result = this.result
    myAccumulator
  }

  override def reset(): Unit = {
    result = "user0=0|user1=0|user2=0|user3=0"
  }

  override def add(v: String): Unit = {
    val v1 = result
    val v2 = v
    // log.warn("v1 : " + v1 + " v2 : " + v2)
    if (StringUtils.isNotEmpty(v1) && StringUtils.isNotEmpty(v2)) {
      var newResult = ""
      // 从v1中,提取v2对应的值,并累加
      val oldValue = StringUtils.getFieldFromConcatString(v1, "\\|", v2)
      if (oldValue != null) {
        val newValue = oldValue.toInt + 1
        newResult = StringUtils.setFieldInConcatString(v1, "\\|", v2, String.valueOf(newValue))
      }
      result = newResult
    }
  }

  override def merge(other: AccumulatorV2[String, String]) = other match {
    case map: MyAccumulatorV2 =>
      result = other.value
    case _ =>
      throw new UnsupportedOperationException(
        s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
  }

  override def value: String = {
    result
  }

}

StringUtils中的方法

/** * 从拼接的字符串中提取字段 * * @param str 字符串 * @param delimiter 分隔符 * @param field 字段 * @return 字段值 */
  def getFieldFromConcatString(str: String, delimiter: String, field: String): String = {
    val fields = str.split(delimiter)
    var result = "0"
    for (concatField <- fields) {
      if (concatField.split("=").length == 2) {
        val fieldName = concatField.split("=")(0)
        val fieldValue = concatField.split("=")(1)
        if (fieldName == field) {
          result = fieldValue
        }
      }
    }
    result
  }

    /** * 从拼接的字符串中给字段设置值 * * @param str 字符串 * @param delimiter 分隔符 * @param field 字段名 * @param newFieldValue 新的field值 * @return 字段值 */
    def setFieldInConcatString(str: String, delimiter: String, field: String, newFieldValue: String): String = {
      val fields = str.split(delimiter)

      val buffer = new StringBuffer("")
      for (item <- fields) {
        val fieldName = item.split("=")(0)
        if (fieldName == field) {
          val concatField = fieldName + "=" + newFieldValue
          buffer.append(concatField).append("|")
        } else {
          buffer.append(item).append("|")
        }
      }
      buffer.substring(0, buffer.length() - 1)
    }

  }

使用
在spark中使用:

    val accumulator = new MyAccumulatorV2()
    sc.register(accumulator)

需要注册,不然在运行过程中,会抛出一个序列化异常。

你可能感兴趣的:(spark)