➢ 易整合
无缝的整合了 SQL 查询和 Spark 编程
➢ 统一的数据访问
使用相同的方式连接不同的数据源
➢ 兼容 Hive
在已有的仓库上直接运行 SQL 或者 HiveQL
➢ 标准数据连接
通过 JDBC 或者 ODBC 来连接
➢ DataFrame
DataFrame也是一种基于RDD的分布式数据集, 与RDD的区别在于DataFrame中有数据的原信息
DataFrame可以理解为传统数据库中的一张二维表格,每一列都有列名和类型
➢ DataSet
DataSet也是分布式数据集,对DataFrame的一个扩展,相当于传统JDBC中的ResultSet
在SparkCore中需要创建上下文环境SparkContext
而SparkSql对SparkCore的封装, 不仅仅是功能上的封装,上下文件环境也封装了
老版本中称为 SQLContext 用于Spark自己的查询 和 HiveContext 用于Hive连接的查询
新版本中称为 SparkSession 是 SQLContext 和 HiveContext的组成 , 所以他们的API是通用的
同时 SparkSession也可以直接获取到SparkContext对象
三个概念:
数据: RDD中只关心数据 比如: (1,"jack",20) 并不关心每个字段的汉字
结构: DataFrame关心 数据+结构 比如: {"id":1,"name":"jack","age":20} 关心每个字段数据的类型
类型: DataSet关系 数据+结构+类型 比如: DataSet[Person] Person是我们定义好的类, 既有类型+字段+数据
➢ 从数据源中创建
scala> var df = spark.read.json("data/info.json")
df: org.apache.spark.sql.DataFrame = [ age: bigint , id: bigint ]
➢ 从RDD中转换(后续章节补充)
➢ 从Hive Table查询返回(后续章节补充)
使用DataFrame有两个方式,分别是 SQL语法和DSL语法
➢ SQL语法
1. 通过 "临时视图" 来使用,所以先创建视图
2. 通过 sparkSession对象执行sql进行数据查询
scala> df.createOrReplaceTempView("user") //创建临时视图
scala> var viewdf = spark.sql("select id,name,age from user") //通过spark执行sql
viewdf: org.apache.spark.sql.DataFrame = [id: bigint, name: string] //执行sql返回的还是DF
scala> viewdf.show //展示DF中的数据
scala> spark.sql("select id,name,age from user").show //也可以直接查询sql并展示
+---+-----+---+
| id| name|age|
+---+-----+---+
| 1|jack1| 18|
| 2|jack2| 28|
| 3|jack3| 38|
+---+-----+---+
注意:
df.createOrReplaceTempView 只能创建当前会话有效的临时视图
df.createOrReplaceGlobalTempView 能创建所有会话都有效的临时视图
使用时 需要在视图名前面加上 global_temp.视图名
➢ DSL语法
DSL称为 Domain-Specific Language 特定领域语言
这是 DataFrame中管理结构化数据的API ,通过DataFrame就可以调用这些API
scala> df.printSchema
root
|-- age: long (nullable = true)
|-- id: long (nullable = true)
|-- name: string (nullable = true)
scala> df.select("name")
res20: org.apache.spark.sql.DataFrame = [name: string]
//基本查询
scala> df.select("name").show
+-----+
| name|
+-----+
|jack1|
|jack2|
|jack3|
+-----+
//列运算
scala> df.select($"age" + 1).show
scala> df.select('age + 1).show
+---------+
|(age + 1)|
+---------+
| 19|
| 29|
| 39|
+---------+
//取别名
scala> df.select('name,'age + 1 as "aa").show
+-----+---+
| name| aa|
+-----+---+
|jack1| 19|
|jack2| 29|
|jack3| 39|
+-----+---+
//统计函数
scala> df.select(avg("age") as "平均年龄").show
+--------+
|平均年龄|
+--------+
| 48.0|
+--------+
//条件过滤
scala> df.filter('age > 25).show
+---+---+-----+
|age| id| name|
+---+---+-----+
| 28| 2|jack2|
| 38| 3|jack3|
+---+---+-----+
//组合+聚合函数
scala> df.groupBy("id").count.show
+---+-----+
| id|count|
+---+-----+
| 1| 1|
| 3| 1|
| 2| 1|
+---+-----+
➢ RDD 与 DF 转换需要导入 隐式函数
import spark.implicits._
这里的spark是 SparkSession的对象名,因此需要创建好SparkSession对象之后导入,并且该对象必须是val常量
1. RDD ==> DF ,缺少结构,即字段名
scala> var rdd = spark.sparkContext.makeRDD(List(1,2,3))
rdd: org.apache.spark.rdd.RDD[Int] = ParallelCollectionRDD[129] at makeRDD at <console>:23
scala> var df = rdd.toDF("id")
df: org.apache.spark.sql.DataFrame = [id: int]
scala> df.show
+---+
| id|
+---+
| 1|
| 2|
| 3|
+---+
2. DF ===> RDD DF内部封装了RDD 直接获取即可
删除结构后,DF中每一行 就会变成一个Row对象
通过Row对象的get(index) 或者 getAs[Type](index)方法获取Row对象中的数据
scala> var df = spark.read.json("data/info.json");
df: org.apache.spark.sql.DataFrame = [age: bigint, id: bigint]
scala> var rdd = df.rdd
rdd: org.apache.spark.rdd.RDD[org.apache.spark.sql.Row]
scala> var arr = rdd.collect
arr: Array[org.apache.spark.sql.Row] = Array([18,1,jack1], [28,2,jack2], [38,3,jack3])
scala> arr(0)
res62: org.apache.spark.sql.Row = [18,1,jack1]
scala> arr(0).get(0)
res63: Any = 18
scala> arr(0).getAs[String](2)
res67: String = jack1
➢ 通过Seq或者List 可以把集合直接转成DS
1. 通过基本类型的集合
scala> var ds = List(1,2,3).toDS
ds: org.apache.spark.sql.Dataset[Int] = [value: int]
scala> var ds = List(1.1,2.2).toDS
ds: org.apache.spark.sql.Dataset[Double] = [value: double]
2.通过已定义类型的集合
scala> case class User(age:Int,name:String)
defined class User
scala> var ds = List(User(10,"jack"),User(20,"rose")).toDS
ds: org.apache.spark.sql.Dataset[User] = [age: int, name: string]
在实际使用的时候,很少用到把序列转换成DataSet,更多的是通过RDD来得到DataSet
1. RDD ==> DS 缺少结构和类型
a. RDD 转DS 我们一般可以映射成 具体具体类型的RDD之后再转DS
scala> case class User(name:String, age:Int)
defined class User
scala> sc.makeRDD(List(("zhangsan",30), ("lisi",49))).map(t=>User(t._1,t._2)).toDS
res11: org.apache.spark.sql.Dataset[User] = [name: string, age: int]
b. 也可以直接ToDS
scala> sc.makeRDD(List(("zhangsan",30), ("lisi",49))).toDS
res11: org.apache.spark.sql.Dataset[(String,Int)] = [_1: String, _2: Int]
2.DS ==> RDD DS内部封装了RDD 直接获取即可,且获取出来的RDD也是带有类型的
scala> var ds = List(User("aa",11),User("bb",22)).toDS
ds: org.apache.spark.sql.Dataset[User] = [name: string, age: int]
scala> var rdd = ds.rdd
rdd: org.apache.spark.rdd.RDD[User]
➢ DataFrame ==> DataSet 需要一个类型
scala> case class User(name:String, age:Int)
defined class User
scala> val df = sc.makeRDD(List(("zhangsan",30),("lisi",49))).toDF("name","age")
df: org.apache.spark.sql.DataFrame = [name: string, age: int]
scala> val ds = df.as[User]
ds: org.apache.spark.sql.Dataset[User] = [name: string, age: int]
➢ DataSet ==> DataFrame 删除类型 即变成 DataSet[Row]
scala> val df = ds.toDF
df: org.apache.spark.sql.DataFrame = [name: string, age: int]
DF ==rdd==> RDD [ ROW ]
DS ==rdd==> RDD [ Type ]
转rdd 如果是DF 那么泛型是ROW 如果是DS泛型就是DS的泛型
RDD ==toDF==> DF [ ROW ]
DS ==toDF==> DF [ ROW ]
转DF 无论如何DF没有类型 所以都是ROW
DF ==as[Type]==> DS[ Type ]
RDD ====> DS [ RDD的泛型 ]
转DS 如果是DF 那么泛型是Type 如果是RDD泛型就是RDD的泛型
object Spark_SQL_Start {
def main(args: Array[String]): Unit = {
//1.创建SparkSession
val spark: SparkSession = SparkSession.builder()
.config(new SparkConf().setMaster("local[*]")
.setAppName("start01")).getOrCreate()
import spark.implicits._
//2.DF的创建和使用
val df: DataFrame = spark.read.json("datas/info.json").cache()
//SQL
df.createOrReplaceTempView("User")
spark.sql("select * from User").show()
//DSL
df.select("name").show()
df.groupBy("id").count().show()
println("-----------------------")
//3.DF ==> DS DF ==> RDD
val ds: Dataset[User] = df.as[User]
ds.show()
val rdd: RDD[Row] = df.rdd
rdd.collect().foreach(println)
println("-----------------------")
//4.RDD ==> DF RDD ==> DS
val rdd1: RDD[(Int, String)] = spark.sparkContext.makeRDD(Seq((10,"tom"),(20,"jack")))
var df1: DataFrame = rdd1.toDF("id","name")
df1.show()
val ds1: Dataset[User] = rdd1.map(t=>User(t._1,t._2)).toDS()
ds1.show()
println("-----------------------")
//5.DS==>RDD DS==>DF
val rdd2: RDD[User] = ds1.rdd
rdd2.collect().foreach(println)
val df2: DataFrame = ds1.toDF()
df2.show()
//6.关闭
spark.close()
}
case class User(id:Long,name:String)
}
object Spark_Sql_UDF {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.config(new SparkConf().setAppName("UDF").setMaster("local[*]")).getOrCreate()
import spark.implicits._
//注册用户自定义函数
spark.udf.register("getWithName",(x)=>{"Name:"+x})
//创建DF
val df: DataFrame = spark.read.format("json").load("datas/info.json")
//创建临时表
df.createOrReplaceTempView("user")
//使用sql自定义函数
spark.sql("select getWithName(name),id from user").show()
spark.close()
}
}
➢ info.json:
{"id": 10,"name": "jack"}
{"id": 20,"name": "rose"}
{"id": 30,"name": "tom"}
➢ console:
|getWithName(name)| id|
+-----------------+---+
| Name:jack| 10|
| Name:rose| 20|
| Name:tom| 30|
+-----------------+---+
需求: 自定义求平均值函数 avgAge
➢ 通过RDD实现
val rdd: RDD[(String, Int)] = spark.sparkContext.makeRDD(Seq(("jack", 10), ("rose", 20), ("tom", 30)))
val ageOneRdd: RDD[(Int, Int)] = rdd.map { case (_, age) => (age, 1) }
val ageCount: (Int, Int) = ageOneRdd.reduce((t1, t2) => (t1._1 + t2._1, t1._2 + t2._2))
println(ageCount._1 / ageCount._2)
➢ 通过累加器实现
val ageAcc = new AgeAccumulator
spark.sparkContext.register(ageAcc)
val rdd: RDD[(String, Int)] = spark.sparkContext.makeRDD(Seq(("jack", 10), ("rose", 20), ("tom", 30)))
rdd.foreach { case (_, age) => ageAcc.add(age) }
val ageCount: (Int, Int) = ageAcc.value
println(ageCount._1 / ageCount._2)
class AgeAccumulator extends AccumulatorV2[Int, (Int, Int)] {
private var ageSum: Int = 0
private var ageCnt: Int = 0
override def isZero: Boolean = ageSum == 0 && ageCnt == 0
override def copy(): AccumulatorV2[Int, (Int, Int)] = new AgeAccumulator
override def reset(): Unit = {
ageSum = 0
ageCnt = 0
}
override def add(age: Int): Unit = {
ageSum += age
ageCnt += 1
}
override def merge(other: AccumulatorV2[Int, (Int, Int)]): Unit = {
ageSum += other.value._1
ageCnt += other.value._2
}
override def value: (Int, Int) = (ageSum, ageCnt)
}
➢ 通过 继承 UDAF (Spark3.0之前) 抽象类实现自定义聚合函数
//创建DF
val rdd: RDD[(String, Int)] = spark.sparkContext.makeRDD(Seq(("jack", 10), ("rose", 20), ("tom", 30)))
val df: DataFrame = rdd.toDF("name", "age")
//创建自定义集合函数对象
val ageUDAF = new AgeUDAF
//注册UDAF
spark.udf.register("ageAVG", ageUDAF)
//创建临时表
df.createOrReplaceTempView("user")
//执行sql
spark.sql("select ageAVG(age) from user").show()
class AgeUDAF extends UserDefinedAggregateFunction {
//聚合函数输入的数据类型
override def inputSchema: StructType = {
StructType(
Array(
StructField("age", IntegerType)
)
)
}
//计算过程的缓冲区
override def bufferSchema: StructType = {
StructType(
Array(
StructField("ageSum", LongType),
StructField("ageCnt", LongType)
)
)
}
//聚合函数返回值类型
override def dataType: DataType = DoubleType
// 稳定性:对于相同的输入是否一直返回相同的输出
override def deterministic: Boolean = true
// 函数缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
//累加计算
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getInt(0)
buffer(1) = buffer.getLong(1) + 1
}
//合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算结果
override def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}
➢ 通过继承 Aggregate (Spark3.0) 自定义强类型聚合函数
//创建强类型UDAF对象
val ageAggr = new AgeAggregator
//注册强类型的UDAF,需要使用functions.udaf进行函数转换
spark.udf.register("ageAggr", functions.udaf(ageAggr))
//创建rdd
val rdd: RDD[(String, Int)] = spark.sparkContext.makeRDD(Seq(("jack", 10), ("rose", 20), ("tom", 30)))
//转成DF
val df: DataFrame = rdd.toDF("name", "age")
//创建临时表
df.createOrReplaceTempView("user")
//执行sql 直接使用聚合函数
spark.sql("select ageAggr(age) from user").show()
class AgeAggregator extends Aggregator[Int, (Long, Long), Double] {
//缓冲区 初始值
override def zero: (Long, Long) = (0, 0)
//输入age到缓冲区计算
override def reduce(buff: (Long, Long), age: Int): (Long, Long) = {
(buff._1 + age, buff._2 + 1)
}
//合并多个缓冲区
override def merge(buff1: (Long, Long), buff2: (Long, Long)): (Long, Long) = {
(buff1._1 + buff2._1, buff1._2 + buff2._2)
}
//计算结果
override def finish(buff: (Long, Long)): Double = {
buff._1.toDouble / buff._2
}
//输入编码,自定义对象Encoders.product 其他Encoders.scalaXxx
override def bufferEncoder: Encoder[(Long, Long)] = Encoders.tuple(Encoders.scalaLong, Encoders.scalaLong)
//输出编码,自定义对象Encoders.product 其他Encoders.scalaXxx
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
SparkSQL 提供了通用的保存数据和加载数据的方式. 默认的保存和加载数据的格式都是parquet
scala> spark.read.load( "parquet path" ) 和 df.write.save( "parquet path" )
我们可以通过设置不同的参数,来指定不同的数据源格式,读取和保存数据均可
scala> spark.read.format( "json" ) [ .option( "key","value" ) ] .load( "filepath" )
➢ format( "..." ) 指定读取数据源的格式, 包括 "csv"、"jdbc"、"json"、"orc"、"parquet" 和 "textFile"。
➢ option( "key","value" ) 如果format是jdbc, 那么使用多个option传递JDBC参数
➢ 也有一些简化的方法,用于特定的文件读取,从而省略format调用
spark.read.json("json file path") spark.read.cvs("cvs file path")
scala> spark.read.
csv jdbc load options parquet table textFile
format json option orc schema text
保存操作:
df.write.mode("SaveMode String").save("parquet file path")
除了和读取操作一样的参数之外, 另有一个模式,表示保存时的状态
SaveMode.ErrorIfExists(default) ==> "error"(默认的) 如果文件已经存在则抛出异常
SaveMode.Append ==> "append" 如果文件已经存在则追加(会有多个文件生成)
SaveMode.Overwrite ==> "overwrite" 如果文件已经存在则覆盖(把以前的文件删除)
SaveMode.Ignore ==> "ignore" 如果文件已经存在则忽略(不生成新文件,保留原来的文件 )
scala> df.write.mode("append").json("/output")
➢ spark-shell 命令行连接
1. 添加mysql驱动到spark的jars目录下
2. scala> val jdbcDF = spark.read.format("jdbc").options(Map("url" -> "jdbc:mysql://localhost:3306/mysql", "driver" -> "com.mysql.jdbc.Driver", "dbtable" -> "plugin", "user" -> "root", "password" -> "1234")).load()
➢ scala代码连接
1.导入mysql的驱动依赖
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>5.1.27</version>
</dependency>
2.读取mysql数据
spark.read.format("jdbc").options(Map(
"url" -> "jdbc:mysql://192.168.189.90:3306/mysql",
"driver" -> "com.mysql.cj.jdbc.Driver",
"user" -> "root",
"password" -> "1234",
"dbtable" -> "help_topic"
)).load().show()
spark.read.format("jdbc")
.option("url", "jdbc:mysql://192.168.189.90:3306/mysql")
.option("driver", "com.mysql.cj.jdbc.Driver")
.option("user", "root")
.option("password", "1234")
.option("dbtable", "plugin")
.load().show
val props: Properties = new Properties()
props.setProperty("user", "root")
props.setProperty("password", "1234")
spark.read.jdbc("jdbc:mysql://node0:3306/mysql",
"plugin", props).show()
3.写入数据到mysql
val df: DataFrame = spark.sparkContext.makeRDD(List(1,2,3)).toDF("id")
df.write.format("jdbc").options(Map(
"url" -> "jdbc:mysql://192.168.189.90:3306/test",
"driver" -> "com.mysql.cj.jdbc.Driver",
"user" -> "root",
"password" -> "1234",
"dbtable" -> "ids"
)).mode(SaveMode.Append) //默认的模式是 表存在则报错,这里指定追加 表存在不会保存
.save()
连接内嵌Hive什么都不需要做, 默认使用derby作为元数据库,使用本地文件系统作为数据仓库
执行两个命令:
//执行查看数据库sql, 会自动在 $spark_home下生成metastore_db元数据库信息
scala> spark.sql("show tables").show
//执行创建表操作,或者插入数据操作,会自动生成并在$spark_home/spark-warehouse 存储数据
scala> spark.sql("create table test(id int)")
spark.sql("insert into test values(1),(2)")
scala> spark.sql("show tables").show
+--------+--------------------+-----------+
|database| tableName|isTemporary|
+--------+--------------------+-----------+
| default| aa| false|
| default| live_events| false|
| default| login_events| false|
| default|order_amount_by_p...| false|
| default| order_detail| false|
| default| page_view_events| false|
| default| payment_detail| false|
| default| product_info| false|
| default| promotion_info| false|
| default| province_info| false|
+--------+--------------------+-----------+
确保Spark可以连接外部的hive之后,.就可以是用spark-sql直接连接hive进行操作
[zhyp@node0 spark-local]$ bin/spark-sql
spark-sql> show tables;
default aa false
default live_events false
default login_events false
default order_amount_by_province false
default order_detail false
default page_view_events false
default payment_detail false
default product_info false
default promotion_info false
default province_info false
Time taken: 2.433 seconds, Fetched 10 row(s)
[zhyp@node0 spark-local]$ sbin/start-thriftserver.sh
[zhyp@node0 spark-local]$ bin/beeline -u jdbc:hive2://linux1:10000 -n zhyp
或者
[zhyp@node0 spark-local]$ bin/beeline
beeline> !connect jdbc:hive2://node0:10000
def main(args: Array[String]): Unit = {
//创建 SparkSession
System.setProperty("HADOOP_USER_NAME", "zhyp")
val spark: SparkSession = SparkSession
.builder()
.enableHiveSupport() //添加Hive支持,默认是不支持连接Hive,开启后会读取classpath下的hive-site.xml
.config("spark.sql.warehouse.dir", "hdfs://node0:8020/user/hive/warehouse") //通过spark创建数据库需要写这个地址,因为新建的数据库默认是在本地路径中找/user/hive/warehouse
.master("local[*]")
.appName("sql")
.getOrCreate()
import spark.implicits._
spark.sql("create table abc(id int)").show()
spark.sql("insert into abc values(1),(111)").show()
spark.sql("show tables").show()
spark.sql("select * from abc").show()
spark.close()
}
package com.zhyp.spark.sql.start
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Dataset, SparkSession}
object Spark_SQL_Test01{
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "zhyp")
//1.创建SparkSession
val spark: SparkSession = SparkSession.builder().enableHiveSupport()
.config("spark.sql.warehouse.dir","hdfs://node0:8020/user/hive/warehouse")
.config(new SparkConf().setMaster("local[*]")
.setAppName("start01")).getOrCreate()
import spark.implicits._
spark.sql("create database test")
spark.sql("use test")
spark.sql(
"""
|CREATE TABLE `user_visit_action`(
| `date` string,
| `user_id` bigint,
| `session_id` string,
| `page_id` bigint,
| `action_time` string,
| `search_keyword` string,
| `click_category_id` bigint,
| `click_product_id` bigint,
| `order_category_ids` string,
| `order_product_ids` string,
| `pay_category_ids` string,
| `pay_product_ids` string,
| `city_id` bigint)
|row format delimited fields terminated by '\t';
""".stripMargin)
spark.sql(
"""
|load data local inpath 'datas/user_visit_action.txt' into table
|user_visit_action
""".stripMargin)
spark.sql(
"""
|CREATE TABLE `product_info`(
| `product_id` bigint,
| `product_name` string,
| `extend_info` string)
|row format delimited fields terminated by '\t'
""".stripMargin)
spark.sql(
"""
|load data local inpath 'datas/product_info.txt' into table product_info
""".stripMargin)
spark.sql(
"""
|CREATE TABLE `city_info`(
| `city_id` bigint,
| `city_name` string,
| `area` string)
|row format delimited fields terminated by '\t'
""".stripMargin)
spark.sql(
"""
|load data local inpath 'datas/city_info.txt' into table city_info
""".stripMargin)
spark.close()
}
}
package com.zhyp.spark.sql.start
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
object Spark_SQL_Test03 {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "zhyp")
//1.创建SparkSession
val spark: SparkSession = SparkSession.builder().enableHiveSupport()
.config("spark.sql.warehouse.dir", "hdfs://node0:8020/user/hive/warehouse")
.config(new SparkConf().setMaster("local[*]")
.setAppName("start01")).getOrCreate()
//2.执行
spark.udf.register("cityRemark", functions.udaf(new CityRemarkUDAF))
spark.sql("use test")
//1.关联三张表 过滤非点击数据
spark.sql(
"""
|select
|c.area,c.city_name,p.product_name
|from
|user_visit_action a
|join
|city_info c
|on a.city_id = c.city_id
|join
|product_info p
|on a.click_product_id = p.product_id
|where a.click_product_id > -1
""".stripMargin).createOrReplaceTempView("t1")
//2.分组 按照地区和商品分组
spark.sql(
"""
|select
|area,product_name,
|cityRemark(city_name) city_remark,
|count(*) clickCnt
|from t1
|group by area,product_name
""".stripMargin).createOrReplaceTempView("t2")
//3.对 同一个地区的各种商品点击量排名
spark.sql(
"""
|select
|*,
|rank() over(partition by area order by clickCnt desc) rank
|from t2
""".stripMargin).createOrReplaceTempView("t3")
//4.取各区域的前三名
spark.sql(
"""
|select
|*
|from t3
|where rank <= 3
""".stripMargin).show()
}
case class CityAndCntBuff(map:mutable.Map[String,Long])
class CityRemarkUDAF extends Aggregator[String, CityAndCntBuff, String] {
override def zero: CityAndCntBuff = CityAndCntBuff(mutable.Map())
override def reduce(buff: CityAndCntBuff, city: String): CityAndCntBuff = {
val map: mutable.Map[String, Long] = buff.map
map.update(city, map.getOrElse(city, 0L) + 1L)
buff
}
override def merge(buff1: CityAndCntBuff, buff2: CityAndCntBuff): CityAndCntBuff = {
var map1 = buff1.map;
var map2 = buff2.map
map2.foreach({
case (city, cnt) => {
map1.update(city, map1.getOrElse(city, 0L) + cnt)
}
})
buff1
}
override def finish(resultBuff: CityAndCntBuff): String = {
val resultMap: mutable.Map[String, Long] = resultBuff.map
val totalCnt: Long = resultMap.values.reduce(_ + _)
val top2: List[(String, Long)] = resultMap.toList.sortBy(_._2)(Ordering.Long.reverse).take(2)
val cityBuffer = new ListBuffer[String]
var percentSum = 0L
top2.foreach({
case (city, cnt) => {
val percent: Long = cnt * 100 / totalCnt
cityBuffer.append(s"${city} ${percent}%")
percentSum += percent
}
})
if (resultMap.size > 2) {
cityBuffer.append(s"其他 ${1 - percentSum}%")
}
cityBuffer.mkString(", ")
}
override def bufferEncoder: Encoder[CityAndCntBuff] = Encoders.product
override def outputEncoder: Encoder[String] = Encoders.STRING
}
}