1.迭代器模式
在计算时,为了节省内存,不把所有的数据一次全部加载到内存中,有一种设计模式叫迭代器模式。
迭代器模式:在逻辑代码执行时,真正的逻辑并未执行,而是创建了新的迭代器,新的迭代器保存着对当前迭代器的引用从而形成链表,每个迭代器需要实现hasNext(),next()两个方法。当触发计算时,最后一个创建的迭代器会调用next方法,next方法会调用父迭代器的next方法。
例如:
val list = List("a a", "b d", "c e")
val it = list.iterator
it.flatMap(_.split(" ")).map((_, 1)).filter(_._1 != "").foreach(println)
这个例子中it是初始迭代器,后面每个方法都会生成一个新的迭代器,但并不进行迭代计算,到最后foreach方法(类似action算子),开始执行迭代计算了。
我们依次展开:
def flatMap[B](f: A => GenTraversableOnce[B]): Iterator[B] = new AbstractIterator[B] {
// f作用在上游单条数据的结果转换成的iterator
private var cur: Iterator[B] = empty
private def nextCur() { cur = f(self.next()).toIterator }
def hasNext: Boolean = {
while (!cur.hasNext) {
if (!self.hasNext) return false
nextCur()
}
true
}
def next(): B = (if (hasNext) cur else empty).next()
}
flatMap方法是创建了一个AbstractIterator的匿名内部类,并实现了hasNext和next两个方法。每当调用next时,会先调用hasNext,在hasNext中,调上游的iterator的next方法获取上游这条数据的返回结果,再对这条结果执行用户传入的函数f并返回结果后,将其转换为iterator,再返回这个iterator的next的结果。
def map[B](f: A => B): Iterator[B] = new AbstractIterator[B] {
def hasNext = self.hasNext
def next() = f(self.next())
}
map与flatMap的代码模板一样,逻辑更简单,只是对上游的next返回结果执行用户传入的函数,再返回。
def filter(p: A => Boolean): Iterator[A] = new AbstractIterator[A] {
private var hd: A = _
private var hdDefined: Boolean = false
def hasNext: Boolean = hdDefined || {
do {
if (!self.hasNext) return false
hd = self.next()
} while (!p(hd))
hdDefined = true
true
}
def next() = if (hasNext) { hdDefined = false; hd } else empty.next()
}
filter中,调用hasNext时,先调用上游iterator的hasNext,如果返回false,那么直接返回false。如果上游的hasNext返回true,就取出上游的next结果,并将用户传入的判断函数p作用在这个结果上,若为true,则退出循环,并将hdDefine置为true;若p的结果为false,则继续从上游取下一条数据让p判断。
def foreach[U](f: A => U) { while (hasNext) f(next()) }
遍历迭代器,将每个元素传给用户传入的函数f中执行。
2.RDD串联
在spark的每个任务中,都是以迭代器模式进行计算的。而每个迭代器的链表对应每个分区中的数据。RDD的每个算子会生成一个新的RDD,新的RDD会保存对前一个RDD的引用,并且会保存传入到算子中的用户定义函数。
例如:
def map[U: ClassTag](f: T => U): RDD[U] = withScope {
val cleanF = sc.clean(f)
new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
}
这个map算子会返回一个MapPartitionsRDD,MapPartitionsRDD中含有当前this这个RDD的引用,并把用户定义函数f转换成作用于iterator的函数传入到MapPartitionsRDD中。
RDD中有个抽象方法compute,MapPartitionsRDD中实现如下:
override def compute(split: Partition, context: TaskContext): Iterator[U] =
f(context, split.index, firstParent[T].iterator(split, context))
从父RDD(firstParent[T])获取迭代器,这个过程需要分区信息split和任务上下文。再map算子中转换后的用户定义函数作用在这个迭代器上。
compute方法同迭代器模式类似,也是不断从上游RDD获取的迭代器,这样来获得一个迭代器的链表,这个链表就是一个task要执行的任务。
为了说明这个过程,我们从Executor源码来找寻。
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val tr = new TaskRunner(context, taskDescription)
runningTasks.put(taskDescription.taskId, tr)
threadPool.execute(tr)
}
Executor源码中有个launchTask方法,会创建TaskRunner,将TaskRunner交给线程池执行。TaskRunner是什么呢?
在Executor源码中有一个内部类,TaskRunner,它是一个线程的任务:
class TaskRunner(
execBackend: ExecutorBackend,
private val taskDescription: TaskDescription)
extends Runnable {
继承Runnable必须实现run方法,找到run方法,在run方法中找到了如下代码:
val res = task.run(
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem)
threwException = false
点进这里task的run,会在Task类中找到runTask(context),这个runTask是Task类的抽象方法,会被Task的子类实现。比如ResultTask,这个子类是最后collect类型的action算子出发的任务类。在ResultTask中,runTask方法调用了rdd的iterator方法来获取iterator,并将用户定义的方法作用到这个iterator上。
override def runTask(context: TaskContext): U = {
// Deserialize the RDD and the func using the broadcast variables.
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L
func(context, rdd.iterator(partition, context))
}
这个rdd的iterator方法会获取父rdd的迭代器或调用compute方法。
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
}
private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
if (isCheckpointedAndMaterialized) {
firstParent[T].iterator(split, context)
} else {
compute(split, context)
}
}
小结
spark每个任务都是由向前依赖串联起来RDD链表生成的iterator链表构成的,任务执行由最后的一个iterator的迭代开始,调用上游的迭代器的next,直到迭代到第一个iterator。这样避免了将所有数据先加载到内存中,而每次计算都只从源头取一条数据,大大节省了内存。