Spark共享变量之累加器

来看一个简单的例子,需求是:统计单词的个数。

    val data: RDD[String] = sc.makeRDD(Seq("hadoop map reduce", "spark mllib"))
    // 方式1
    val count1: Int = data.flatMap(line => line.split(" ")).map(word => 1).reduce(_ + _)
    println(count1)
    // 方式2
    var acc = 0
    data.flatMap(line => line.split("")).map(word => {acc += 1; word})
    println(acc)

方式1使用spark提供的RDD算子实现需求,而方式2,我们在驱动程序中定义了一个变量acc,在map算子中每次加1来实现单词统计,最终的结果如下,
输出

5
0

可以发现,方式1正确实现了需求,而方式2却不行。这是因为,在驱动器程序中定义的变量,集群中运行的每个Task都会得到这些变量的一份新的副本,在Task中更新这些副本的值不会影响驱动器中的对应变量。
在处理分片时,如果想要实现更新共享变量的功能,就需要用到“累加器”。

系统累加器

Spark内置了三种类型的累加器,分别是

  1. LongAccumulator用来累加整数型;
  2. DoubleAccumulator用来累加浮点型;
  3. CollectionAccumulator用来累加集合元素
val totalNum1: LongAccumulator = sc.longAccumulator("totalNum1")
val totalNum2: DoubleAccumulator = sc.doubleAccumulator("totalNum2")
val allWords: CollectionAccumulator[String] = sc.collectionAccumulator[String]("allWords")
data.foreach(
  line => {
    val words: Array[String] = line.split(" ")
    totalNum1.add(words.length)
    totalNum2.add(words.length)
    words.foreach(allWords.add(_))
  }
)
println(totalNum1.value)
println(totalNum2.value)
println(allWords.value)
5
5.0
[hadoop, map, reduce, spark, mllib]

累加器的add(v)方法将v添加进累加器(LongAccumulator和DoubleAccumulator为对值累加,CollectionAccumulator为将v添加进_list: java.util.List[T]),累加器的value用于获取累加器的值。

自定义累加器

有时候,Spark内置的累加器无法满足需求,可以自定义累加器。

  1. 继承抽象类AccumulatorV2[IN, OUT],重写相关方法;
  2. 创建自定义Accumulator的实例,然后通过SparkContext.register(acc: AccumulatorV2[_, _], name: String)注册累加器。

AccumulatorV2 can accumulate inputs of type IN, and produce output of type OUT. OUT should be a type that can be read atomically (e.g., Int, Long), or thread-safely (e.g., synchronized collections) because it will be read from other threads.
自定义一个实现WordCount功能的累加器。

class MyAccumulator extends AccumulatorV2[String, util.Map[String, Int]] {
  private val _map: util.Map[String, Int] = Collections.synchronizedMap(new util.HashMap[String, Int]())

  override def isZero: Boolean = _map.isEmpty

  override def copyAndReset(): MyAccumulator = new MyAccumulator

  override def copy(): MyAccumulator = {
    val newAcc = new MyAccumulator
    _map.synchronized {
      newAcc._map.putAll(_map)
    }
    newAcc
  }

  override def reset(): Unit = _map.clear()

  override def add(v: String): Unit = {
    val i = _map.getOrDefault(v, 0)
    _map.put(v, i+1)
  }

  override def merge(other: AccumulatorV2[String, java.util.Map[String, Int]]): Unit = other match {
    case o: MyAccumulator => {
      val iter: util.Iterator[Map.Entry[String, Int]] = other.value.entrySet().iterator()
      while (iter.hasNext) {
        val entry: Map.Entry[String, Int] = iter.next()
        _map.put(entry.getKey, entry.getValue+_map.getOrDefault(entry.getKey, 0))
      }
    }
    case _ => throw new UnsupportedOperationException(
      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
  }

  override def value: java.util.Map[String, Int] = _map.synchronized {
    java.util.Collections.unmodifiableMap(new util.HashMap[String, Int](_map))
  }
}

 val myAcc: MyAccumulator = new MyAccumulator
    sc.register(myAcc, "myAcc")
    data.foreach(
      line => {
        val words = line.split(" ")
        words.foreach(myAcc.add(_))
      }
    )
 println(myAcc.value)
{reduce=1, hadoop=2, mllib=1, spark=1, map=1}

注意事项

  1. 工作节点上的任务不能访问累加器的值。从这些任务的角度来看,累加器是一个只写变量;
  2. 累加器的最终结果应该不受累加顺序的影响(CollectionAccumulator可以将结果集看做是一个可以有重复元素的无序Set);
  3. 如果累加器在spark的transform算子中调用add,可能会导致重复更新,最好将累加器的add操作放在 foreach() 这样的action算子中。

你可能感兴趣的:(spark)