Custom Accumulator in Spark 2.1
Accumulator can sum or count number in spark tasks over all nodes, and then return the final result. For example, LongAccumulator.
But when I want to accumulate value by custom way, it can be implemented by extends AccumulatorV2 in spark 2.1.
Below, a MapLongAccumulator is implemented to count the numbers of several values seperately.
import java.util
import java.util.Collections
import org.apache.spark.util.AccumulatorV2
import scala.collection.JavaConversions._
class MapLongAccumulator[T] extends AccumulatorV2[T, java.util.Map[T, java.lang.Long]] {
private val _map: java.util.Map[T, java.lang.Long] = Collections.synchronizedMap(new util.HashMap[T, java.lang.Long]())
override def isZero: Boolean = _map.isEmpty
override def copyAndReset(): MapLongAccumulator[T] = new MapLongAccumulator
override def copy(): MapLongAccumulator[T] = {
val newAcc = new MapLongAccumulator[T]
_map.synchronized {
newAcc._map.putAll(_map)
}
newAcc
}
override def reset(): Unit = _map.clear()
override def add(v: T): Unit = _map.synchronized {
val old = _map.put(v, 1l)
if (old != null) {
_map.put(v, 1 + old)
}
}
override def merge(other: AccumulatorV2[T, java.util.Map[T, java.lang.Long]]): Unit = other match {
case o: MapLongAccumulator[T] => {
for ((k,v) <- o.value) {
val old = _map.put(k, v)
if(old != null){
_map.put(k, old.longValue() + v)
}
}
}
case _ => throw new UnsupportedOperationException(
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
override def value: java.util.Map[T, java.lang.Long] = _map.synchronized {
java.util.Collections.unmodifiableMap(new util.HashMap[T, java.lang.Long](_map))
}
def setValue(newValue: java.util.Map[T, java.lang.Long]): Unit = {
_map.clear()
_map.putAll(newValue)
}
}
Use case:
val accumulator = new MapLongAccumulator[String]()
sc.register(accumulator, "CustomAccumulator")
someRdd.map(a => {
...
val convertResult = convert(a)
convertResultAccumulator.add(convertResult.toString)
...
}).repartition(1).saveAsTextFile(outputPath)
System.out.println(accumulator.value) // the several convertResults will be counted seperately, and then get the output value.
AccumulatorV2 is new accumulator
class since spark 2.0.0.
public abstract class AccumulatorV2
extends Object
implements scala.Serializable
The base class for accumulators, that can accumulate inputs of type IN, and produce output of type OUT.
OUT should be a type that can be read atomically (e.g., Int, Long), or thread-safely (e.g., synchronized collections) because it will be read from other threads.
Methods of AccumulatorV2:
merge
method doesn't need thread-safe, it is called only in one thread when task completion.(handleTaskCompletion
in DAGScheduler
)
org.apache.spark.scheduler.DAGScheduler.scala
/**
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
*/
private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
val taskId = event.taskInfo.id
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)
outputCommitCoordinator.taskCompleted(
stageId,
task.partitionId,
event.taskInfo.attemptNumber, // this is a task attempt number
event.reason)
// Reconstruct task metrics. Note: this may be null if the task has failed.
val taskMetrics: TaskMetrics =
if (event.accumUpdates.nonEmpty) {
try {
TaskMetrics.fromAccumulators(event.accumUpdates)
} catch {
case NonFatal(e) =>
logError(s"Error when attempting to reconstruct metrics for task $taskId", e)
null
}
} else {
null
}
......
val stage = stageIdToStage(task.stageId)
event.reason match {
case Success =>
stage.pendingPartitions -= task.partitionId
task match {
case rt: ResultTask[_, _] =>
// Cast to ResultStage here because it's part of the ResultTask
// TODO Refactor this out to a function that accepts a ResultStage
val resultStage = stage.asInstanceOf[ResultStage]
resultStage.activeJob match {
case Some(job) =>
if (!job.finished(rt.outputId)) {
updateAccumulators(event)
......
case smt: ShuffleMapTask =>
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
updateAccumulators(event)
val status = event.result.asInstanceOf[MapStatus]
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
} else {
shuffleStage.addOutputLoc(smt.partitionId, status)
}
......
case exceptionFailure: ExceptionFailure =>
// Tasks failed with exceptions might still have accumulator updates.
updateAccumulators(event)
......
}
/**
* Merge local values from a task into the corresponding accumulators previously registered
* here on the driver.
*
* Although accumulators themselves are not thread-safe, this method is called only from one
* thread, the one that runs the scheduling loop. This means we only handle one task
* completion event at a time so we don't need to worry about locking the accumulators.
* This still doesn't stop the caller from updating the accumulator outside the scheduler,
* but that's not our problem since there's nothing we can do about that.
*/
private def updateAccumulators(event: CompletionEvent): Unit = {
val task = event.task
val stage = stageIdToStage(task.stageId)
try {
event.accumUpdates.foreach { updates =>
val id = updates.id
// Find the corresponding accumulator on the driver and update it
val acc: AccumulatorV2[Any, Any] = AccumulatorContext.get(id) match {
case Some(accum) => accum.asInstanceOf[AccumulatorV2[Any, Any]]
case None =>
throw new SparkException(s"attempted to access non-existent accumulator $id")
}
acc.merge(updates.asInstanceOf[AccumulatorV2[Any, Any]])
// To avoid UI cruft, ignore cases where value wasn't updated
if (acc.name.isDefined && !updates.isZero) {
stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value))
event.taskInfo.accumulables += acc.toInfo(Some(updates.value), Some(acc.value))
}
}
} catch {
case NonFatal(e) =>
logError(s"Failed to update accumulators for task ${task.partitionId}", e)
}
}
org.apache.spark.executor.TaskMetrics.scala
/**
* Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only.
*/
def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = {
val tm = new TaskMetrics
val (internalAccums, externalAccums) =
accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get))
internalAccums.foreach { acc =>
val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]]
tmAcc.metadata = acc.metadata
tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
}
tm.externalAccums ++= externalAccums
tm
}