SparkSQL自定义外部数据源源码分析及案例实现

通过查看JDBC方式源代码入口分析:

源码分析

//继承BaseRelation的类必须能够以`StructType`的形式产生其数据模式。具体的实现应继承自后代Scan类之一
abstract class BaseRelation {
  def sqlContext: SQLContext
  def schema: StructType

def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes
def needConversion: Boolean = true
def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters
}


// 全表扫描 相当于 select * from xxx
trait TableScan {
  def buildScan(): RDD[Row]
}

// 列裁剪  过滤掉不需要的列
trait PrunedScan {
  def buildScan(requiredColumns: Array[String]): RDD[Row]
}

// 列裁剪加上行过滤  有点类似于 select col1,col2 ... limit 10
trait PrunedFilteredScan {
  def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row]
}

//可写入
trait InsertableRelation {
  def insert(data: DataFrame, overwrite: Boolean): Unit
}

案例实现

需要读取的数据(没有字段类型和schema):

101,zhansan,0,10000,200000
102,lisi,0,150000,250000
103,wangwu,1,3000,5
104,zhaoliu,2,500,6
102,lisi,0,250000,250000

代码

//注意:必须以DefaultSource 为类名,如果不以这个为类名需要指定一个datasource的名字,否则Spark SQL会将类名DefaultSource附加到路径中,以减少冗长的调用 比如:“org.apache.spark.sql.json”将解析为数据源“ org.apache.spark.sql.json.DefaultSource”
class DefaultSource extends RelationProvider with SchemaRelationProvider{

  def createRelation(
                      sqlContext: SQLContext,
                      parameters: Map[String, String],
                      schema: StructType): BaseRelation = {

    val path = parameters.get("path")

    path match {
      case Some(x) => new TextDataSourceRelation(sqlContext,x,schema)
      case _ => throw new IllegalArgumentException("path is required...")
    }
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    createRelation(sqlContext,parameters,null)
  }

}

紧接着

//这里只实现了全表扫描功能(TableScan),
class TextDataSourceRelation(override val sqlContext: SQLContext,
                             path:String,
                             userSchema:StructType)
  extends BaseRelation with TableScan with Logging{


  override def schema: StructType = {

    if (null != userSchema){
      userSchema
    }else {
	//自定义schema
      StructType(
        StructField("id",LongType,false) ::
          StructField("name",StringType,false) ::
          StructField("gender",StringType,false) ::
          StructField("salary",LongType,false) ::
          StructField("comm",LongType,false) :: Nil
      )

    }


  }

  override def buildScan(): RDD[Row] = {
    logInfo("this is custom buildScan")
    //wholeTextFiles 读取整个文本
    val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(_._2)

    val fieldsSchema = schema.fields

    val rows = rdd.map(fileContent =>{
	  //按行进行切割
      val lines = fileContent.split("\n")
      //每行中以逗号进行切割
      val data = lines.map(_.split(",").map(x=>x.trim)).toSeq
      val result = data.map(x => x.zipWithIndex.map{
        case (value,index) => {
          val columnName = fieldsSchema(index).name
          val castValue = if (columnName.equalsIgnoreCase("gender")){
            if (value.equalsIgnoreCase("0")){
              "man"
            }else if(value.equalsIgnoreCase("1")){
              "woman"
            }else{
              "unknown"
            }
          }else{
            value
          }
          SqlUtil.castTo(castValue,fieldsSchema(index).dataType)
        }

      })

      result.map(x => Row.fromSeq(x))

    })
    rows.flatMap(x =>x)

  }
}

object SqlUtil {
  def castTo(value:String,dataType:DataType)={
    dataType match {
      case _:LongType => value.toLong
      case _:StringType => value
    }
  }
}

main方法

object TestCustomSouce {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().appName("TextApp").master("local[2]").getOrCreate()

    val df = spark.read.format("com.kzw.bigdata.spark.sql04").option("path","input/custom.txt").load()
    //df.show()
    df.printSchema()


    df.createOrReplaceTempView("customTable")

    val sql = "select * from customTable"
    spark.sql(sql).show()


    val sql2 = "select id,sum(salary) from customTable group by id"
    spark.sql(sql2).show()

    spark.stop()

  }

}

结果显示:

root
 |-- id: long (nullable = false)
 |-- name: string (nullable = false)
 |-- gender: string (nullable = false)
 |-- salary: long (nullable = false)
 |-- comm: long (nullable = false)

+---+-------+-------+------+------+
| id|   name| gender|salary|  comm|
+---+-------+-------+------+------+
|101|zhansan|    man| 10000|200000|
|102|   lisi|    man|150000|250000|
|103| wangwu|  woman|  3000|     5|
|104|zhaoliu|unknown|   500|     6|
|102|   lisi|    man|150000|250000|
+---+-------+-------+------+------+

+---+-----------+
| id|sum(salary)|
+---+-----------+
|103|       3000|
|104|        500|
|101|      10000|
|102|     300000|
+---+-----------+

你可能感兴趣的:(Spark,源码)