package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql01_Test {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
//导入隐式转换,这里的spark其实是环境对象的名称
//要求这个对象使用val声明
import spark.implicits._ //最好用不用上都加上
//逻辑操作
val jsonDF: DataFrame = spark.read.json("input/user.json")
//SQL
//将df转换为临时视图
jsonDF.createOrReplaceTempView("user")
spark.sql("select * from user").show()
//DSL
//如果查询列名采用单引号,那么需要隐式转换
jsonDF.select("name", "age").show
jsonDF.select($"name", $"age").show
jsonDF.select('name, 'age).show
val rdd = spark.sparkContext.makeRDD(List(
(1, "zhangsan", 30),
(2, "lisi", 20),
(3, "wangwu", 40),
))
//RDD<=>DataFrame
val df: DataFrame = rdd.toDF("id", "name", "age")
val dfToRDD1: RDD[Row] = df.rdd
dfToRDD1.foreach(
row=>{
println(row(0))
})
//RDD<=>DataSet
val userRDD: RDD[User] = rdd.map {
case (id, name, age) => {
User(id, name, age)
}
}
val userDS: Dataset[User] = userRDD.toDS()
val dsToRDD: RDD[User] = userDS.rdd
//DataFram <=>DataSet
val dsToDS: Dataset[User] = df.as[User]
//type DataFrame = Dataset[Row] Dataset就是特殊类型的DataFrame
val dsToDF: DataFrame = dsToDS.toDF()
rdd.foreach(println)
df.show()
userDS.show()
//释放对象
spark.stop()
}
case class User(id: Int, name: String, age: Int)
}
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql02_Test {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
//导入隐式转换,这里的spark其实是环境对象的名称
//要求这个对象使用val声明
import spark.implicits._ //最好用不用上都加上
//逻辑操作
val rdd = spark.sparkContext.makeRDD(List(
(1, "zhangsan", 30),
(2, "lisi", 20),
(3, "wangwu", 40),
))
//RDD<=>DataSet
// val userRDD: RDD[User] = rdd.map {
// case (id, name, age) => {
// User(id, name, age)
// }
// }
//val userDS: Dataset[User] = userRDD.toDS()
//sparkSql封装的对象提供了大量的方法进行处理,类似于RDD的算子操作
//userDS.join()
//error
//val df: DataFrame = rdd.toDF("id", "name", "age")
// val ds: Dataset[Row] = df.map(row => {
// val id: Any = row(0)
// val name: Any = row(1)
// val age: Any = row(3)
// Row(id, "name" + name, age)
// })
val userRDD: RDD[User] = rdd.map {
case (id, name, age) => {
User(id, name, age)
}
}
val userDS: Dataset[User] = userRDD.toDS()
val newDS: Dataset[User] = userDS.map(user => {
User(user.id, "name:" + user.name, user.age)
})
newDS.show()
//使用自定义函数在SQL中完成数据的转换操作
val df = rdd.toDF("id", "name", "age")
df.createOrReplaceTempView("user")
spark.udf.register("addName", (x: String) => "Name:" + x)
spark.udf.register("changeAge", (x: Int) => 18)
spark.sql("select addName(name),changeAge(age) from user").show
spark.stop()
}
case class User(id: Int, name: String, age: Int)
}
(用户定义聚合函数)
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql03_UDAF {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._ //最好用不用上都加上
val rdd = spark.sparkContext.makeRDD(List(
(1, "zhangsan", 30L),
(2, "lisi", 20L),
(3, "wangwu", 40L),
))
val df = rdd.toDF("id", "name", "age")
df.createOrReplaceTempView("user")
//创建UDAF函数
val udaf = new MyAvgAgeUDAF
//注册到SparkSQL中
spark.udf.register("avgAge",udaf)
//在SQL中使用聚合函数
//定义用户的自定义函数
spark.sql("select avgAge(age) from user").show
spark.stop()
}
//自定义聚合函数
//1.继承UserDefinedAggregateFunction
//2.重写方法
//totalage,count
class MyAvgAgeUDAF extends UserDefinedAggregateFunction {
//输入数据的结构信息:年龄信息
override def inputSchema: StructType = {
StructType(Array(StructField("age", LongType)))
}
//缓冲区的数据结构信息:年龄的总和,人的数量
override def bufferSchema: StructType = {
StructType(Array(
StructField("totalage", LongType),
StructField("count", LongType)
))
}
//聚合函数返回的结果类型
override def dataType: DataType = LongType
//函数稳定性
override def deterministic: Boolean = true
//函数缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
//更新缓冲区数据,
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
//合并缓冲区
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//函数的计算
override def evaluate(buffer: Row): Any = {
buffer.getLong(0) / buffer.getLong(1)
}
}
}
自定义聚合函数 - 强类型
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql04_UDAF_Class {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._ //最好用不用上都加上
val rdd = spark.sparkContext.makeRDD(List(
(1, "zhangsan", 30L),
(2, "lisi", 20L),
(3, "wangwu", 40L),
))
val df = rdd.toDF("id", "name", "age")
val ds: Dataset[User] = df.as[User]
//创建UDAF函数
val udaf = new MyAvgAgeUDAFClass
//在SQL中使用聚合函数
//因为聚合函数是强类型,那么sql中没有类型的概念,所以无法使用
//可以采用DSL语法进行访问
//将聚合函数转换为查询的列让DataSet访问
ds.select(udaf.toColumn).show
spark.stop()
}
case class User(id: Int, name: String, age: Long)
case class AvgBuffer(var totalage: Long, var count: Long)
//自定义聚合函数 - 强类型
//1.继承Aggregator,敌营泛型
//IN:输入数据的类型User
// BUF:缓冲区的数据类型 AvgBuffer
//OUT:输出的数据类型Long
//2.重写方法
class MyAvgAgeUDAFClass extends Aggregator[User, AvgBuffer, Long] {
//缓冲区的初始值
override def zero: AvgBuffer = {
AvgBuffer(0L, 0L)
}
//聚合数据
override def reduce(buffer: AvgBuffer, user: User): AvgBuffer = {
buffer.totalage = buffer.totalage + user.age
buffer.count = buffer.count + 1
buffer
}
//合并缓冲区
override def merge(buffer1: AvgBuffer, buffer2: AvgBuffer): AvgBuffer = {
buffer1.totalage = buffer1.totalage + buffer2.totalage
buffer1.count = buffer1.count + buffer2.count
buffer1
}
//计算函数结果
override def finish(reduction: AvgBuffer): Long = {
reduction.totalage / reduction.count
}
//编解码器,用于序列化 固定写法
override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
}
user.json
格式不符合 json要求,符合spark要求,不能有“,”
{"name": "zhangsan","age": "20"}
{"name": "lisi","age": "30"}
{"name": "wangwu","age": "40"}
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql05_LoadSave {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
//sparkSQL通用的读取和保存
//通用的读取
//RuntimeException: file:xxxx/input/user.json is not a Parquet file.
//SparkSQL通用读取的数据格式为Parquet列式存储格式
//val frame: DataFrame = spark.read.load("input/user.json")
//如果想要改变读取文件的格式,需要使用特殊的操作
//如果读取的文件格式为JSON格式,Spark对JSON文件的格式有要求
//JSON => JavaScrip Object Notation
//JSON文件的格式要求整个文件满足JSON的语法规则
//Spark读取文件默认是以行为单位来读取
//Spark读取JSON文件时,要求文件中的每一行符合JSON的格式要求
//如果文件格式不正确,那么不会发生错误,但是解析结果不正确
val frame: DataFrame = spark.read.format("json").load("input/user.json") //通用的
//spark.read.json() //特殊的
frame.show()
spark.stop()
}
}
另一种读取方式更简单
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql07_LoadSave {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
spark.sql("select * from json.`input/user.json`").show()
spark.stop()
}
}
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql06_LoadSave {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
//sparkSQL通用的读取和保存
//通用的保存
val df = spark.read.format("json").load("input/user.json")
//sparksql默认通用保存的文件格式为parquet
//如果想要保存的格式是指定的格式,比如json,那么需要进行对应的格式化操作
//如果路径已经存在,那么执行保存操作会发生错误
df.write.format("json").save("output1")
//如果非得想要路径已经存在的情况下,保存数据,那么可以使用保存模式
//df.write.mode("overwrite").format("json").save("output")
df.write.mode("append").format("json").save("output")
spark.stop()
}
}
没有“,”,默认字典序
第一行写数据类型
name;age
zhangsan;30
wangwu;40
lisi;20
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql08_Load_CSV {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val frame: DataFrame = spark.read.format("csv")
.option("sep", ";")
.option("inferSchema", "true")
.option("header", "true")
.load("input/user.csv")
frame.show()
spark.stop()
}
}
添加依赖
<dependency>
<groupId>mysqlgroupId>
<artifactId>mysql-connector-javaartifactId>
<version>5.1.27version>
dependency>
通用方法
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql09_Load_MySQL {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
spark.read.format("jdbc")
.option("url", "jdbc:mysql://hadoop130:3306/spark-sql")
.option("driver", "com.mysql.jdbc.Driver")
.option("user", "root")
.option("password", "123456")
.option("dbtable", "user")
.load().show
spark.stop()
}
}
通用方法
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql10_Save_MySQL {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val frame: DataFrame = spark.read.format("jdbc")
.option("url", "jdbc:mysql://hadoop130:3306/spark-sql")
.option("driver", "com.mysql.jdbc.Driver")
.option("user", "root")
.option("password", "123456")
.option("dbtable", "user")
.load()
frame.write.format("jdbc")
.option("url", "jdbc:mysql://hadoop130:3306/spark-sql")
.option("driver", "com.mysql.jdbc.Driver")
.option("user", "root")
.option("password", "123456")
.option("dbtable", "user1")
.mode(SaveMode.Append) //可以选定模式在原表中追加
.save()
spark.stop()
}
}
添加依赖
<dependency>
<groupId>org.apache.sparkgroupId>
<artifactId>spark-hive_2.12artifactId>
<version>2.4.5version>
dependency>
<dependency>
<groupId>org.apache.hivegroupId>
<artifactId>hive-execartifactId>
<version>3.1.2version>
dependency>
spark内嵌hive
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql11_Load_Hive {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
//默认情况下SparkSQL支持本地Hive操作的,执行前需要启用Hive的支持
//调用enableHiveSupport方法
val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
//可以使用基本的sql访问hive中的内容
spark.sql("create table aa(id int)")
spark.sql("show tables").show()
spark.sql("load data local inpath'input/id.txt' into table aa")
spark.sql("select * from aa").show
spark.stop()
}
}
添加resource文件
hive-site.xml
注意取消tez的配置 与spark冲突
<configuration>
<property>
<name>javax.jdo.option.ConnectionURLname>
<value>jdbc:mysql://hadoop130:3306/metastore?useSSL=falsevalue>
property>
<property>
<name>javax.jdo.option.ConnectionDriverNamename>
<value>com.mysql.jdbc.Drivervalue>
property>
<property>
<name>javax.jdo.option.ConnectionUserNamename>
<value>rootvalue>
property>
<property>
<name>javax.jdo.option.ConnectionPasswordname>
<value>123456value>
property>
<property>
<name>hive.metastore.warehouse.dirname>
<value>/user/hive/warehousevalue>
property>
<property>
<name>hive.metastore.schema.verificationname>
<value>falsevalue>
property>
<property>
<name>hive.metastore.urisname>
<value>thrift://hadoop130:9083value>
property>
<property>
<name>hive.server2.thrift.portname>
<value>10000value>
property>
<property>
<name>hive.server2.thrift.bind.hostname>
<value>hadoop130value>
property>
<property>
<name>hive.metastore.event.db.notification.api.authname>
<value>falsevalue>
property>
<property>
<name>hive.cli.print.headername>
<value>truevalue>
<description>Whether to print the names of the columns in query output.description>
property>
<property>
<name>hive.cli.print.current.dbname>
<value>truevalue>
<description>Whether to include the current database in the Hive prompt.description>
property>
configuration>
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql12_Load_Hive {
def main(args: Array[String]): Unit = {
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//builder构建,创建
//访问外置的hive
val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
//可以使用基本的sql访问hive中的内容
spark.sql("show databases").show()
spark.stop()
}
}
地区 | 商品名称 | 点击次数 | 城市备注 |
---|---|---|---|
华北 | 商品A | 100000 | 北京21.2%,天津13.2%,其他65.6% |
华北 | 商品P | 80200 | 北京63.0%,太原10%,其他27.0% |
华北 | 商品M | 40000 | 北京63.0%,太原10%,其他27.0% |
东北 | 商品J | 92000 | 大连28%,辽宁17.0%,其他 55.0% |
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql13_Req_Mock {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "vanas")
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//访问外置的Hive
val spark = SparkSession.builder()
.enableHiveSupport()
.config(sparkConf).getOrCreate()
spark.sql("use bigdata0213")
spark.sql(
"""
|CREATE TABLE `user_visit_action`(
| `date` string,
| `user_id` bigint,
| `session_id` string,
| `page_id` bigint,
| `action_time` string,
| `search_keyword` string,
| `click_category_id` bigint,
| `click_product_id` bigint,
| `order_category_ids` string,
| `order_product_ids` string,
| `pay_category_ids` string,
| `pay_product_ids` string,
| `city_id` bigint)
|row format delimited fields terminated by '\t'
|""".stripMargin)
spark.sql(
"""
|load data local inpath 'input1/user_visit_action.txt' into table bigdata0213.user_visit_action
|""".stripMargin)
spark.sql(
"""
|CREATE TABLE `product_info`(
| `product_id` bigint,
| `product_name` string,
| `extend_info` string)
|row format delimited fields terminated by '\t'
|""".stripMargin).show
spark.sql(
"""
|load data local inpath 'input1/product_info.txt' into table bigdata0213.product_info
|""".stripMargin)
spark.sql(
"""
|CREATE TABLE `city_info`(
| `city_id` bigint,
| `city_name` string,
| `area` string)
|row format delimited fields terminated by '\t'
|""".stripMargin)
spark.sql(
"""
|load data local inpath 'input1/city_info.txt' into table bigdata0213.city_info
|""".stripMargin)
spark.sql(
"""
|select * from city_info
|""".stripMargin).show(10)
spark.stop()
}
}
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql14_Req {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "vanas")
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//访问外置的Hive
val spark = SparkSession.builder()
.enableHiveSupport()
.config(sparkConf).getOrCreate()
spark.sql("use bigdata0213")
// spark.sql(
// """
// |select
// | *
// |from (
// | select
// | *,
// | rank() over( partition by area order by clickCount desc ) as rank
// | from (
// | select
// | area,
// | product_name,
// | count(*) as clickCount
// | from (
// | select
// | a.*,
// | c.area,
// | p.product_name
// | from user_visit_action a
// | join city_info c on c.city_id = a.city_id
// | join product_info p on p.product_id = a.click_product_id
// | where a.click_product_id > -1
// | ) t1 group by area, product_name
// | ) t2
// |) t3
// |where rank <= 3
// """.stripMargin).show
spark.sql(
"""
|select *
|from(
|select *,
|rank() over(distribute by area order by sum_click desc) rank
|from(
|select area ,product_name,count(click_product_id) sum_click
|from user_visit_action a
|join city_info c on a.city_id = c.city_id
|join product_info p on p.product_id = a.click_product_id
|where click_product_id > -1
|group by area ,product_name
|)t1
|)t2
|where rank <=3
|""".stripMargin).show()
spark.stop()
}
}
这里的热门商品是从点击量的维度来看的,计算各个区域前三大热门商品,并备注上每个商品在主要城市中的分布比例,超过两个城市用其他显示
package com.vanas.bigdata.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
/**
* @author Vanas
* @create 2020-06-10 4:32 下午
*/
object SparkSql15_Req {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "vanas")
//创建环境对象
val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
//访问外置的Hive
val spark = SparkSession.builder()
.enableHiveSupport()
.config(sparkConf).getOrCreate()
spark.sql("use bigdata0213")
//创建自定义聚合函数
val udaf = new CityRemarkUDAF
//注册聚合函数
spark.udf.register("cityReamark", udaf)
//从hive表中/获取满足条件的数据
//将数据根据区域进行分组,统计商品点击的数量
spark.sql(
"""
|select area ,product_name,count(click_product_id) sum_click,cityReamark(city_name)
|from user_visit_action a
|join city_info c on a.city_id = c.city_id
|join product_info p on p.product_id = a.click_product_id
|where click_product_id > -1
|group by area ,product_name
|""".stripMargin).createOrReplaceTempView("t1")
//将统计结果数量进行排序(降序)
spark.sql(
"""
|select *,
|rank() over(distribute by area order by sum_click desc) rank
|from t1
|""".stripMargin).createOrReplaceTempView("t2")
//将组内排序后的结果取前三名
spark.sql(
"""
|select *
|from t2
|where rank <=3
|""".stripMargin).show()
spark.stop()
}
//北京,上海,北京,深圳
//in:cityname:String
//out:remark:String
//buffer :2结构,(total,map)
//(商品点击总和,每个城市点击总和)
//(商品点击总和,Map(城市,点击sum))
//城市点击sum/商品点击总和%
//自定义城市备注聚合函数
class CityRemarkUDAF extends UserDefinedAggregateFunction {
//输入的数据其实就是城市名称
override def inputSchema: StructType = {
StructType(Array(StructField("cityName", StringType)))
}
//缓冲区中的数据应该为:totalcnt,Map[cityname,cnt]
override def bufferSchema: StructType = {
StructType(Array(
StructField("cityName", LongType),
StructField("cityMap", MapType(StringType, LongType))
))
}
//返回城市备注的字符串
override def dataType: DataType = StringType
override def deterministic: Boolean = true
//缓冲区的初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
//buffer.update(0,0L)
buffer(1) = Map[String, Long]()
}
//更新缓冲区
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val cityName: String = input.getString(0)
//点击总和需要增加
buffer(0) = buffer.getLong(0) + 1
//城市点击增加
val cityMap: Map[String, Long] = buffer.getAs[Map[String, Long]](1)
val newClickCount = cityMap.getOrElse(cityName, 0L) + 1
buffer(1) = cityMap.updated(cityName, newClickCount)
}
//合并缓冲区
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//合并点击数量总和
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
//合并城市点击map
val map1 = buffer1.getAs[Map[String, Long]](1)
val map2 = buffer2.getAs[Map[String, Long]](1)
buffer1(1) = map1.foldLeft(map2) {
case (map, (k, v)) => {
map.updated(k, map.getOrElse(k, 0L) + v)
}
}
}
//对缓冲区进行计算并返回备注信息
override def evaluate(buffer: Row): Any = {
val totalcnt: Long = buffer.getLong(0)
val citymap: collection.Map[String, Long] = buffer.getMap[String, Long](1)
val cityToCountList: List[(String, Long)] = citymap.toList.sortWith(
(left, right) => left._2 > right._2
).take(2)
//val hasRest = citymap.size > 2
var rest = 0L
val s = new StringBuilder
cityToCountList.foreach {
case (city, cnt) => {
val r = (cnt * 100 / totalcnt)
s.append(city + " " + r + "%,")
rest = rest + r
}
}
s.toString() + "其他" + (100 - rest) + "%"
// if (hasRest) {
// s.toString() + "其他" + (100 - rest) + "%"
// } else {
// toString
// }
}
}
}