/** Spark SQL源码分析系列文章*/
前几篇文章介绍了Spark SQL的Catalyst的核心运行流程、SqlParser,和Analyzer,本来打算直接写Optimizer的,但是发现忘记介绍TreeNode这个Catalyst的核心概念,介绍这个可以更好的理解Optimizer是如何对Analyzed Logical Plan进行优化的生成Optimized Logical Plan,本文就将TreeNode基本架构进行解释。
主要继承关系类图如下:
二元节点,即有左右孩子的二叉节点
[[TreeNode]] that has two children, [[left]] and [[right]]. trait BinaryNode[BaseType <: TreeNode[BaseType]] { def left: BaseType def right: BaseType def children = Seq(left, right) } abstract class BinaryNode extends LogicalPlan with trees.BinaryNode[LogicalPlan] { self: Product => }节点定义比较简单,左孩子,右孩子都是BaseType。 children是一个Seq(left, right)
下面列出主要继承二元节点的类,可以当查询手册用 :)
一元节点,即只有一个孩子节点
A [[TreeNode]] with a single [[child]]. trait UnaryNode[BaseType <: TreeNode[BaseType]] { def child: BaseType def children = child :: Nil } abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] { self: Product => }下面列出主要继承一元节点的类,可以当查询手册用 :)
常用的二元节点有,Project,Subquery,Filter,Limit ...等
叶子节点,没有孩子节点的节点。
A [[TreeNode]] with no children. trait LeafNode[BaseType <: TreeNode[BaseType]] { def children = Nil } abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { self: Product => // Leaf nodes by definition cannot reference any input attributes. override def references = Set.empty }下面列出主要继承叶子节点的类,可以当查询手册用 :)
提示常用的叶子节点: Command类系列,一些Funtion函数,以及Unresolved Relation...etc.
currentId
一颗树里的TreeNode有个唯一的id,类型是java.util.concurrent.atomic.AtomicLong原子类型。
private val currentId = new java.util.concurrent.atomic.AtomicLong protected def nextId() = currentId.getAndIncrement()sameInstance
def sameInstance(other: TreeNode[_]): Boolean = { this.id == other.id }fastEquals,更常用的一个快捷的判定方法,没有重写Object.Equals,这样防止scala编译器生成case class equals 方法
def fastEquals(other: TreeNode[_]): Boolean = { sameInstance(other) || this == other }map,flatMap,collect都是递归的对子节点进行应用PartialFunction,其它方法还有很多,篇幅有限这里不一一描述了。
来看一个例子:
object GlobalAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { //apply方法这里调用了logical plan(TreeNode) 的transform方法来应用一个PartialFunction。 case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } def containsAggregates(exprs: Seq[Expression]): Boolean = { exprs.foreach(_.foreach { case agg: AggregateExpression => return true case _ => }) false } }这个方法真正的调用是transformChildrenDown,这里提到了用先序遍历来对子节点进行递归的Rule应用。
transformDown方法:
/** * Returns a copy of this node where `rule` has been recursively applied to it and all of its * children (pre-order). When `rule` does not apply to a given node it is left unchanged. * @param rule the function used to transform this nodes children */ def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = { val afterRule = rule.applyOrElse(this, identity[BaseType]) // Check if unchanged and then possibly return old copy to avoid gc churn. if (this fastEquals afterRule) { transformChildrenDown(rule) //修改前节点this.transformChildrenDown(rule) } else { afterRule.transformChildrenDown(rule) //修改后节点进行transformChildrenDown } }最重要的方法transformChildrenDown:
transformChildrenDown方法:
/** * Returns a copy of this node where `rule` has been recursively applied to all the children of * this node. When `rule` does not apply to a given node it is left unchanged. * @param rule the function used to transform this nodes children */ def transformChildrenDown(rule: PartialFunction[BaseType, BaseType]): this.type = { var changed = false val newArgs = productIterator.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) //递归子节点应用rule if (!(newChild fastEquals arg)) { changed = true newChild } else { arg } case Some(arg: TreeNode[_]) if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true Some(newChild) } else { Some(arg) } case m: Map[_,_] => m case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => val newChild = arg.asInstanceOf[BaseType].transformDown(rule) if (!(newChild fastEquals arg)) { changed = true newChild } else { arg } case other => other } case nonChild: AnyRef => nonChild case null => null }.toArray if (changed) makeCopy(newArgs) else this //根据作用结果返回的newArgs数组,反射生成新的节点副本。 }makeCopy方法,反射生成节点副本
/** * Creates a copy of this type of tree node after a transformation. * Must be overridden by child classes that have constructor arguments * that are not present in the productIterator. * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { try { val defaultCtor = getClass.getConstructors.head //反射获取默认构造函数的第一个 if (otherCopyArgs.isEmpty) { defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] //反射生成当前节点类型的节点 } else { defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type] //如果还有其它参数,++ } } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " + s"Exception message: ${e.getMessage}.") } }
<span style="font-size:12px;">sbt/sbt hive/console Using /usr/java/default as default JAVA_HOME. Note, this will be overridden by -java-home if it is set. [info] Loading project definition from /app/hadoop/shengli/spark/project/project [info] Loading project definition from /app/hadoop/shengli/spark/project [info] Set current project to root (in build file:/app/hadoop/shengli/spark/) [info] Starting scala interpreter... [info] import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.parquet.ParquetTestData scala> val query = sql("SELECT * FROM (SELECT * FROM src) a join (select * from src)b on a.key=b.key")</span>
scala> query.queryExecution.logical res0: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = Project [*] Join Inner, Some(('a.key = 'b.key)) Subquery a Project [*] UnresolvedRelation None, src, None Subquery b Project [*] UnresolvedRelation None, src, None如果画成树是这样的,仅个人理解:
scala> query.queryExecution.optimizedPlan res3: org.apache.spark.sql.catalyst.plans.logical.LogicalPlan = Project [key#0,value#1,key#2,value#3] Join Inner, Some((key#0 = key#2)) MetastoreRelation default, src, None MetastoreRelation default, src, None
scala> query.queryExecution.executedPlan res4: org.apache.spark.sql.execution.SparkPlan = Project [key#0:0,value#1:1,key#2:2,value#3:3] HashJoin [key#0], [key#2], BuildRight Exchange (HashPartitioning [key#0:0], 150) HiveTableScan [key#0,value#1], (MetastoreRelation default, src, None), None Exchange (HashPartitioning [key#2:0], 150) HiveTableScan [key#2,value#3], (MetastoreRelation default, src, None), None生成的物理执行树如图: