Spark sql 读文件的源码分析

从spark jobs监控页面上经常看到这种job:
    Listing leaf files and directories for 100 paths:

 如图:
    Spark sql 读文件的源码分析_第1张图片
这其实是spark sql在读一大堆文件。

最简单的demo语句,这样读文件:
    val df = session.read.json("path/to/your/resources/data.json")
或者 session.read.parquet(file_path) 或者 session.read.csv(file_path)

本文详细看看 read.* 的实现过程。


首先调用 SparkSession.scala中的 read 函数,而 def read: DataFrameReader = new DataFrameReader(self),所以 read只是返回了一个DataFrameReader对象,然后调用".parquet"或者".csv"等,其实是调的DataFrameReader.scala中的 json/csv/parquet 函数,例如parquet() 和 csv() 如下:

def format(source: String): DataStreamReader = {
    this.source = source
    this
  }

def parquet(path: String): DataFrame = {
    format("parquet").load(path)
  }

def orc(path: String): DataFrame = {
    format("orc").load(path)
  }

可以看出,指定好文件格式后,后面都是调用以paths为参数的 load 函数。load()函数定义如下: 

def load(paths: String*): DataFrame = {
    if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
      throw new AnalysisException("Hive data source can only be used with tables, you can not " +
        "read files of Hive data source directly.")
    }

    val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
    if (classOf[DataSourceV2].isAssignableFrom(cls)) {
      val ds = cls.newInstance().asInstanceOf[DataSourceV2]
      if (ds.isInstanceOf[ReadSupport]) {
        val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
          ds = ds, conf = sparkSession.sessionState.conf)
        val pathsOption = {
          val objectMapper = new ObjectMapper()
          DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray)
        }
        Dataset.ofRows(sparkSession, DataSourceV2Relation.create(
          ds, sessionOptions ++ extraOptions.toMap + pathsOption,
          userSpecifiedSchema = userSpecifiedSchema))
      } else {
        loadV1Source(paths: _*)
      }
    } else {
      loadV1Source(paths: _*)
    }
  }

其中DataSource.lookupDataSource()传入两个参数,source是指文件格式名称(csv/json/parquet等),另一个参数是sessionState的全局配置文件。source入参后作为数据源的provider,首先从backwardCompatibilityMap查找对应的类名,然后调用Utils.getContextOrSparkClassLoader的类加载器,加载相应的类。具体见DataSource.scala中的lookupDataSource()方法:

  /** Given a provider name, look up the data source class definition. */
  def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
    val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match {
      case name if name.equalsIgnoreCase("orc") &&
          conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native" =>
        classOf[OrcFileFormat].getCanonicalName
      case name if name.equalsIgnoreCase("orc") &&
          conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "hive" =>
        "org.apache.spark.sql.hive.orc.OrcFileFormat"
      case "com.databricks.spark.avro" if conf.replaceDatabricksSparkAvroEnabled =>
        "org.apache.spark.sql.avro.AvroFileFormat"
      case name => name
    }
    val provider2 = s"$provider1.DefaultSource"
    val loader = Utils.getContextOrSparkClassLoader
    val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)

    try {
      serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider1)).toList match {
        // the provider format did not match any given registered aliases
        case Nil =>
          try {
            Try(loader.loadClass(provider1)).orElse(Try(loader.loadClass(provider2))) match {
              case Success(dataSource) =>
                // Found the data source using fully qualified path
                dataSource
              case ...
            }
          }
    } catch {...}

从该方法中的backwardCompatibilityMap可知,spark sql支持 jdbc/json/parquet/csv/libsvm/orc/socket等数据源。

一般情况下, DataFrameReader 的 load()会调用私有的 loadV1Source方法并返回。

  private def loadV1Source(paths: String*) = {
    // Code path for data source v1.
    sparkSession.baseRelationToDataFrame(
      DataSource.apply(
        sparkSession,
        paths = paths,
        userSpecifiedSchema = userSpecifiedSchema,
        className = source,
        options = extraOptions.toMap).resolveRelation())
  }

先看DataSource.apply()。DataSource是case类,调用apply方法产生了伴生对象(类似C++默认的构造函数),其实相当于 new了一个DataSource对象。传入各种参数,随后调用DataSoure.resolveRelation方法,如下:

  /**
   * 该函数创建了一个 resolved [[BaseRelation]],可以用它从所在的[[DataSource]]类读取数据,或将数据写入。
   */
  def resolveRelation(checkFilesExist: Boolean = true): BaseRelation = {
    // 这里 providingClass 是一个lazy的 DataSource.lookupDataSource(source_className, session.conf)
    // 跟前面判断是否是ReadSupport实例时的一样,只不过这里是DataSource的成员变量。
    val relation = (providingClass.newInstance(), userSpecifiedSchema) match {
      // TODO: Throw when too much is given.
      case (dataSource: SchemaRelationProvider, Some(schema)) =>
        dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions, schema)
      case (dataSource: RelationProvider, None) =>
        dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
      case (_: SchemaRelationProvider, None) =>
        throw new AnalysisException(s"A schema needs to be specified when using $className.")
      case (dataSource: RelationProvider, Some(schema)) =>
        val baseRelation =
          dataSource.createRelation(sparkSession.sqlContext, caseInsensitiveOptions)
        if (baseRelation.schema != schema) {
          throw new AnalysisException(s"$className does not allow user-specified schemas.")
        }
        baseRelation

      // We are reading from the results of a streaming query. Load files from the metadata log
      // instead of listing them using HDFS APIs.
      case (format: FileFormat, _)
          if FileStreamSink.hasMetadata(
            caseInsensitiveOptions.get("path").toSeq ++ paths,
            sparkSession.sessionState.newHadoopConf()) =>
        //... 这段不看

        HadoopFsRelation(
          )(sparkSession)

      // 普通文件看这里
      // This is a non-streaming file based datasource.
      case (format: FileFormat, _) =>
        val globbedPaths =
          checkAndGlobPathIfNecessary(checkEmptyGlobPath = true, checkFilesExist = checkFilesExist)
        val useCatalogFileIndex = sparkSession.sqlContext.conf.manageFilesourcePartitions &&
          catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog &&
          catalogTable.get.partitionColumnNames.nonEmpty
        val (fileCatalog, dataSchema, partitionSchema) = if (useCatalogFileIndex) {
          val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes
          val index = new CatalogFileIndex(
            sparkSession,
            catalogTable.get,
            catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize))
          (index, catalogTable.get.dataSchema, catalogTable.get.partitionSchema)
        } else {
          val index = createInMemoryFileIndex(globbedPaths)
          val (resultDataSchema, resultPartitionSchema) =
            getOrInferFileFormatSchema(format, Some(index))
          (index, resultDataSchema, resultPartitionSchema)
        }

        HadoopFsRelation(
          fileCatalog,
          partitionSchema = partitionSchema,
          dataSchema = dataSchema.asNullable,
          bucketSpec = bucketSpec,
          format,
          caseInsensitiveOptions)(sparkSession)

      case _ =>
        throw new AnalysisException(
          s"$className is not a valid Spark SQL Data Source.")
    }

    relation match {
      case hs: HadoopFsRelation =>
        SchemaUtils.checkColumnNameDuplication(
          hs.dataSchema.map(_.name),
          "in the data schema",
          equality)
        SchemaUtils.checkColumnNameDuplication(
          hs.partitionSchema.map(_.name),
          "in the partition schema",
          equality)
        DataSourceUtils.verifyReadSchema(hs.fileFormat, hs.dataSchema)
      case _ =>
        SchemaUtils.checkColumnNameDuplication(
          relation.schema.map(_.name),
          "in the data schema",
          equality)
    }

    relation
  }

这个函数比较啰嗦,注意如果是正常的静态文件的话,会走这个 case  =(format: FileFormat, _)=> 里面的这几行:

        {
          val index = createInMemoryFileIndex(globbedPaths)
          val (resultDataSchema, resultPartitionSchema) =
            getOrInferFileFormatSchema(format, Some(index))
          (index, resultDataSchema, resultPartitionSchema)
        }
        最后返回一个HadoopFsRelation()对象。

 这是因为之前DataSource.lookupDataSource()中把source参数转化成了provider1,provider是source对应的文件格式类,也就是FileFormat的实现(比如CSVFileFormat/ParquetFileFormat)。

这里createInMemoryFileIndex根据globbedPaths/userSpecifiedSchema等创建了InMemoryFileIndex对象:

  private def createInMemoryFileIndex(globbedPaths: Seq[Path]): InMemoryFileIndex = {
    val fileStatusCache = FileStatusCache.getOrCreate(sparkSession)
    new InMemoryFileIndex(
      sparkSession, globbedPaths, options, userSpecifiedSchema, fileStatusCache)
  }

InMemoryFileIndex对象会在初始化时直接执行 refresh0() then, listLeafFiles()。

  private def refresh0(): Unit = {
    val files = listLeafFiles(rootPaths)
    cachedLeafFiles =
      new mutable.LinkedHashMap[Path, FileStatus]() ++= files.map(f => f.getPath -> f)
    cachedLeafDirToChildrenFiles = files.toArray.groupBy(_.getPath.getParent)
    cachedPartitionSpec = null
  }

listLeafFiles(rootPaths) 会列出指定路径的所有叶子文件,当某个path下的文件超过threshold时,
该方法会提交一个spark job来并行地列出这些文件。
执行完成后,把结果缓存到 fileStatusCache 对象中去,并且以mutable.LinkedHashSet[FileStatus]的形式返回给上层对象,也就是 refresh0()所在的 InMemoryFileIndex 对象。

  /**
   * List leaf files of given paths. This method will submit a Spark job to do parallel
   * listing whenever there is a path having more files than the parallel partition discovery
   * discovery threshold.
   */
  def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = {
    val output = mutable.LinkedHashSet[FileStatus]()
    val pathsToFetch = mutable.ArrayBuffer[Path]()
    for (path <- paths) {
      fileStatusCache.getLeafFiles(path) match {
        case Some(files) =>
          HiveCatalogMetrics.incrementFileCacheHits(files.length)
          output ++= files
        case None =>
          pathsToFetch += path
      }
      Unit // for some reasons scalac 2.12 needs this; return type doesn't matter
    }
    val filter = FileInputFormat.getInputPathFilter(new JobConf(hadoopConf, this.getClass))
    val discovered = InMemoryFileIndex.bulkListLeafFiles(
      pathsToFetch, hadoopConf, filter, sparkSession)
    discovered.foreach { case (path, leafFiles) =>
      HiveCatalogMetrics.incrementFilesDiscovered(leafFiles.size)
      fileStatusCache.putLeafFiles(path, leafFiles.toArray)
      output ++= leafFiles
    }
    output
  }
}

下面看 InMemoryFileIndex.bulkListLeafFiles(),如果传入目录下的文件较少(<32),则直接在driver端运行:

private def listLeafFiles(
      path: Path,
      hadoopConf: Configuration,
      filter: PathFilter,
      sessionOpt: Option[SparkSession]): Seq[FileStatus] = {} 

来完成上面的 def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = {}。
否则会启一个job,利用集群中的多个executor来并行的完成(列出所有LeafFiles的任务)。具体代码如下:

/**
   * Lists a collection of paths recursively. Picks the listing strategy adaptively depending
   * on the number of paths to list.
   *
   * This may only be called on the driver.
   *
   * @return for each input path, the set of discovered files for the path
   */
  private[sql] def bulkListLeafFiles(
      paths: Seq[Path],
      hadoopConf: Configuration,
      filter: PathFilter,
      sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = {

    // Short-circuits parallel listing when serial listing is likely to be faster.
    // 如果文件个数小于spark.sql.sources.parallelPartitionDiscovery.threshold(默认是32),
    //     则直接 listLeafFiles(四个参数)并返回。
    // 否则大于32时,为防止某个目录下文件数过多,会开一个job专门来查找有哪些文件,并行执行。
    if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) {
      return paths.map { path =>
        (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession)))
      }
    }

    logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}")
    HiveCatalogMetrics.incrementParallelListingJobCount(1)

    val sparkContext = sparkSession.sparkContext
    val serializableConfiguration = new SerializableConfiguration(hadoopConf)
    val serializedPaths = paths.map(_.toString)
    val parallelPartitionDiscoveryParallelism =
      sparkSession.sessionState.conf.parallelPartitionDiscoveryParallelism

    // Set the number of parallelism to prevent following file listing from generating many tasks
    // in case of large #defaultParallelism.
    val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism)

    val previousJobDescription = sparkContext.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION)
    val statusMap = try {
      val description = paths.size match {
        case 0 =>
          s"Listing leaf files and directories 0 paths"
        case 1 =>
          s"Listing leaf files and directories for 1 path:
${paths(0)}"         case s =>           s"Listing leaf files and directories for $s paths:
${paths(0)}, ..."       }       sparkContext.setJobDescription(description)       sparkContext         .parallelize(serializedPaths, numParallelism)         .mapPartitions { pathStrings =>           val hadoopConf = serializableConfiguration.value           pathStrings.map(new Path(_)).toSeq.map { path =>             (path, listLeafFiles(path, hadoopConf, filter, None))           }.iterator         }.map { case (path, statuses) =>         val serializableStatuses = statuses.map { status =>           // Turn FileStatus into SerializableFileStatus so we can send it back to the driver           val blockLocations = status match {             case f: LocatedFileStatus =>               f.getBlockLocations.map { loc =>                 SerializableBlockLocation(                   loc.getNames,                   loc.getHosts,                   loc.getOffset,                   loc.getLength)               }             case _ =>               Array.empty[SerializableBlockLocation]           }           SerializableFileStatus(             status.getPath.toString,             status.getLen,             status.isDirectory,             status.getReplication,             status.getBlockSize,             status.getModificationTime,             status.getAccessTime,             blockLocations)         }         (path.toString, serializableStatuses)       }.collect()     } finally {       sparkContext.setJobDescription(previousJobDescription)     }     // turn SerializableFileStatus back to Status     statusMap.map { case (path, serializableStatuses) =>       val statuses = serializableStatuses.map { f =>         val blockLocations = f.blockLocations.map { loc =>           new BlockLocation(loc.names, loc.hosts, loc.offset, loc.length)         }         new LocatedFileStatus(           new FileStatus(             f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime,             new Path(f.path)),           blockLocations)       }       (new Path(path), statuses)     }   }

配置参数spark.sql.sources.parallelPartitionDiscovery.threshold在spark 3.1.2中的默认值仍是32.

在private[sql] def bulkListLeafFiles()中启动job时,先设置各种参数,比如并行度numParallelism设置为不超过配置文件的 spark.sql.sources.parallelPartitionDiscovery.parallelism (默认值为1000)。
然后设置job的 description,最后调用collect()触发job执行。

执行列出叶子文件后获取的文件属性可以从listLeafFiles()中的 LocatedFileStatus对象看出来:

    new LocatedFileStatus(f.getLen, f.isDirectory, f.getReplication, f.getBlockSize,
            f.getModificationTime, 0, null, null, null, null, f.getPath, locations)

Where are we ? 再复制下原来 DataFrameReader 的 load()中调用的 loadV1Source()方法:

  private def loadV1Source(paths: String*) = {
    // Code path for data source v1.
    sparkSession.baseRelationToDataFrame(
      DataSource.apply(
        sparkSession,
        paths = paths,
        userSpecifiedSchema = userSpecifiedSchema,
        className = source,
        options = extraOptions.toMap).resolveRelation())
  }

现在我们知道 DataSource()的 resolveRelation()做的工作是根据文件格式,createInMemoryFileIndex,最后返回 HadoopFsRelation(),才表示这个 relation 被 resolved 了。
HadoopFsRelation是这样一个容器,它 Acts as a container for all of the metadata required to read from a datasource. All discovery, resolution and merging logic for schemas and partitions has been removed. 接下来就可以通过这个HadoopFsRelation()来读写集群文件了。

resolveRelation()外面的 baseRelationToDataFrame() 定义:

  def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
    Dataset.ofRows(self, LogicalRelation(baseRelation))
  }

该函数把 为外部的数据源(比如hadoop文件)而创建的 BaseRelation 转换成一个 DataFrame。
从代码可以看出,先是把 baseRelation 封装成了 LogicalRelation,然后调用 ofRows()来执行了这个逻辑计划:

  def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
    val qe = sparkSession.sessionState.executePlan(logicalPlan)
    qe.assertAnalyzed()
    new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
  }

spark sql 直接.read.json() 方式,相比.sql("select * from table") 省去了词法解析、语法解析,生成AST,用关系代数解析成各种Relation等的过程,而是直接用DataSource 创建出了BaseRelation,并且resolved 成逻辑计划,最后执行这个逻辑计划得到RDD组成的 DataFrame。具体参见下面代码:

// DataSourceScanExec.scala 中的 case class FileSourceScanExec
private lazy val inputRDD: RDD[InternalRow] = {
    // Update metrics for taking effect in both code generation node and normal node.
    updateDriverMetrics()
    val readFile: (PartitionedFile) => Iterator[InternalRow] =
      relation.fileFormat.buildReaderWithPartitionValues(
        sparkSession = relation.sparkSession,
        dataSchema = relation.dataSchema,
        partitionSchema = relation.partitionSchema,
        requiredSchema = requiredSchema,
        filters = pushedDownFilters,
        options = relation.options,
        hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options))

    relation.bucketSpec match {
      case Some(bucketing) if relation.sparkSession.sessionState.conf.bucketingEnabled =>
        createBucketedReadRDD(bucketing, readFile, selectedPartitions, relation)
      case _ =>
        createNonBucketedReadRDD(readFile, selectedPartitions, relation)
    }
  }

spark sql 读文件的整体流程大致如此。

中间省略了一些细节,注意红色部分,createInMemoryFileIndex()要启一个job,随后的 getOrInferFileFormatSchema() 进行table/csv/json等格式的Schema推断时,也要启动至少一个 job 来完成。比如 parquet文件需要创建一个job,根据传进来的文件来解析出schema信息,这部分源码解析推荐 夜月行者的博客,写的很清楚。

又比如csv文件,CSVFileFormat的inferSchema 需要借助抽象类CSVDataSource派生出的TextInputCSVDataSource 或 MultiLineCSVDataSource 来完成schema的推断,此时就先 .take(1) 一行来推断出header 的Schema(take是一个Action),然后再执行下面两行来返回最终的Schema: 

    val sampled = CSVUtils.sample(tokenRDD, parsedOptions)
    CSVInferSchema.infer(sampled, header, parsedOptions)

 此时sample又是一个触发 job 的Action操作。

 

 

你可能感兴趣的:(Spark,spark)