Spark1.*中CreateTableAsSelect 语句实现

本文以一个 CreateTableAsSelect 的Command 的job提交执行过程为例,说明spark sql 的job提交执行过程

``

SparkSQLCLIDriver:main():   ret = cli.processLine(line, true)
SparkSQLCLIDriver:main():   val rc = driver.run(cmd)
SparkSQLDriver:run():       val execution = context.executePlan(context.sql(command).logicalPlan)
SQLContext:sql():           DataFrame(this, parseSql(sqlText))  // parseSQL 解析SQL并生成执行计划
DataFrame:apply():           new DataFrame(sqlContext, logicalPlan) // 调用 apply() 方法实例化

DataFrame 构造函数

  // 这里 logicalPlan 是一个对象,虽然不是lazy对象,但是也是在 DataFrame 构造方法中进行解析的...
  def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = {
    this(sqlContext, {
      val qe = sqlContext.executePlan(logicalPlan)
      if (sqlContext.conf.dataFrameEagerAnalysis) {
        qe.assertAnalyzed()  // This should force analysis and throw errors if there are any
      }
      qe
    })
  }

  @transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match {
    // For various commands (like DDL) and queries with side effects, we force query optimization to
    // happen right away to let these side effects take place eagerly.
    case _: Command |
         _: InsertIntoTable |
         _: CreateTableUsingAsSelect =>
      // 在这里出现了分支,这里增加 CTAS 的Callback
      LogicalRDD(queryExecution.analyzed.output, withCallback("CTAS", queryExecution))(sqlContext)
    case _ =>
      queryExecution.analyzed
  }

DataFrame:withCallback():           val rdd = qe.toRdd
QueryExecution::           lazy val toRdd: RDD[InternalRow] = executedPlan.execute()
SparkPlan:execute():

    RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
      prepare()
      doExecute()
    }

  // ExecutedCommand
  protected override def doExecute(): RDD[InternalRow] = {
    sqlContext.sparkContext.parallelize(sideEffectResult, 1)
  }

  protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
    val converter = CatalystTypeConverters.createToCatalystConverter(schema)
    cmd.run(sqlContext).map(converter(_).asInstanceOf[InternalRow])
  }

  // 这里 cmd 为 CreateTableAsSelect

case class CreateTableAsSelect(
    tableDesc: HiveTable,
    query: LogicalPlan,
    allowExisting: Boolean)
  extends RunnableCommand {

  val tableIdentifier = TableIdentifier(tableDesc.name, Some(tableDesc.database))

  override def children: Seq[LogicalPlan] = Seq(query)

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val hiveContext = sqlContext.asInstanceOf[HiveContext]
    // 通过 hiveContext.catalog 找到对应的hive metaStore中对象
    lazy val metastoreRelation: MetastoreRelation = ...

    if (hiveContext.catalog.tableExists(tableIdentifier)) {
      // 表已经存在异常判断
    } else {
      // a. metastoreRelation 对应表在hive metaStore 中存储的对象
      // b. 根据 metastoreRelation 和 query 创建 InsertIntoTable LogicalPlan 对象
      // c. hiveContext.executePlan() 这个方法不执行任何操作,只是将LogicalPlan 对象包装为 QueryExecution 对象  调用  来执行 上面的plan
      // d. toRdd() 真正调用 executedPlan.execute() 来提交任务执行
      hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd
    }
    Seq.empty[Row]
  }
}

上面的 toRdd() 方法再次调用 SparkPlan:execute() 方法循环,我们执行调用进入 InsertIntoHiveTable Plan的 doExecute() 方法

// InsertIntoHiveTable
  protected override def doExecute(): RDD[InternalRow] = {
    sqlContext.sparkContext.parallelize(sideEffectResult.asInstanceOf[Seq[InternalRow]], 1)
  }

  /**
   * Inserts all the rows in the table into Hive.  Row objects are properly serialized with the
   * `org.apache.hadoop.hive.serde2.SerDe` and the
   * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition.
   *
   * Note: this is run once and then kept to avoid double insertions.
   */
  protected[sql] lazy val sideEffectResult: Seq[InternalRow] = {
    // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer
    // instances within the closure, since Serializer is not serializable while TableDesc is.
    val tableDesc = table.tableDesc
    val tableLocation = table.hiveQlTable.getDataLocation
    val jobConf = new JobConf(sc.hiveconf)
    val tmpLocation = getExternalTmpPath(tableLocation, jobConf)

    val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false)
    val isCompressed = sc.hiveconf.getBoolean(
      ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal)

    if (isCompressed) {
      // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec",
      // and "mapred.output.compression.type" have no impact on ORC because it uses table properties
      // to store compression information.
      sc.hiveconf.set("mapred.output.compress", "true")
      fileSinkConf.setCompressed(true)
      fileSinkConf.setCompressCodec(sc.hiveconf.get("mapred.output.compression.codec"))
      fileSinkConf.setCompressType(sc.hiveconf.get("mapred.output.compression.type"))
    }

    val numDynamicPartitions = partition.values.count(_.isEmpty)
    val numStaticPartitions = partition.values.count(_.nonEmpty)
    val partitionSpec = partition.map {
      case (key, Some(value)) => key -> value
      case (key, None) => key -> ""
    }

    // All partition column names in the format of "//..."
    val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns")
    val partitionColumnNames = Option(partitionColumns).map(_.split("/")).orNull

    // Validate partition spec if there exist any dynamic partitions
    if (numDynamicPartitions > 0) {
      // Report error if dynamic partitioning is not enabled
      if (!sc.hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) {
        throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg)
      }

      // Report error if dynamic partition strict mode is on but no static partition is found
      if (numStaticPartitions == 0 &&
        sc.hiveconf.getVar(HiveConf.ConfVars.DYNAMICPARTITIONINGMODE).equalsIgnoreCase("strict")) {
        throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg)
      }

      // Report error if any static partition appears after a dynamic partition
      val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty)
      if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) {
        throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg)
      }
    }

    val jobConfSer = new SerializableJobConf(jobConf)

    // When speculation is on and output committer class name contains "Direct", we should warn
    // users that they may loss data if they are using a direct output committer.
    val speculationEnabled = sqlContext.sparkContext.conf.getBoolean("spark.speculation", false)
    val outputCommitterClass = jobConf.get("mapred.output.committer.class", "")
    if (speculationEnabled && outputCommitterClass.contains("Direct")) {
      val warningMessage =
        s"$outputCommitterClass may be an output committer that writes data directly to " +
          "the final location. Because speculation is enabled, this output committer may " +
          "cause data loss (see the case in SPARK-10063). If possible, please use a output " +
          "committer that does not have this behavior (e.g. FileOutputCommitter)."
      logWarning(warningMessage)
    }

    val writerContainer = if (numDynamicPartitions > 0) {
      val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions)
      new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames)
    } else {
      new SparkHiveWriterContainer(jobConf, fileSinkConf)
    }

    // 提交任务执行
    saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer)

    val outputPath = FileOutputFormat.getOutputPath(jobConf)
    // Have to construct the format of dbname.tablename.
    val qualifiedTableName = s"${table.databaseName}.${table.tableName}"
    // TODO: Correctly set holdDDLTime.
    // In most of the time, we should have holdDDLTime = false.
    // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint.
    val holdDDLTime = false
    if (partition.nonEmpty) {

      // loadPartition call orders directories created on the iteration order of the this map
      val orderedPartitionSpec = new util.LinkedHashMap[String, String]()
      table.hiveQlTable.getPartCols.asScala.foreach { entry =>
        orderedPartitionSpec.put(entry.getName, partitionSpec.get(entry.getName).getOrElse(""))
      }

      // inheritTableSpecs is set to true. It should be set to false for a IMPORT query
      // which is currently considered as a Hive native command.
      val inheritTableSpecs = true
      // TODO: Correctly set isSkewedStoreAsSubdir.
      val isSkewedStoreAsSubdir = false
      if (numDynamicPartitions > 0) {
        catalog.synchronized {
          catalog.client.loadDynamicPartitions(
            outputPath.toString,
            qualifiedTableName,
            orderedPartitionSpec,
            overwrite,
            numDynamicPartitions,
            holdDDLTime,
            isSkewedStoreAsSubdir)
        }
      } else {
        // scalastyle:off
        // ifNotExists is only valid with static partition, refer to
        // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DML#LanguageManualDML-InsertingdataintoHiveTablesfromqueries
        // scalastyle:on
        val oldPart =
          catalog.client.getPartitionOption(
            catalog.client.getTable(table.databaseName, table.tableName),
            partitionSpec.asJava)

        if (oldPart.isEmpty || !ifNotExists) {
            catalog.client.loadPartition(
              outputPath.toString,
              qualifiedTableName,
              orderedPartitionSpec,
              overwrite,
              holdDDLTime,
              inheritTableSpecs,
              isSkewedStoreAsSubdir)
        }
      }
    } else {
      catalog.client.loadTable(
        outputPath.toString, // TODO: URI
        qualifiedTableName,
        overwrite,
        holdDDLTime)
    }

    // Attempt to delete the staging directory and the inclusive files. If failed, the files are
    // expected to be dropped at the normal termination of VM since deleteOnExit is used.
    try {
      createdTempDir.foreach { path => path.getFileSystem(jobConf).delete(path, true) }
    } catch {
      case NonFatal(e) =>
        logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e)
    }

    // Invalidate the cache.
    sqlContext.cacheManager.invalidateCache(table)

    // It would be nice to just return the childRdd unchanged so insert operations could be chained,
    // however for now we return an empty list to simplify compatibility checks with hive, which
    // does not return anything for insert operations.
    // TODO: implement hive compatibility as rules.
    Seq.empty[InternalRow]
  }

  private def saveAsHiveFile(
      rdd: RDD[InternalRow],
      valueClass: Class[_],
      fileSinkConf: FileSinkDesc,
      conf: SerializableJobConf,
      writerContainer: SparkHiveWriterContainer): Unit = {
    assert(valueClass != null, "Output value class not set")
    conf.value.setOutputValueClass(valueClass)

    val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName
    assert(outputFileFormatClassName != null, "Output format class not set")
    conf.value.set("mapred.output.format.class", outputFileFormatClassName)

    FileOutputFormat.setOutputPath(
      conf.value,
      SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value))
    log.debug("Saving as hadoop file of type " + valueClass.getSimpleName)

    writerContainer.driverSideSetup()
    // 把 RDD 计算任务提交到集群运行
    sc.sparkContext.runJob(rdd, writeToFile _)
    writerContainer.commitJob()

    // Note that this function is executed on executor side
    def writeToFile(context: TaskContext, iterator: Iterator[InternalRow]): Unit = {
      val serializer = newSerializer(fileSinkConf.getTableInfo)
      val standardOI = ObjectInspectorUtils
        .getStandardObjectInspector(
          fileSinkConf.getTableInfo.getDeserializer.getObjectInspector,
          ObjectInspectorCopyOption.JAVA)
        .asInstanceOf[StructObjectInspector]

      val fieldOIs = standardOI.getAllStructFieldRefs.asScala
        .map(_.getFieldObjectInspector).toArray
      val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray
      val wrappers = fieldOIs.zip(dataTypes).map { case (f, dt) => wrapperFor(f, dt)}
      val outputData = new Array[Any](fieldOIs.length)

      writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)

      iterator.foreach { row =>
        var i = 0
        while (i < fieldOIs.length) {
          outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i)))
          i += 1
        }

        writerContainer
          .getLocalFileWriter(row, table.schema)
          .write(serializer.serialize(outputData, standardOI))
      }

      writerContainer.close()
    }
  }

SparkSQLDriver:run():       hiveResponse = execution.stringResult() // 这里又提交了一个任务

  // HiveContext
  // 奇怪,这里的 executedPlan = Scan ExistingRDD[]
    def stringResult(): Seq[String] = executedPlan match {
      case ExecutedCommand(desc: DescribeHiveTableCommand) =>
        // If it is a describe command for a Hive table, we want to have the output format
        // be similar with Hive.
        desc.run(self).map {
          case Row(name: String, dataType: String, comment) =>
            Seq(name, dataType,
              Option(comment.asInstanceOf[String]).getOrElse(""))
              .map(s => String.format(s"%-20s", s))
              .mkString("\t")
        }
      case command: ExecutedCommand =>
        command.executeCollect().map(_.getString(0))

      case other =>
        // other = Scan ExistingRDD[] 继续执行
        val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq
        // We need the types so we can output struct field names
        val types = analyzed.output.map(_.dataType)
        // Reformat to match hive tab delimited output.
        result.map(_.zip(types).map(HiveContext.toHiveString)).map(_.mkString("\t")).toSeq
    }

BroadcastHashJoin

看一下 BroadcastHashJoin [appid#26], [app_id#41], BuildRight 这个算子的broadcast join 实现。

入口程序还是在 SparkPlanexecute()方法,再进入 BroadcastHashJoindoExecute()方法。


  // 使用独立的线程池用于广播
  // 先把小表 execute() 计算,并collect()
  // 根据小表数据,创建 HashedRelation 对象,并 broadcast
  // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value
  // for the same query.
  @transient
  private lazy val broadcastFuture = {
    val numBuildRows = buildSide match {
      case BuildLeft => longMetric("numLeftRows")
      case BuildRight => longMetric("numRightRows")
    }

    // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
    val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
    future {
      // This will run in another thread. Set the execution id so that we can connect these jobs
      // with the correct execution.
      SQLExecution.withExecutionId(sparkContext, executionId) {
        // Note that we use .execute().collect() because we don't want to convert data to Scala
        // types
        val input: Array[InternalRow] = buildPlan.execute().map { row =>
          numBuildRows += 1
          row.copy()
        }.collect()
        // The following line doesn't run in a job so we cannot track the metric value. However, we
        // have already tracked it in the above lines. So here we can use
        // `SQLMetrics.nullLongMetric` to ignore it.
        val hashed = HashedRelation(
          input.iterator, SQLMetrics.nullLongMetric, buildSideKeyGenerator, input.size)
        sparkContext.broadcast(hashed)
      }
    }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
  }

  // 
  protected override def doExecute(): RDD[InternalRow] = {
    val numStreamedRows = buildSide match {
      case BuildLeft => longMetric("numRightRows")
      case BuildRight => longMetric("numLeftRows")
    }
    val numOutputRows = longMetric("numOutputRows")

    // broadcastFuture 用于广播小表
    val broadcastRelation = Await.result(broadcastFuture, timeout)

    // 对于大表,对每一个partition 执行 hashJoin
    streamedPlan.execute().mapPartitions { streamedIter =>
      val hashedRelation = broadcastRelation.value
      hashedRelation match {
        case unsafe: UnsafeHashedRelation =>
          TaskContext.get().internalMetricsToAccumulators(
            InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
        case _ =>
      }
      hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
    }
  }

你可能感兴趣的:(spark)