Spark sql 自定义读取数据源

通常在一个流式计算的主流程里,会用到很多映射数据,比较常见的是Text文档,但是文档读进来之后还要匹配相应的schema,本文通过自定义TextSource数据源,自动读取默认的Schema。
DefaultSource.scala

package com.wxx.bigdata.sql_custome_source

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType

class DefaultSource  extends RelationProvider with SchemaRelationProvider{
  def createRelation(sqlContext: SQLContext,
                              parameters: Map[String, String],
                              schema: StructType) :BaseRelation = {
    val path = parameters.get("path")
    path match {
      case Some(p) => new TextDatasourceRelation(sqlContext, p, schema)
      case _ => throw  new IllegalArgumentException("path is required")
    }
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) :BaseRelation = {
    createRelation(sqlContext, parameters, null)
  }
}

TextDatasourceRelation.scala
 

package com.wxx.bigdata.sql_custome_source

import com.wxx.bigdata.utils.Utils
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}

class TextDatasourceRelation(override val sqlContext: SQLContext,
                             path : String,
                             userSchema : StructType) extends BaseRelation with TableScan with Logging{


  override def schema = {
    if(userSchema != null){
      userSchema
    }else {
      StructType(
        StructField("id", LongType, false) ::
          StructField("name", StringType, false) ::
          StructField("gender", StringType, false) ::
          StructField("salary", LongType, false) ::
          StructField("comm", LongType, false) :: Nil
      )
    }
  }

  override def buildScan() :RDD[Row] = {
    logWarning("this is a custom buildScan")
    val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(_._2)
   // rdd.foreach(println)
    val schemaField = schema.fields

    val rows = rdd.map(fileContent =>{
      val lines = fileContent.split("\n")
      val data = lines.map(_.split(",").map(x => x.trim)).toSeq

      val result = data.map(x => x.zipWithIndex.map{
        case(value, index) =>{
          val colnumName = schemaField(index).name
          println(value + " | " + index +" | " + colnumName)
          Utils.castTo(if(colnumName.equals("gender")){
            if(value == "0"){
              "男"
            }else if(value == "1"){
              "女"
            }else{
              "未知"
            }
          } else {
            value
          }, schemaField(index).dataType)
        }
      })
      result.map(x => Row.fromSeq(x))
    })


    rows.flatMap(x => x)


  }
}

Utils.scala

package com.wxx.bigdata.utils

import org.apache.spark.sql.types.{DataType, LongType, StringType}

object Utils {

  def castTo(value: String, dataType: DataType)={
    dataType match {
      case _ : LongType => value.toLong
      case _ : StringType => value.toString
    }
  }

}


TextApp.scala

package com.wxx.bigdata.sql_custome_source

import org.apache.spark.sql.SparkSession

object TextApp {
  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder().appName("TextApp").master("local[2]").getOrCreate()
    //format 写到包名
    val df = spark.read.format("com.wxx.bigdata.sql_custome_source").option("path", "D:\\inputs").load()

    df.show()
    spark.stop()
  }

}

测试文档
字段名依次为ID,name,性别(0:男,1女),薪水,奖金

10000,zhangsan,0,100000,200000
10001,lisi,0,99999,199999
10002,wangwu,0,2000,5
10003,zhaoliu,0,2001,6
10004,tianqi,0,2007,7

 

你可能感兴趣的:(Spark)