Spark JDBC系列--源码简析

本文旨在简析 Spark 读取数据库的一些关键源码

Spark如何读取数据库数据

像其他的数据映射框架一样(如hibernate,mybatis等),spark如果想读取数据库数据,也绕不开JDBC链接,毕竟这是代码与数据库“交流”的官方途径。spark如果想快速读取数据库中的数据,需要解决的事情包括但不限于:

  • 分布式读取
  • 原始数据到RDD/DataFrame的映射

所以这篇小文主要围绕这两个方面做下源码的简析

关于spark操作数据库API,可以参考这篇文档:Spark JDBC系列--取数的四种方式

源码简析

1.JDBC API公共入口

入口源码:

org.apache.spark.sql.DataFrameReader
...
private def jdbc(
  url: String,
  table: String,
  parts: Array[Partition],
  connectionProperties: Properties): DataFrame = {
    val props = new Properties()
    extraOptions.foreach { case (key, value) =>
      props.put(key, value)
    }
    // connectionProperties should override settings in extraOptions
    props.putAll(connectionProperties)
    //关键点
    val relation = JDBCRelation(url, table, parts, props)(sparkSession)
    //逻辑分区的创建,action后会触发读取
    sparkSession.baseRelationToDataFrame(relation)
}

通过观察源码可知,四种取数API的参数虽然略有不同,但最终都转换成了一个Array[Partition],即分区条件数组。

2.指定column的取数API分区原理简析

此处列举提供long型column的分区模式的API的分区原理,先看源码:

def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
    if (partitioning == null || partitioning.numPartitions <= 1 ||
      partitioning.lowerBound == partitioning.upperBound) {
      //单分区模式会进入此条件
      return Array[Partition](JDBCPartition(null, 0))
    }
    
    //合法性校验
    val lowerBound = partitioning.lowerBound
    val upperBound = partitioning.upperBound
    ....
      
    //分区调整
    val numPartitions =
      if ((upperBound - lowerBound) >= partitioning.numPartitions) {
        partitioning.numPartitions
      } else {
        upperBound - lowerBound
      }
      
    //计算步长
    val stride: Long = upperBound / numPartitions - lowerBound / numPartitions
    val column = partitioning.column
    var i: Int = 0
    var currentValue: Long = lowerBound
    var ans = new ArrayBuffer[Partition]()
    
    //根据步长,根据提供的最大、最小值做步长累计,确定边界后组装where查询条件
    while (i < numPartitions) {
      //注意此处,会存在单边限制条件的情况,如:JDBCPartition(id >= 901,9)
      val lBound = if (i != 0) s"$column >= $currentValue" else null
      currentValue += stride
      val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
      val whereClause =
        if (uBound == null) {
          lBound
        } else if (lBound == null) {
          s"$uBound or $column is null"
        } else {
          s"$lBound AND $uBound"
        }
      ans += JDBCPartition(whereClause, i)
      i = i + 1
    }
    ans.toArray
  }

测试代码与分区结果如下:

入参为:
lowerBound=1, upperBound=1000, numPartitions=10

对应分区数组为:
JDBCPartition(id < 101 or id is null,0), 
JDBCPartition(id >= 101 AND id < 201,1), 
JDBCPartition(id >= 201 AND id < 301,2), 
JDBCPartition(id >= 301 AND id < 401,3), 
JDBCPartition(id >= 401 AND id < 501,4), 
JDBCPartition(id >= 501 AND id < 601,5), 
JDBCPartition(id >= 601 AND id < 701,6), 
JDBCPartition(id >= 701 AND id < 801,7), 
JDBCPartition(id >= 801 AND id < 901,8), 
JDBCPartition(id >= 901,9)

这种使用方式存在误用场景,即通过指定一段ID的最大最小值(而非整张表真正的最大最小值去取数据),则依然会取出全表数据,且发生数据倾斜,原因就在于第一个分区和最后一个分区的where条件处理,所以如果需要指定范围或更多条件,建议使用支持自定义分区条件的API。

3.数据结果映射

函数:

org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
//获取dataframe的schema,即对数据库的字段类型和spark的数据类型做映射
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)

//具体实现
org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
def resolveTable(url: String, table: String, properties: Properties): StructType = {
  //url中识别出需要使用的方言
 val dialect = JdbcDialects.get(url)
  val ncols = rsmd.getColumnCount
  val fields = new Array[StructField](ncols)
  var i = 0
  ....
  
  while (i < ncols) {
    val columnName = rsmd.getColumnLabel(i + 1)
    val dataType = rsmd.getColumnType(i + 1)
    val typeName = rsmd.getColumnTypeName(i + 1)
    val fieldSize = rsmd.getPrecision(i + 1)
    val fieldScale = rsmd.getScale(i + 1)
    ....
    
    //根据不同方言的约定做映射,未找到时使用默认映射规则
    val columnType =dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
        getCatalystType(dataType, fieldSize, fieldScale, isSigned))
    fields(i) = StructField(columnName, columnType, nullable, metadata.build())
    i = i + 1
  }
  return new StructType(fields)
  
  字段映射的默认配置例举:
  val answer = sqlType match {
  ....   
  case java.sql.Types.BLOB          => BinaryType
  case java.sql.Types.BOOLEAN       => BooleanType
  case java.sql.Types.CHAR          => StringType
  case java.sql.Types.CLOB          => StringType
  case java.sql.Types.DATALINK      => null
  case java.sql.Types.DATE          => DateType
  case java.sql.Types.DECIMAL
    if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
  case java.sql.Types.DECIMAL       => DecimalType.SYSTEM_DEFAULT
  case java.sql.Types.DISTINCT      => null
  case java.sql.Types.DOUBLE        => DoubleType
  case java.sql.Types.FLOAT         => FloatType
  ....
}

此处例举MySQL的方言实现:

所有的方言实现都此包下:org.apache.spark.sql.jdbc.*,实现请自行参考。

MySQL方言:
private case object MySQLDialect extends JdbcDialect {

  override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")

  override def getCatalystType(
      sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
    //关键实现
    if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
      // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
      // byte arrays instead of longs.
      md.putLong("binarylong", 1)
      Option(LongType)
    } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
      Option(BooleanType)
    } else None
  }
  ....
}

从源码可以看出,MySQL只对bit和tinyint类型进行了约束,其他类型使用了spark的默认配置,所以在读取数据时,需要考虑spark中的方言映射,是否对已存在的数据造成影响,避免数据失真。
此时 JDBCRelation 对象已经完成构造。

4.RDD构造与逻辑分区生成

根据之前生成的 JDBCRelation,sparkSession会把任务加入逻辑执行计划。当遇到action操作时,会转为物理执行计划,

org.apache.spark.sql.SparkSession
//逻辑执行计划构建,细节不写了,源码我也没怎么研究过
def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
Dataset.ofRows(self, LogicalRelation(baseRelation))
}

org.apache.spark.sql.execution.datasources.DataSourceStrategy
//物理执行计划
object DataSourceStrategy extends Strategy with Logging {
  def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
    case PhysicalOperation.....

    //JDBCRelation继承了PrunedFilteredScan,进入此case分支,并调用buildScan方法
    case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _)) =>
      pruneFilterProject(
        l,
        projects,
        filters,
        (a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil

    case PhysicalOperation..... 
  }    

JDBCRelation 的 buildScan 方法执行时,会调用JDBCRDD的 scanTable 方法新建 RDD,其中计算前加入的 filter 条件,会合并到JDBC查询where条件中,使用AND连接:

private[jdbc] class JDBCRDD(
    sc: SparkContext,
    getConnection: () => Connection,
    schema: StructType,
    fqTable: String,
    columns: Array[String],
    filters: Array[Filter],
    partitions: Array[Partition],
    url: String,
    properties: Properties)
  extends RDD[InternalRow](sc, Nil) {

  override def getPartitions: Array[Partition] = partitions
  
  .....
    
  private def getWhereClause(part: JDBCPartition): String = {
    if (part.whereClause != null && filterWhereClause.length > 0) {
      "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})"
    } else if (part.whereClause != null) {
      "WHERE " + part.whereClause
    } else if (filterWhereClause.length > 0) {
      "WHERE " + filterWhereClause
    } else {
      ""
    }
  }
  
  //compute方法为action触发时,执行的SQL语句,并对结果按之前的约定做数据映射
  override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] =
    new Iterator[InternalRow] {
    。。。。
    //实现细节不再展开,主要是JDBC查询操作和数据类型映射
}

filter条件使用示例:

val url = "jdbc:mysql://mysqlHost:3306/database"
val tableName = "table"
val columnName = "id"
val lowerBound = getMinId()
val upperBound = getMaxId()
val numPartitions = 200

// 设置连接用户&密码
val prop = new java.util.Properties
prop.setProperty("user","username")
prop.setProperty("password","pwd")

// 对mysql数据进行过滤
val jdbcDF = sqlContext.read.jdbc(url,tableName, columnName, lowerBound, upperBound,prop).where("date='2017-11-30'").filter("name is not null")

where 和 filter 是等价的,过滤条件将在 where 语句中生效,多个条件会用And进行拼接。

结语

读取数据库数据时,可以到对应的源码中,debug分析。

你可能感兴趣的:(Spark JDBC系列--源码简析)