Spark-SQL之自定义数据源的构建

自定义数据源的构建

常见的trait

下面是interfaces.scala中常见的一些接口:

下面各种类、方法,在源码里面都有详细的注释。

//BaseRelation是Spark提供的一个标准的接口
//由于是抽象类,如果要实现自己的外部数据源,必须要实现它里面的一些方法
//这个里面是含有schema的元组集合(字段:字段类型)
//继承了BaseRelation的类,必须以StructType这个形式产生数据的schema
//继承了`Scan`类之后,要实现它里面的相应的方法
@InterfaceStability.Stable
abstract class BaseRelation {
  def sqlContext: SQLContext
  def schema: StructType
.....
}

//在BaseRelation的子类来返回下面几种方式的scan,这个东西由谁创建?由RelationProvider创建

//A BaseRelation that can produce all of its tuples as an RDD of Row objects.
//读取数据,构建RDD[ROW]
//可以理解为select * from xxx   把所有数据读取出来变成RDD[Row]
trait TableScan {
  def buildScan(): RDD[Row]
}

//A BaseRelation that can eliminate unneeded columns before producing an RDD 
//containing all of its tuples as Row objects.
//可以理解为select a,b from xxx   裁剪的scan,读取需要的列变成RDD[Row]
trait PrunedScan {
  def buildScan(requiredColumns: Array[String]): RDD[Row]
}

//可以理解为select a,b from xxx where a>10  读取需要的列,再进行过滤,变成RDD[Row]
trait PrunedFilteredScan {
  def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row]
}

//写数据,插入数据,无返回
trait InsertableRelation {
  def insert(data: DataFrame, overwrite: Boolean): Unit
}

trait CatalystScan {
  def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row]
}



//用来创建上面的BaseRelation
//传进来指定数据源的参数:比如url、dbtable、user、password等(这个就是你要连接的那个数据源)
//最后返回BaseRelation(已经带有了传进来参数的属性了)
trait RelationProvider {
  def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation
}

// Saves a DataFrame to a destination (using data source-specific parameters)
//mode: SaveMode,当目标已经存在,是用什么方式保存
//parameters: Map[String, String] :指定的数据源参数
//要保存的DataFrame,比如执行查询之后的rows
//返回BaseRelation
trait CreatableRelationProvider {
  def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      data: DataFrame): BaseRelation
}

//把你的数据源起一个简短的别名
trait DataSourceRegister {
//override def shortName(): String = "parquet"(举例)
  def shortName(): String
}

//比CreatableRelationProvider多了个schema参数
trait SchemaRelationProvider {
  def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      schema: StructType): BaseRelation
}

通过JDBCRelation的源码了解外部数据源的执行

点击JdbcRelationProvider ,可以看到它是如何实现的

class JdbcRelationProvider extends CreatableRelationProvider
  with RelationProvider with DataSourceRegister {
  
  override def shortName(): String = "jdbc"

  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {

	//这个option就是去连接JDBC的那些信息,比如url、dbtable、user等等
	//具体看JDBCOptions的源码
    val jdbcOptions = new JDBCOptions(parameters)
    val resolver = sqlContext.conf.resolver
    val timeZoneId = sqlContext.conf.sessionLocalTimeZone

	//这个schema如何拿到的???
	//通过JDBC metastore获取得到的
	//具体可以getSchema的源码
    val schema = JDBCRelation.getSchema(resolver, jdbcOptions)
    val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions)
	
	//创建JDBCRelation,JDBCRelation这个是把上面说的那些scan的东西给实现出来
    JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession)
  }
  override def createRelation(
  ........

下面是JDBCRelation.scala

//可以看一下它里面实现的方法,底层就是拼sql
private[sql] case class JDBCRelation(
    override val schema: StructType,
    parts: Array[Partition],
    jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)  //可以点击scanTable具体分析一下,都是拼SQL
  extends BaseRelation
  with PrunedFilteredScan
  with InsertableRelation {
....................

//实现Scan
  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    JDBCRDD.scanTable(
      sparkSession.sparkContext,
      schema,
      requiredColumns,
      filters,
      parts,
      jdbcOptions).asInstanceOf[RDD[Row]]
  }
  //实现写数据
    override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    data.write
      .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
      .jdbc(jdbcOptions.url, jdbcOptions.tableOrQuery, jdbcOptions.asProperties)
  }
  ..........
  }

总结:Spark去处理JDBC数据源就是:拼sql,然后交给JDBC API编程,然后产生DataFrame。 上面是JDBC数据源的如何实现的,还有其它数据源比如json、parquet、text等等。

自己实现一个外部数据源(核心重要)

现在有个文本:

//编号、名字、性别、工资、年终奖
101,zhansan,0,10000,200000
102,lisi,0,150000,250000
103,wangwu,1,3000,5
104,zhaoliu,2,500,6

这个文本是没有schema的,之前有两种方式把它转换成DataFrame。一种是通过case class反射的方式,另一种是通过创建带有Rows的RDD,自定义一个schema,然后再用通过createDataFrame来创建DataFrame。

现在通外部数据源把它来实现。

上面的JDBCRelation是通过JdbcRelationProvider来实现的。

定义一个DefaultSource(必须写DefaultSource,源码里定死了,不然会报找不到数据源),继承CreatableRelationProvider,参考上面JDBC的JdbcRelationProvider。 定义一个TextDataSourceRelation,继承BaseRelation和TableScan,并实现TableScan,参考上面JDBC的JdbcRelation。

下面是完整代码:

object TextApp {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder()
      .appName("TextApp")
      .master("local[2]")
      .getOrCreate()

    //只要写到包名就可以了...sql.text,不用这样写...sql.text.DefaultSource
    val df = spark.sqlContext.read.format("com.ruozedata.spark.sql.text")
      .load("D:\\data.txt")


    df.show()

    spark.stop()
  }
}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType}

object Utils {
  def castTo(value:String,dataType:DataType) ={
    dataType match {
      case _ : IntegerType => value.toInt
      case _:LongType =>value.toLong
      case _:StringType => value
    }
  }
}
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType


//DefaultSource这个名字不能乱写,底层就是自动去拼包名加Datasource,with SchemaRelationProvider  //最佳实践
//RelationProvider用来创建数据的关系,SchemaRelationProvider用来明确schema信息。
class DefaultSource
  extends RelationProvider
    with SchemaRelationProvider {
  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    createRelation(sqlContext,parameters,null)
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = {
    val path=parameters.get("path")
    path match {
      case Some(p) => new TextDataSourceRelation(sqlContext,p,schema)
      case _ => throw new IllegalArgumentException("Path is required for custom-datasource format!!")
    }
  }
}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}

//在编写Relation时,需要实现BaseRelation来重写自定数据源的schema信息,然后实现序列化接口,为了网络传输
class TextDataSourceRelation(override val sqlContext: SQLContext,path:String,userSchema: StructType) extends BaseRelation
  with   Serializable
   with TableScan with Logging {


  //如果传进来的schema不为空,就用传进来的schema,否则就用自定义的schema
  override def schema: StructType = {
    if(userSchema != null){
      userSchema
    }else{
      StructType(
        StructField("id",LongType,false) ::
          StructField("name",StringType,false) ::
          StructField("gender",StringType,false) ::
          StructField("salary",LongType,false) ::
          StructField("comm",LongType,false) :: Nil
      )
    }
  }

  //把数据读进来,读进来之后把它转换成 RDD[Row]
  override def buildScan(): RDD[Row] = {
    logWarning("this is ruozedata buildScan....")
    //读取数据,变成为RDD
    //wholeTextFiles会把文件名读进来,可以通过map(_._2)把文件名去掉,第一位是文件名,第二位是内容
    val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(_._2)

    //拿到schema
    val schemaField = schema.fields

    //rdd.collect().foreach(println)

    //rdd + schemaField 把rdd和schemaField解析出来拼起来
    val rows = rdd.map(fileContent => {
      //拿到每一行的数据
      val lines = fileContent.split("\n")
      //每一行数据按照逗号分隔,分隔之后去空格,然后转成一个seq集合
      val data = lines.map(_.split(",").map(_.trim)).toSeq

      //zipWithIndex
      val result = data.map(x => x.zipWithIndex.map {
        case (value, index) => {

          val columnName = schemaField(index).name
          //castTo里面有两个参数,第一个参数需要给个判断,如果是字段是性别,里面再进行判断再转换一下,如果不是性别就直接用这个字段
          Utils.castTo(if(columnName.equalsIgnoreCase("gender")){
            if(value == "0"){
              "man"
            }else if(value == "1"){
              "woman"
            } else{
              "unknown"
            }
          }else{
            value
          },schemaField(index).dataType)

        }
      })

      result.map(x => Row.fromSeq(x))
    })

    rows.flatMap(x => x)

  }
}

你可能感兴趣的:(Spark)