通过查看JDBC方式源代码入口分析:
//继承BaseRelation的类必须能够以`StructType`的形式产生其数据模式。具体的实现应继承自后代Scan类之一
abstract class BaseRelation {
def sqlContext: SQLContext
def schema: StructType
def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes
def needConversion: Boolean = true
def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters
}
// 全表扫描 相当于 select * from xxx
trait TableScan {
def buildScan(): RDD[Row]
}
// 列裁剪 过滤掉不需要的列
trait PrunedScan {
def buildScan(requiredColumns: Array[String]): RDD[Row]
}
// 列裁剪加上行过滤 有点类似于 select col1,col2 ... limit 10
trait PrunedFilteredScan {
def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row]
}
//可写入
trait InsertableRelation {
def insert(data: DataFrame, overwrite: Boolean): Unit
}
需要读取的数据(没有字段类型和schema):
101,zhansan,0,10000,200000
102,lisi,0,150000,250000
103,wangwu,1,3000,5
104,zhaoliu,2,500,6
102,lisi,0,250000,250000
代码
//注意:必须以DefaultSource 为类名,如果不以这个为类名需要指定一个datasource的名字,否则Spark SQL会将类名DefaultSource附加到路径中,以减少冗长的调用 比如:“org.apache.spark.sql.json”将解析为数据源“ org.apache.spark.sql.json.DefaultSource”
class DefaultSource extends RelationProvider with SchemaRelationProvider{
def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
schema: StructType): BaseRelation = {
val path = parameters.get("path")
path match {
case Some(x) => new TextDataSourceRelation(sqlContext,x,schema)
case _ => throw new IllegalArgumentException("path is required...")
}
}
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
createRelation(sqlContext,parameters,null)
}
}
紧接着
//这里只实现了全表扫描功能(TableScan),
class TextDataSourceRelation(override val sqlContext: SQLContext,
path:String,
userSchema:StructType)
extends BaseRelation with TableScan with Logging{
override def schema: StructType = {
if (null != userSchema){
userSchema
}else {
//自定义schema
StructType(
StructField("id",LongType,false) ::
StructField("name",StringType,false) ::
StructField("gender",StringType,false) ::
StructField("salary",LongType,false) ::
StructField("comm",LongType,false) :: Nil
)
}
}
override def buildScan(): RDD[Row] = {
logInfo("this is custom buildScan")
//wholeTextFiles 读取整个文本
val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(_._2)
val fieldsSchema = schema.fields
val rows = rdd.map(fileContent =>{
//按行进行切割
val lines = fileContent.split("\n")
//每行中以逗号进行切割
val data = lines.map(_.split(",").map(x=>x.trim)).toSeq
val result = data.map(x => x.zipWithIndex.map{
case (value,index) => {
val columnName = fieldsSchema(index).name
val castValue = if (columnName.equalsIgnoreCase("gender")){
if (value.equalsIgnoreCase("0")){
"man"
}else if(value.equalsIgnoreCase("1")){
"woman"
}else{
"unknown"
}
}else{
value
}
SqlUtil.castTo(castValue,fieldsSchema(index).dataType)
}
})
result.map(x => Row.fromSeq(x))
})
rows.flatMap(x =>x)
}
}
object SqlUtil {
def castTo(value:String,dataType:DataType)={
dataType match {
case _:LongType => value.toLong
case _:StringType => value
}
}
}
main方法
object TestCustomSouce {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("TextApp").master("local[2]").getOrCreate()
val df = spark.read.format("com.kzw.bigdata.spark.sql04").option("path","input/custom.txt").load()
//df.show()
df.printSchema()
df.createOrReplaceTempView("customTable")
val sql = "select * from customTable"
spark.sql(sql).show()
val sql2 = "select id,sum(salary) from customTable group by id"
spark.sql(sql2).show()
spark.stop()
}
}
结果显示:
root
|-- id: long (nullable = false)
|-- name: string (nullable = false)
|-- gender: string (nullable = false)
|-- salary: long (nullable = false)
|-- comm: long (nullable = false)
+---+-------+-------+------+------+
| id| name| gender|salary| comm|
+---+-------+-------+------+------+
|101|zhansan| man| 10000|200000|
|102| lisi| man|150000|250000|
|103| wangwu| woman| 3000| 5|
|104|zhaoliu|unknown| 500| 6|
|102| lisi| man|150000|250000|
+---+-------+-------+------+------+
+---+-----------+
| id|sum(salary)|
+---+-----------+
|103| 3000|
|104| 500|
|101| 10000|
|102| 300000|
+---+-----------+