Spark源码拜读之RDD的迭代器串联

1.迭代器模式

在计算时,为了节省内存,不把所有的数据一次全部加载到内存中,有一种设计模式叫迭代器模式。

迭代器模式:在逻辑代码执行时,真正的逻辑并未执行,而是创建了新的迭代器,新的迭代器保存着对当前迭代器的引用从而形成链表,每个迭代器需要实现hasNext(),next()两个方法。当触发计算时,最后一个创建的迭代器会调用next方法,next方法会调用父迭代器的next方法。

例如:

val list = List("a a", "b d", "c e")
val it = list.iterator
it.flatMap(_.split(" ")).map((_, 1)).filter(_._1 != "").foreach(println)

这个例子中it是初始迭代器,后面每个方法都会生成一个新的迭代器,但并不进行迭代计算,到最后foreach方法(类似action算子),开始执行迭代计算了。

我们依次展开:

def flatMap[B](f: A => GenTraversableOnce[B]): Iterator[B] = new AbstractIterator[B] {
  // f作用在上游单条数据的结果转换成的iterator
  private var cur: Iterator[B] = empty
  private def nextCur() { cur = f(self.next()).toIterator }
  def hasNext: Boolean = {
    while (!cur.hasNext) {
      if (!self.hasNext) return false
      nextCur()
    }
    true
  }
  def next(): B = (if (hasNext) cur else empty).next()
}

flatMap方法是创建了一个AbstractIterator的匿名内部类,并实现了hasNext和next两个方法。每当调用next时,会先调用hasNext,在hasNext中,调上游的iterator的next方法获取上游这条数据的返回结果,再对这条结果执行用户传入的函数f并返回结果后,将其转换为iterator,再返回这个iterator的next的结果。

def map[B](f: A => B): Iterator[B] = new AbstractIterator[B] {
  def hasNext = self.hasNext
  def next() = f(self.next())
}

map与flatMap的代码模板一样,逻辑更简单,只是对上游的next返回结果执行用户传入的函数,再返回。

def filter(p: A => Boolean): Iterator[A] = new AbstractIterator[A] {
  private var hd: A = _
  private var hdDefined: Boolean = false

  def hasNext: Boolean = hdDefined || {
    do {
      if (!self.hasNext) return false
      hd = self.next()
    } while (!p(hd))
    hdDefined = true
    true
  }

  def next() = if (hasNext) { hdDefined = false; hd } else empty.next()
}

filter中,调用hasNext时,先调用上游iterator的hasNext,如果返回false,那么直接返回false。如果上游的hasNext返回true,就取出上游的next结果,并将用户传入的判断函数p作用在这个结果上,若为true,则退出循环,并将hdDefine置为true;若p的结果为false,则继续从上游取下一条数据让p判断。

def foreach[U](f: A => U) { while (hasNext) f(next()) }

遍历迭代器,将每个元素传给用户传入的函数f中执行。

2.RDD串联

在spark的每个任务中,都是以迭代器模式进行计算的。而每个迭代器的链表对应每个分区中的数据。RDD的每个算子会生成一个新的RDD,新的RDD会保存对前一个RDD的引用,并且会保存传入到算子中的用户定义函数。

例如:

def map[U: ClassTag](f: T => U): RDD[U] = withScope {
  val cleanF = sc.clean(f)
  new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
}

这个map算子会返回一个MapPartitionsRDD,MapPartitionsRDD中含有当前this这个RDD的引用,并把用户定义函数f转换成作用于iterator的函数传入到MapPartitionsRDD中。

RDD中有个抽象方法compute,MapPartitionsRDD中实现如下:

override def compute(split: Partition, context: TaskContext): Iterator[U] =
  f(context, split.index, firstParent[T].iterator(split, context))

从父RDD(firstParent[T])获取迭代器,这个过程需要分区信息split和任务上下文。再map算子中转换后的用户定义函数作用在这个迭代器上。

compute方法同迭代器模式类似,也是不断从上游RDD获取的迭代器,这样来获得一个迭代器的链表,这个链表就是一个task要执行的任务。

为了说明这个过程,我们从Executor源码来找寻。

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
  val tr = new TaskRunner(context, taskDescription)
  runningTasks.put(taskDescription.taskId, tr)
  threadPool.execute(tr)
}

Executor源码中有个launchTask方法,会创建TaskRunner,将TaskRunner交给线程池执行。TaskRunner是什么呢?

在Executor源码中有一个内部类,TaskRunner,它是一个线程的任务:

class TaskRunner(
    execBackend: ExecutorBackend,
    private val taskDescription: TaskDescription)
  extends Runnable {

继承Runnable必须实现run方法,找到run方法,在run方法中找到了如下代码:

val res = task.run(
  taskAttemptId = taskId,
  attemptNumber = taskDescription.attemptNumber,
  metricsSystem = env.metricsSystem)
threwException = false

点进这里task的run,会在Task类中找到runTask(context),这个runTask是Task类的抽象方法,会被Task的子类实现。比如ResultTask,这个子类是最后collect类型的action算子出发的任务类。在ResultTask中,runTask方法调用了rdd的iterator方法来获取iterator,并将用户定义的方法作用到这个iterator上。

override def runTask(context: TaskContext): U = {
    // Deserialize the RDD and the func using the broadcast variables.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTime = System.currentTimeMillis()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
        ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
        threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

    func(context, rdd.iterator(partition, context))
}

这个rdd的iterator方法会获取父rdd的迭代器或调用compute方法。

final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
  if (storageLevel != StorageLevel.NONE) {
    getOrCompute(split, context)
  } else {
    computeOrReadCheckpoint(split, context)
  }
}
private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
    if (isCheckpointedAndMaterialized) {
        firstParent[T].iterator(split, context)
    } else {
        compute(split, context)
    }
}

小结

spark每个任务都是由向前依赖串联起来RDD链表生成的iterator链表构成的,任务执行由最后的一个iterator的迭代开始,调用上游的迭代器的next,直到迭代到第一个iterator。这样避免了将所有数据先加载到内存中,而每次计算都只从源头取一条数据,大大节省了内存。

你可能感兴趣的:(scala,spark)