object TaskContext
object TaskContext {
/**
* Return the currently active TaskContext. This can be called inside of
* user functions to access contextual information about running tasks.
*/
def get(): TaskContext = taskContext.get
/**
* Returns the partition id of currently active TaskContext. It will return 0
* if there is no active TaskContext for cases like local execution.
*/
def getPartitionId(): Int = {
val tc = taskContext.get()
if (tc eq null) {
0
} else {
tc.partitionId()
}
}
private[this] val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext]
/**
* Set the thread local TaskContext. Internal to Spark.
*/
protected[spark] def setTaskContext(tc: TaskContext): Unit = taskContext.set(tc)
/**
* Unset the thread local TaskContext. Internal to Spark.
*/
protected[spark] def unset(): Unit = taskContext.remove()
/**
* An empty task context that does not represent an actual task. This is only used in tests.
*/
private[spark] def empty(): TaskContextImpl = {
new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
}
}
/**
* Contextual information about a task which can be read or mutated during
* execution. To access the TaskContext for a running task, use:
* {{{
* org.apache.spark.TaskContext.get()
* }}}
*/
abstract class TaskContext extends Serializable {
/**
* Returns true if the task has completed.
*/
def isCompleted(): Boolean
/**
* Returns true if the task has been killed.
*/
def isInterrupted(): Boolean
/**
* Returns true if the task is running locally in the driver program.
* @return false
*/
@deprecated("Local execution was removed, so this always returns false", "2.0.0")
def isRunningLocally(): Boolean
/**
* Adds a (Java friendly) listener to be executed on task completion.
* This will be called in all situation - success, failure, or cancellation.
* An example use is for HadoopRDD to register a callback to close the input stream.
*
* Exceptions thrown by the listener will result in failure of the task.
*/
def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
/**
* Adds a listener in the form of a Scala closure to be executed on task completion.
* This will be called in all situations - success, failure, or cancellation.
* An example use is for HadoopRDD to register a callback to close the input stream.
*
* Exceptions thrown by the listener will result in failure of the task.
*/
def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = {
addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = f(context)
})
}
/**
* Adds a listener to be executed on task failure.
* Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
*/
def addTaskFailureListener(listener: TaskFailureListener): TaskContext
/**
* Adds a listener to be executed on task failure.
* Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
*/
def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = {
addTaskFailureListener(new TaskFailureListener {
override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error)
})
}
/**
* The ID of the stage that this task belong to.
*/
def stageId(): Int
/**
* The ID of the RDD partition that is computed by this task.
*/
def partitionId(): Int
/**
* How many times this task has been attempted. The first task attempt will be assigned
* attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
*/
def attemptNumber(): Int
/**
* An ID that is unique to this task attempt (within the same SparkContext, no two task attempts
* will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID.
*/
def taskAttemptId(): Long
/**
* Get a local property set upstream in the driver, or null if it is missing. See also
* `org.apache.spark.SparkContext.setLocalProperty`.
*/
def getLocalProperty(key: String): String
@DeveloperApi
def taskMetrics(): TaskMetrics
/**
* ::DeveloperApi::
* Returns all metrics sources with the given name which are associated with the instance
* which runs the task. For more information see `org.apache.spark.metrics.MetricsSystem`.
*/
@DeveloperApi
def getMetricsSources(sourceName: String): Seq[Source]
/**
* Returns the manager for this task's managed memory.
*/
private[spark] def taskMemoryManager(): TaskMemoryManager
/**
* Register an accumulator that belongs to this task. Accumulators must call this method when
* deserializing in executors.
*/
private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit
}
TaskContextImpl
private[spark] class TaskContextImpl(
val stageId: Int,
val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
localProperties: Properties,
@transient private val metricsSystem: MetricsSystem,
// The default value is only used in tests.
override val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
with Logging {
/** List of callback functions to execute when the task completes. */
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
/** List of callback functions to execute when the task fails. */
@transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener]
@volatile private var interrupted: Boolean = false
@volatile private var completed: Boolean = false
@volatile private var failed: Boolean = false
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
}
override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
onFailureCallbacks += listener
this
}
/** Marks the task as failed and triggers the failure listeners. */
private[spark] def markTaskFailed(error: Throwable): Unit = {
if (failed) return
failed = true
val errorMsgs = new ArrayBuffer[String](2)
onFailureCallbacks.reverse.foreach { listener =>
try {
listener.onTaskFailure(this, error)
} catch {
case e: Throwable =>
errorMsgs += e.getMessage
logError("Error in TaskFailureListener", e)
}
}
if (errorMsgs.nonEmpty) {
throw new TaskCompletionListenerException(errorMsgs, Option(error))
}
}
/** Marks the task as completed and triggers the completion listeners. */
private[spark] def markTaskCompleted(): Unit = {
completed = true
val errorMsgs = new ArrayBuffer[String](2)
onCompleteCallbacks.reverse.foreach { listener =>
try {
listener.onTaskCompletion(this)
} catch {
case e: Throwable =>
errorMsgs += e.getMessage
logError("Error in TaskCompletionListener", e)
}
}
if (errorMsgs.nonEmpty) {
throw new TaskCompletionListenerException(errorMsgs)
}
}
/** Marks the task for interruption, i.e. cancellation. */
private[spark] def markInterrupted(): Unit = {
interrupted = true
}
override def isCompleted(): Boolean = completed
override def isRunningLocally(): Boolean = false
override def isInterrupted(): Boolean = interrupted
override def getLocalProperty(key: String): String = localProperties.getProperty(key)
override def getMetricsSources(sourceName: String): Seq[Source] =
metricsSystem.getSourcesByName(sourceName)
private[spark] override def registerAccumulator(a: AccumulatorV2[_, _]): Unit = {
taskMetrics.registerAccumulator(a)
}
}