1. JDK 1.8
2. Spark 2.1
窗口函数,顾名思义,这里存在一个窗口的概念。也就是指表内数据参与到函数计算的一个区间。这里说的计算区间,我理解是有两个意思。第一是看是否需要按指定的列来对数据进行分区。第二是看分区确定后是否还指定了对分区数据的进一步的限定。包括rows区间和range区间两种限定,后面会一一举例说明。具体对应到SQL语句中,就是over语句的部分。
窗口定义的相关类有两个:
Spark的文档里关于Window类的说明很奇怪,类的函数一个都没提到。实际上写代码时需要import的就是Window类,查看了Spark源码里这个类文件,这些函数确实都是有的。这是一个普通的工具类,用来定义窗口,相关函数如下:
函数原型 | 功能说明 |
---|---|
partitionBy(colName: String, colNames: String*): WindowSpec | 按给定的列名对整个数据分区 |
partitionBy(cols: Column*): WindowSpec | 同上,只是参数类型的差异 |
orderBy(colName: String, colNames: String*): WindowSpec | 分区内的数据按指定列排序 |
orderBy(cols: Column*): WindowSpec | 同上 |
rowsBetween(start: Long, end: Long): WindowSpec | 指定区间内的进一步区间限制 |
rangeBetween(start: Long, end: Long): WindowSpec | 指定区间内的进一步区间限制 |
rowsBetween函数和rangeBetween函数的参数都是Long型,就是靠start和end这两个值来约束区间。看Spark的文档里,把这个区间叫做frame。要注意的是,窗口函数是针对每一行来处理的,所以这里的start和end都是相对于当前行这个概念。
文档中functions类里单列出的窗口函数不多,看了一下Spark v2.1.0里面提供的总共是8个,主要是排序相关的函数。实际上,到最新版的v2.4.3也一样。一般的聚合函数,像是first, last, count, sum之类的也都是可以用于窗口计算的。
函数名 | 函数功能 | 函数原型 |
cume_dist | 计算窗口范围内的累积分布 | cume_dist(): [Column] |
dense_rank | 返回窗口内的排名。和下面的rank的区别是不会跳过排名的序号。比如有两个并列第一,那么第三个的rank是2,而rank的话,是3,跳过了2 | dense_rand(): [Column] |
rank | 返回窗口内的排名,主要是注意与dense_rank的区别 | rank(): [Column] |
row_number | 返回窗口内从1开始的序号 | row_number(): [Column] |
percent_rank | 返回窗口内的相对排名 | percent_rank(): [Column] |
lag | 返回当前行之前某几行的值,不同的原型有些细节不一样,比如null值怎么处理之类的。 | lag(e: [Column], offset: Int, defaultValue: Any): [Column] |
lag(columnName: String, offset: Int, defaultValue: Any): [Column] | ||
lag(columnName: String, offset: Int): [Column] | ||
lag(e: [Column], offset: Int): [Column] | ||
lead | 返回当前行之后的某几行的值,不同的原型有些细节不一样,比如null值怎么处理之类的。跟lag的处理一样,只是方向是反的 | lead(e: [Column], offset: Int, defaultValue: Any): [Column] |
lead(columnName: String, offset: Int, defaultValue: Any): [Column] | ||
lead(e: [Column], offset: Int): [Column] | ||
lead(columnName: String, offset: Int): [Column] | ||
ntile | ntile是个将窗口内的记录尽量均匀分组的函数,返回分组后记录对应的组id。 | ntile(n: Int): [Column] |
概念性的东西介绍完了,再来结合数据看一下就应该能彻底弄清楚窗口函数到底是怎么回事了。
scala> val df=Seq((0,1,100),(1,2,50),(2,3,40),
| (3,1,70),(4,2,60),(5,3,40),
| (6,1,30),(7,2,50),(8,3,100),
| (9,1,51),(10,2,72), (11,3,35),
| (12,1,45),(13,2,25)
| ).toDF("id","cate","key")
df: org.apache.spark.sql.DataFrame = [id: int, cate: int ... 1 more field]
scala> df.show
+---+----+---+
| id|cate|key|
+---+----+---+
| 0| 1|100|
| 1| 2| 50|
| 2| 3| 40|
| 3| 1| 70|
| 4| 2| 60|
| 5| 3| 40|
| 6| 1| 30|
| 7| 2| 50|
| 8| 3|100|
| 9| 1| 51|
| 10| 2| 72|
| 11| 3| 35|
| 12| 1| 45|
| 13| 2| 25|
+---+----+---+
### 导入Window工具类
scala> import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.expressions.Window
先来看下rowsBetween函数和rangeBetween函数的区别。按cate列分区,分区内根据key列排序,所有其他条件都一样,对比一下。
### 等价sql (rowsBetween样例)
select *, first(id) over w_rows first_value,
collect_list(id) over w_rows list_value
from tbl
window w_rows as (partition by cate
order by key
rows between 1 preceding and 1 following)
###
scala> val w_rows = Window.partitionBy("cate").orderBy("key").rowsBetween(-1,1)
w_rows: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@45e4debe
scala> df.withColumn("first_value",first($"id").over(w_rows)
| ).withColumn("list_value",collect_list($"id").over(w_rows)
| ).show
+---+----+---+-----------+----------+
| id|cate|key|first_value|list_value|
+---+----+---+-----------+----------+
| 6| 1| 30| 6| [6, 12]|
| 12| 1| 45| 6|[6, 12, 9]|
| 9| 1| 51| 12|[12, 9, 3]|
| 3| 1| 70| 9| [9, 3, 0]|
| 0| 1|100| 3| [3, 0]|
| 11| 3| 35| 11| [11, 2]|
| 2| 3| 40| 11|[11, 2, 5]|
| 5| 3| 40| 2| [2, 5, 8]| <-- frame是3行,所以list_value里面是(2,5,8)
| 8| 3|100| 5| [5, 8]|
| 13| 2| 25| 13| [13, 1]|
| 1| 2| 50| 13|[13, 1, 7]|
| 7| 2| 50| 1| [1, 7, 4]|
| 4| 2| 60| 7|[7, 4, 10]|
| 10| 2| 72| 4| [4, 10]|
+---+----+---+-----------+----------+
### 等价sql (rangeBetween样例)
select *, first(id) over w_range first_value,
collect_list(id) over w_range list_value
from tbl
window w_range as (partition by cate
order by key
range between 1 preceding and 1 following)
###
scala> val w_range = Window.partitionBy("cate").orderBy("key").rangeBetween(-1,1)
w_range: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@4548f47
scala> df.withColumn("first_value",first($"id").over(w_range)
| ).withColumn("list_value",collect_list($"id").over(w_range)
| ).show
+---+----+---+-----------+----------+
| id|cate|key|first_value|list_value|
+---+----+---+-----------+----------+
| 6| 1| 30| 6| [6]|
| 12| 1| 45| 12| [12]|
| 9| 1| 51| 9| [9]|
| 3| 1| 70| 3| [3]|
| 0| 1|100| 0| [0]|
| 11| 3| 35| 11| [11]|
| 2| 3| 40| 2| [2, 5]|
| 5| 3| 40| 2| [2, 5]| <-- frame是2行,所以list_value里面是(2,5)
| 8| 3|100| 8| [8]|
| 13| 2| 25| 13| [13]|
| 1| 2| 50| 1| [1, 7]|
| 7| 2| 50| 1| [1, 7]|
| 4| 2| 60| 4| [4]|
| 10| 2| 72| 10| [10]|
+---+----+---+-----------+----------+
frame只有2行,前面的key:35<40-1,后面的key:100>40+1,所以list_value里面是key=40的两行的id(2,5)
可以看到这两个操作,只有区间指定的不一样,前面调用的rowsBetween,后面调用的是rangeBetween。可以看到结果并不一样。就是因为这俩的frame不相同:
上图是对应的rowsBetween frame的情况。partitionBy(“cate”)语句导致数据对cate字段做分区,整个数据集被分为三个部分,分别对应cate=1,2,3的情况,和group语句一样。随后的orderBy(“key”)语句是在分区之后,对每个分区内部的所有行按照key字段的值排序,默认升序。最后来看rowsBetween(-1,1)的效果。这里的(-1,1)就是对frame的边界限制,上面已经提到了这个区间是相对于当前行的,也就是当前行的-1行,当前行,以及当前行的+1行。就像上图中所示,经过partitionBy和orderBy之后,当前行为id=5的这一行的frame就是红色区域的三行数据。最后的first(id)语句就是针对这个frame的数据集得到了frame里面的第一行,所以返回的是2。从collect_list函数的结果也能清楚的看到frame就是对应id为(2,5,8)的这三行。
再看看rangeBetween的示意图:
和上面的rowsBetween一样,同样的partition,同样的order,只不过frame是由rangeBetween(-1,1)来限定边界。可以看到上面的红色frame区域比之前rowsBetween少了一行。这是因为rowsBetween是根据分区排序后的物理行范围来确定区间,而rangeBetween是根据orderBy的值的范围来确定区间的。rangeBetween(-1,1)表示的是相对于当前行的排序字段key的值的范围。对上图来说,id=5这一行的排序字段key=40,所以frame限定的是当前分区中,key值范围属于[40-1,40+1]区间的所有行。所以,cate=3的这个分区中,key=35和key=100的行都不在范围内,被排除。list_value字段中也可以明确的看到整个frame的候选id就是(2,5)两行。最后,first(id)函数看到的就是红色frame部分的id=2和id=5这两行,巧了,结果也是2。
关于frame的界限,再补充一下。因为可能实际数据中不知道分区内总共有多少行,如果要表达整个区间,可以用( Window.unboundedPreceding, Window.unboundedFollowing)来确定。也就是说,Window.unboundedPreceding表示partition内的第一行,Window.unboundedFollowing表示partition内的最后一行。而当前行建议是采用Window.currentRow来表示,尽管写个0来表示当前行目前暂时也是支持的。
另外,rowsBetween和rangeBetween不是定义窗口时必须的语句。在不写的情况下,默认的frame是rowsBetween(Window.unboundedPreceding, Window.currentRow),代码验证如下:
### 等价sql
select *, first(id) over w first_value,
collect_list(id) over w list_value
from tbl
window w as (partition by cate
order by key)
###
scala> val w = Window.partitionBy("cate").orderBy("key")
w: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@a530424
scala> df.withColumn("first_value",first($"id").over(w)
| ).withColumn("list_value",collect_list($"id").over(w)
| ).show
+---+----+---+-----------+-----------------+
| id|cate|key|first_value| list_value|
+---+----+---+-----------+-----------------+
| 6| 1| 30| 6| [6]|
| 12| 1| 45| 6| [6, 12]|
| 9| 1| 51| 6| [6, 12, 9]|
| 3| 1| 70| 6| [6, 12, 9, 3]|
| 0| 1|100| 6| [6, 12, 9, 3, 0]|
| 11| 3| 35| 11| [11]|
| 2| 3| 40| 11| [11, 2, 5]|
| 5| 3| 40| 11| [11, 2, 5]|
| 8| 3|100| 11| [11, 2, 5, 8]|
| 13| 2| 25| 13| [13]|
| 1| 2| 50| 13| [13, 1, 7]|
| 7| 2| 50| 13| [13, 1, 7]|
| 4| 2| 60| 13| [13, 1, 7, 4]|
| 10| 2| 72| 13|[13, 1, 7, 4, 10]|
+---+----+---+-----------+-----------------+
重要的概念清楚了,剩下的就简单了。看些示例熟悉一下用法就好了。
lag()和lead()函数可以把当前行的前后几行数据列填充到当前行。
## 常用于统计日增量,月增量这种需要跨行加减来对比的情况。
### 等价sql
select *, lag(key,1,0) over w as lag_row_key,
lead(key,1,0) over w as lead_row_key
from tbl
window w as (partition by cate order by key)
###
scala> val w_rows = Window.partitionBy("cate").orderBy("key")
w_rows: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@2c5748e1
scala> df.withColumn("lag_row_key",lag($"key",1,0).over(w)
| ).withColumn("lead_row_key",lead($"key",1,0).over(w)
| ).show
+---+----+---+-----------+------------+
| id|cate|key|lag_row_key|lead_row_key|
+---+----+---+-----------+------------+
| 6| 1| 30| 0| 45|
| 12| 1| 45| 30| 51|
| 9| 1| 51| 45| 70|
| 3| 1| 70| 51| 100|
| 0| 1|100| 70| 0|
| 11| 3| 35| 0| 40|
| 2| 3| 40| 35| 40|
| 5| 3| 40| 40| 100|
| 8| 3|100| 40| 0|
| 13| 2| 25| 0| 50|
| 1| 2| 50| 25| 50|
| 7| 2| 50| 50| 60|
| 4| 2| 60| 50| 72|
| 10| 2| 72| 60| 0|
+---+----+---+-----------+------------+
scala> val w = Window.partitionBy("cate").orderBy("key")
w: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@2c5748e1
scala> df.withColumn("dense_rank",dense_rank().over(w)
| ).withColumn("rank",rank().over(w)
| ).withColumn("rows_number", row_number().over(w)
| ).show
+---+----+---+----------+----+-----------+
| id|cate|key|dense_rank|rank|rows_number|
+---+----+---+----------+----+-----------+
| 6| 1| 30| 1| 1| 1|
| 12| 1| 45| 2| 2| 2|
| 9| 1| 51| 3| 3| 3|
| 3| 1| 70| 4| 4| 4|
| 0| 1|100| 5| 5| 5|
| 11| 3| 35| 1| 1| 1|
| 2| 3| 40| 2| 2| 2|
| 5| 3| 40| 2| 2| 3| <-- rows_number不管相同key值,就是往下依次递增
| 8| 3|100| 3| 4| 4| <-- dense_rank和rank的区别为是否跳过因相同排名造成的gap
| 13| 2| 25| 1| 1| 1|
| 1| 2| 50| 2| 2| 2|
| 7| 2| 50| 2| 2| 3|
| 4| 2| 60| 3| 4| 4|
| 10| 2| 72| 4| 5| 5|
+---+----+---+----------+----+-----------+
把partition内的n行数据等分成k份,就好像n个鸡蛋,放到k个篮子里。只需要依次在每个篮子里放一个,一轮结束以后还剩下鸡蛋就再来一轮,再来,再来。。。完事。得到的结果就是前面若干个篮子里的鸡蛋数为ceil(n/k),后面篮子里的鸡蛋数为floor(n/k),各个篮子里的鸡蛋数最多差1个,十分均匀。
scala> val w = Window.partitionBy("cate").orderBy("key")
w: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@2c5748e1
###等价sql
select *, ntile(3) over (partition by cate order by key) as ntile_3
from tbl
###
scala> df.withColumn("ntile_3",ntile(3).over(w)).show
+---+----+---+-------+
| id|cate|key|ntile_3|
+---+----+---+-------+
| 6| 1| 30| 1|
| 12| 1| 45| 1|
| 9| 1| 51| 2|
| 3| 1| 70| 2|
| 0| 1|100| 3|
| 11| 3| 35| 1|
| 2| 3| 40| 1|
| 5| 3| 40| 2|
| 8| 3|100| 3|
| 13| 2| 25| 1|
| 1| 2| 50| 1|
| 7| 2| 50| 2|
| 4| 2| 60| 2|
| 10| 2| 72| 3|
+---+----+---+-------+
scala> val w = Window.partitionBy("cate").orderBy("key")
w: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@2c5748e1
###等价sql
select *, rank() over (partition by cate order by key) as rank,
round(percent_rank() over (partition by cate order by key),2) as percent_rank
from tbl
###
scala> df.withColumn("rank",rank().over(w)).withColumn("percent_rank", round(percent_rank().over(w),3)).show
+---+----+---+----+------------+
| id|cate|key|rank|percent_rank|
+---+----+---+----+------------+
| 6| 1| 30| 1| 0.0|
| 12| 1| 45| 2| 0.25|
| 9| 1| 51| 3| 0.5|
| 3| 1| 70| 4| 0.75|
| 0| 1|100| 5| 1.0|
| 11| 3| 35| 1| 0.0|
| 2| 3| 40| 2| 0.333| <-- (rank_in_partition - 1)/(max_rank_in_partition - 1)
| 5| 3| 40| 2| 0.333|
| 8| 3|100| 4| 1.0|
| 13| 2| 25| 1| 0.0|
| 1| 2| 50| 2| 0.25|
| 7| 2| 50| 2| 0.25|
| 4| 2| 60| 4| 0.75|
| 10| 2| 72| 5| 1.0|
+---+----+---+----+------------+
准确的说就是partition内小于等于(降序的话就是大于等于)当前行的rank值的行数除以这个partition的总行数。看到很多资料里都说的是小于等于当前值的个数除以分区内总行数,对于单个排序列的情况,这么说是没错的。但是对于partition内按多个列排序的情况要稍微复杂一点,不是单看主排序列的值,而是所有的排序值都要考虑进去。本质上看,还是多列排序后的rank值的大小分布。
### 单个列排序的情况下,等价sql
select *, rank() over (partition by cate order by key) as rank,
cume_dist() over (partition by cate order by key) as cume_dist
from tbl
###
scala> val w = Window.partitionBy("cate").orderBy("key")
w: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@2c5748e1
scala> df.withColumn("rank",rank().over(w)).withColumn("cume_dist", cume_dist().over(w)).show
+---+----+---+----+---------+
| id|cate|key|rank|cume_dist|
+---+----+---+----+---------+
| 6| 1| 30| 1| 0.2|
| 12| 1| 45| 2| 0.4|
| 9| 1| 51| 3| 0.6|
| 3| 1| 70| 4| 0.8|
| 0| 1|100| 5| 1.0|
| 11| 3| 35| 1| 0.25|
| 2| 3| 40| 2| 0.75|
| 5| 3| 40| 2| 0.75|
| 8| 3|100| 4| 1.0|
| 13| 2| 25| 1| 0.2|
| 1| 2| 50| 2| 0.6|
| 7| 2| 50| 2| 0.6|
| 4| 2| 60| 4| 0.8|
| 10| 2| 72| 5| 1.0|
+---+----+---+----+---------+
### 多列排序时,等价sql
select *, rank() over (partition by cate order by key) as rank,
cume_dist() over (partition by cate order by key, id desc) as cume_dist
from tbl
###
scala> val w = Window.partitionBy("cate").orderBy($"key",$"id".desc)
w: org.apache.spark.sql.expressions.WindowSpec = org.apache.spark.sql.expressions.WindowSpec@42d99e91
scala> df.withColumn("rank",rank().over(w)).withColumn("cume_dist", cume_dist().over(w)).show
+---+----+---+----+---------+
| id|cate|key|rank|cume_dist|
+---+----+---+----+---------+
| 6| 1| 30| 1| 0.2|
| 12| 1| 45| 2| 0.4|
| 9| 1| 51| 3| 0.6|
| 3| 1| 70| 4| 0.8|
| 0| 1|100| 5| 1.0|
| 11| 3| 35| 1| 0.25|
| 5| 3| 40| 2| 0.5|
| 2| 3| 40| 3| 0.75|
| 8| 3|100| 4| 1.0|
| 13| 2| 25| 1| 0.2|
| 7| 2| 50| 2| 0.4|
| 1| 2| 50| 3| 0.6|
| 4| 2| 60| 4| 0.8|
| 10| 2| 72| 5| 1.0|
+---+----+---+----+---------+
最后要说明一下的是,paritition和order by都不是必须的,看业务的需要吧。不过,数据量巨大的话,不做partition直接全表分析,所有的数据被拉到一起来处理,会爆炸的!!!
嗯,差不多就这些吧。这么写一遍,感觉自己也更清晰一些了,真好~