Spark Catalyst中的TreeNode源码解析

简介

源码基于Spark-2.0.2版本

TreeNode累是Catalyst中和执行计划相关的所有AST,包括表达式Expression、逻辑执行计划LogicalPlan,物理执行计划SparkPlan的基类。

TreeNode继承Scala中的Product类,其目的是可通过Product类中的方法(productArity、productElement、productIterator)来操纵TreeNode实现类的参数,这些实现类一般都是case class。

TreeNode是抽象类,用于被其他抽象或实现类继承,如Expression、LogicalPlan。TreeNode类声明如下:

abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product

并且定义了this的别名:

self: BaseType =>

这里约定TreeNode的泛型参数类型BaseType应是TreeNode[BasType]的子类型,并且继承TreeNode的实现类的类型就是传入的BaseType。比如Aggregate是LogicalPlan的实现类,LogicalPlan继承QueryPlan[LogicalPlan],又有QueryPlan[PlanType <: QueryPlan[PlanType]]继承TreeNode[PlanType],因此Aggregate的self类型就是Aggregate。

TreeNode中的成员和方法

TreeNode中定义的方法都是对AST的基本操作方法。

children方法

def children: Seq[BaseType]

children方法需要在子类中实现。返回当前节点所有孩子节点的Seq。

containsChild成员

lazy val containsChild: Set[TreeNode[_]] = children.toSet

孩子节点的集合。因为Set(value)返回Boolean值,表示是否在集合内,因此containsChild这样的集合命名更符合自然语言的语法特点。

fastEquals方法

def fastEquals(other: TreeNode[_]): Boolean = {
  this.eq(other) || this == other
}

比调用equals更快比较连个TreeNode是否是同一个实例的短路版本。这里没有覆盖Object的equals方法,这样不会阻止scala编译器为case class生成equals方法,而TreeNode的实现类都是case class。

这里采用短路的做法是:先比较两个TreeNode的引用是否相等,相等就直接返回,否则才会去比较值(如果other是null的花也是用eq比较引用,具体可以看下==方法的定义,也就是当other为null时用eq做引用比较,不为null则用equals做值比较)。

find方法

def find(f: BaseType => Boolean) = f(this) match {
  case true => Some(this)
  case false => children.foldLeft(Option.empty[BaseType]) { (l, r) => l.orElse(r.find(f)) }
}

用于找到满足函数对象f指定条件的第一个TreeNode。

指定的函数对象地柜应用于当前节点以及它的所有子节点,采用前序遍历的方式,也就是先检查当前节点,再递归检查子节点。

下面我们分析下代码。如果f(this)返回true,则返回当前节点的实例也就是Some(this);如果f(this)返回false,也就是当前节点不满足条件,这时就要递归检查其所有孩子节点。这里用到了foldLeft。

children在上面定义为Seq[BaseType],基于children这个Seq调用foldLeft也就是基于这个Seq从第一个元素向最后一个元素迭代执行(比foldRight效率高,因为找到Seq的最后一个元素的时间是O(n))。这里要传入两个参数给foldLeft函数,一个是初始值,这里是Option.empty[BaseType],也就是一个Option[BaseType]类型的None值(再如果children为空Seq就直接返回None了);另一个参数是以迭代过程中当前基于前次迭代结果和当前元素为参数的匿名函数。这个匿名函数的第一个参数也就是l就是基于前次迭代的计算结果,第二个参数r是当前元素的值。l.orElse(r.find(f))就很好理解了,也就是如果前次的计算结果不为None的话就取这个结果返回(不会直接返回,而是继续递归遍历到整个树的最后一个孩子节点,但是由于返回了非None因此每次对r.orElse的调用都会返回r),否则在当前迭代的元素节点上递归调用find方法并传入f函数。

foreach方法

def foreach(f: BaseType => Unit): Unit = {
  f(this)
  children.foreach(_.foreach(f))
}

在当前节点以及它的所有孩子节点上递归运行一个给定的函数f。

这是对以当前节点为根的AST的前序遍历。

forearchUp方法

def foreachUp(f: BaseType => Unit): Unit = {
  children.foreach(_.foreachUp(f))
  f(this)
}

在当前节点的所有孩子节点上地轨运行一个给定的函数f,之后再在当前节点上运行函数f。

这是对以当前节点为根节点的AST的后序遍历。

map方法

def map[A](f: BaseType => A): Seq[A] = {
  val ret = new collection.mutable.ArrayBuffer[A]()
  foreach(ret += f(_))
  ret
}

先序遍历以当前节点为根的子树,在这个过程中,对每个节点应用给定的函数f并将返回值放入一个Seq,最后返回这个A类型的Seq。

flatMap方法

def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A] = {
  val ret = new collection.mutable.ArrayBuffer[A]()
  foreach(ret ++= f(_))
  ret
}

先序遍历以当前节点为根的子树,在这个过程中,对每个节点应用给定的函数f并将返回的A类型的列表中的值放入一个Seq钟,最后返回这个A类型的Seq。

和map方法不同的是,flatMap的参数是一个返回值为TraverseOnce[A]序列类型的变量。

collect方法

def collect[B](pf: PartialFunction[BaseType, B]): Seq[B] = {
  val ret = new collection.mutable.ArrayBuffer[B]()
  val lifted = pf.lift
  foreach(node => lifted(node).foreach(ret.+=))
  ret
}

收集AST中满足偏函数pf的节点。

前序遍历以当前节点为根的子树,在这个过程中,对每个节点应用给定的偏函数pf,返回所有可以应用偏函数的节点序列。

在collect方法实现中,调用偏函数的lift方法将偏函数转换成普通函数,转换之后的方法会加入case _ => None的模式匹配以防止遇到未定义的情况抛出异常。对None值做foreach就不会把未成功应用偏函数的节点加入结果集。

collectLeaves方法

def collectLeaves(): Seq[BaseType] = {
  this.collect { case p if p.children.isEmpty => p }
}

返回一个包含以当前节点为根的AST中的所有叶子节点的序列。

collectFirst方法

def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B] = {
  val lifted = pf.lift
  lifted(this).orElse {
    children.foldLeft(Option.empty[B]) { (l, r) => l.orElse(r.collectFirst(pf)) }
  }
}

前序遍历以当前节点为根的AST,在这个过程中找到第一个应用给定偏函数pf的结果有定义的节点,并返回此节点调用偏函数之后的结果。

mapProductIterator方法

protected def mapProductIterator[B: ClassTag](f: Any => B): Array[B] = {
  val arr = Array.ofDim[B](productArity)
  val i = 0
  while (i < arr.length) {
    arr[i] = f(productElement(i))
    i += 1
  }
  arr
}

对TreeNode实现类的每个参数应用函数f(因为不知道每个类参数的具体类型所有函数f的输入类型是Any)并返回应用后的类型为B的数组。

这个是productIterator.map(f).toArray的高效版本,对实例参数列表只需要一次遍历。

mapChildren方法

def mapChildren(f: BaseType => BaseType): BaseType = {
  var changed = false
  val newArgs = mapProductIterator {
    case arg: TreeNode[_] if containsChild(arg) =>
      val newChild = f(arg.asInstanceOf[BaseType])
      if (newChild fastEques arg) {
        arg
      } else {
        changed = true
        newChild
      }
    case nonChild: AnyRef => nonChild
    case null => null
  }
  if (changed) makeCopy(newArgs) else this
}

这个方法是对以当前节点为根的子树的所有子节点做map操作,在每个子节点上执行函数f,最后生成一棵新的子树(或者保持原子树不变,前提是应用函数f之后的节点还是原先的节点)。

这个方法是对当前TreeNode实现类的参数上应用函数f。这里调用的mapProductIterator方法的参数类型是Any => Any。TreeNode实现类的参数可能类型是TreeNode,也可能不是,比如Add这个Expression的实现类的声明是case class Add(left: Expression, right: Expression),它的参数都是Expression类型的;对于Aggregate这个LogicalPlan的实现类的声明是case class Aggregate(groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], child: LogicalPlan),它只有第3个参数是LogicalPlan类型的。

withNewChildren方法

def withNewChildren(newChildren: Seq[BaseType]): BaseType = {
  assert(newChildren.size == children.size, "Incorrect number of children")
  var changed = false
  val remainingNewChildren = newChildren.toBuffer
  val remainingOldChildren = children.toBuffer
  val newArgs = mapProductIterator {
    case s: StructType => s // Don't convert struct types to some other type of Seq[StructField]
    // Handle Seq[TreeNode] in TreeNode parameters.
    case s: Seq[_] => s.map {
      case arg: TreeNode[_] if containsChild(arg) =>
        val newChild = remainingNewChildren.remove(0)
        val oldChild = remainingOldChildren.remove(0)
        if (newChild fastEquals oldChild) {
          oldChild
        } else {
          changed = true
          newChild
        }
      case nonChild: AnyRef => nonChild
      case null => null  
    }
    case m: Map[_, _] => m.mapValues {
      case arg: TreeNode[_] if containsChild(arg) =>
        val newChild = remainingNewChildren.remove(0)
        val oldChild = remainingOldChildren.remove(0)
        if (newChild fastEquals oldChild) {
          oldChild
        } else {
          changed = true
          newChild
        }
      case nonChild: AnyRef => nonChild
      case null => null
    }.view.force // `mapValues` is lazy and we need to force it to materialize
    case arg: TreeNode[_] if containsChild(arg) =>
      val newChild = remainingNewChildren.remove(0)
      val oldChild = remainingOldChildren.remove(0)
      if (newChild fastEquals oldChild) {
        oldChild
      } else {
        changed = true
        newChild
      }
    case nonChild: AnyRef => nonChild
    case null => null  
  }

  if (changed) makeCopy(newArgs) else this
}

返回当前节点的副本,匹配的子节点用传入的newChildren的成员替换。

对于TreeNode的实现类参数,可能有如下几种类型:

  • StructType
  • Seq[TreeNode]
  • Seq[AnyRef]
  • Map[_, TreeNode]
  • Map[_, AnyRef]
  • TreeNode
  • AnyRef

有几个TreeNOde的实现类可以覆盖上面几种情况:

case class Add(left: Expression, right: Expression)
case class Coalesce(children: Seq[Expression])
case class AtLeasetNNonNulls(n: Int, children: Seq[Expression])
...

transformDown方法

def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {
  val afterRule = CurrentOrigin.withOrigin(origin) {
    rule.applyOrElse(this, identity[BaseType])
  }

  // Check if unchanged and then possibly return old copy to avoid gc churn
  if (this fastEquals afterRule) {
    transformChildren(rule, (t, r) => t.transformDown(r))
  } else {
    afterRule.transformChildren(rule, (t, r) => t.transformDown(r))
  }
}

前序比那里以当前节点为根的子树,在遍历过程中,对当前节点以及它的子节点递归应用传入的规则rule,当规则rule不能应用于给定节点时,保持节点不变。

这里的规则使用scala的模式匹配,可以认为一个规则是一个或一组模式匹配语句。

下面来看一下方法定义。首先对当前节点应用rule:

rule.applyOrElse(this, identity[BaseType])

这里用到了PartialFunction的applyOrElse方法,用来避免undefined的情况发生。如果当前节点应用rule没有匹配的话,则返回默认的当前节点本身。对于PartialFunction[A, B]类型的偏函数来说,applyOrElse方法的签名应该是:

def applyOrElse(x: A, default: A => B)

如果当前节点没有应用到规则的话。则对当前节点调用transformChildren方法,在这个方法中,对当前节点的每个孩子节点递归调用transformDown方法;如果当前节点应用了规则,则对新节点调用transformChildren方法。

transformUp方法

def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
  val afterRuleOnChildren = transformChildren(rule, (t, r) => t.transformUp(r))
  if (this fastEquals afterRuleOnChildren) {
    CurrentOrigin,withOrigin(origin) {
      rule.applyOrElse(this, Identity[BaseType])
    }
  } else {
    CurrentOrigin.withOrigin(origin) {
      rule.applyOrElse(afterRuleOnChildren, Identity[BaseType])
    }
  }
}

后续遍历以当前节点为根的子树,在便利过程中,先对当前节点调用transformChildren,对当前节点的每个孩子节点递归调用transformUp。最后对当前节点(规则对每个孩子节点都没有成功应用)或者新节点(规则对一个或多个孩子节点应用)应用rule。

transform方法

def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = {
  transformDown(rule)
}

传入的规则rule递归应用到以当前节点为根的子树,然后返回应用之后的当前节点或其拷贝。

这里用户不应该期望有特定的规则应用顺序(前序or后续)。如果需要一个特定的匹配顺序,则应该使用transformDown或者transformUp。

transformChildren方法

protected def transformChildren(rule: PartialFunction[BaseType, BaseType], nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = {
  var changed = false
  val newArgs = mapProductIterator {
    case arg: TreeNode[_] if containsChild(arg) =>
      val newArg = nextOperation(arg.asInstanceOf[BaseType], rule)
      if (newArg fastEquals arg) {
        arg
      } else {
        changed = true
        newArg
      }
    case Some(arg: TreeNode[_]) if containsChild(arg) =>
      val newArg = nextOperation(arg.asInstanceOf[BaseType], rule)
      if (newArg fastEquals arg) {
        Some(arg)
      } else {
        changed = true
        Some(newArg)
      }
    case m: Map[_. _] => m.mapValues {
      case arg: TreeNode[_] if containsChild(arg) =>
        val newArg = nextOperation(arg.asInstanceOf[BaseType], rule)
        if (newArg fastEquals arg) {
          arg
        } else {
          changed = true
          newArg
        }
      case other => other  
    }.view.force // `mapValues` is lazy and we need to force it to materialize
    case d: DataType => d // Avoid unpacking structs
    case args: Traversable[_] => args.map {
      case arg: TreeNode[_] if containsChild(arg) =>
        val newArg = nextOperation(arg.asInstanceOf[BaseType], rule)
        if (newArg fastEquals arg) {
          arg
        } else {
          changed = true
          newArg
        }
      case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) =>
        val newArg1 = nextOperation(arg1.asInstanceOf[BaseType], rule)
        val newArg2 = nextOperation(arg2.asInstanceOf[BaseType], rule)
        if (!(newArg1 fastEquals arg1) || !(newArg2 fastEquals arg2)) {
          changed = true
          (newArg1, newArg2)
        } else {
          tuple
        }
      case other => other  
    }
    case nonChild: AnyRef => nonChild
    case null => null
  }
  if (changed) makeCopy(newArgs) else this  
}

这个类是配合transformtransformUptransformDown方法使用的。在这些方法里,会调用transformChildren方法,传入rule以及一个匿名函数对象(t, r) => t.transformDown(r),这个函数的signiture为(BaseType, PartialFunction[BaseType, BaseType]) => BaseType。在transformChildren方法中,对于类型为TreeNode[_]的孩子节点,会对其调用传入的匿名函数。

你可能感兴趣的:(spark)