【Python笔记】SparkSQL の 窗口函数

1 spark.sql中の应用

1.1 基础数据

from pyspark.sql.types import *


schema = StructType().add('name', StringType(), True).add('create_time', TimestampType(), True).add('department', StringType(), True).add('salary', IntegerType(), True)
df = spark.createDataFrame([
    ("Tom", datetime.strptime("2020-01-01 00:01:00", "%Y-%m-%d %H:%M:%S"), "Sales", 4500),
    ("Georgi", datetime.strptime("2020-01-02 12:01:00", "%Y-%m-%d %H:%M:%S"), "Sales", 4200),
    ("Kyoichi", datetime.strptime("2020-02-02 12:10:00", "%Y-%m-%d %H:%M:%S"), "Sales", 3000),    
    ("Berni", datetime.strptime("2020-01-10 11:01:00", "%Y-%m-%d %H:%M:%S"), "Sales", 4700),
    ("Berni", datetime.strptime("2020-01-07 11:01:00", "%Y-%m-%d %H:%M:%S"), "Sales", None),    
    ("Guoxiang", datetime.strptime("2020-01-08 12:11:00", "%Y-%m-%d %H:%M:%S"), "Sales", 4200),   
    ("Parto", datetime.strptime("2020-02-20 12:01:00", "%Y-%m-%d %H:%M:%S"), "Finance", 2700),
    ("Anneke", datetime.strptime("2020-01-02 08:20:00", "%Y-%m-%d %H:%M:%S"), "Finance", 3300),
    ("Sumant", datetime.strptime("2020-01-30 12:01:05", "%Y-%m-%d %H:%M:%S"), "Finance", 3900),
    ("Jeff", datetime.strptime("2020-01-02 12:01:00", "%Y-%m-%d %H:%M:%S"), "Marketing", 3100),
    ("Patricio", datetime.strptime("2020-01-05 12:18:00", "%Y-%m-%d %H:%M:%S"), "Marketing", 2500)
], schema=schema)
df.createOrReplaceTempView('salary')
df.show()
+--------+-------------------+----------+------+
|    name|        create_time|department|salary|
+--------+-------------------+----------+------+
|     Tom|2020-01-01 00:01:00|     Sales|  4500|
|  Georgi|2020-01-02 12:01:00|     Sales|  4200|
| Kyoichi|2020-02-02 12:10:00|     Sales|  3000|
|   Berni|2020-01-10 11:01:00|     Sales|  4700|
|   Berni|2020-01-07 11:01:00|     Sales|  null|
|Guoxiang|2020-01-08 12:11:00|     Sales|  4200|
|   Parto|2020-02-20 12:01:00|   Finance|  2700|
|  Anneke|2020-01-02 08:20:00|   Finance|  3300|
|  Sumant|2020-01-30 12:01:05|   Finance|  3900|
|    Jeff|2020-01-02 12:01:00| Marketing|  3100|
|Patricio|2020-01-05 12:18:00| Marketing|  2500|
+--------+-------------------+----------+------+

1.2 窗口函数

ranking functions

sql DataFrame 功能
row_number rowNumber 从1~n的唯一序号值
rank rank 与denseRank一样,都是排名,对于相同的数值,排名一致。区别:rank会跳过并列的排名
dense_rank denseRank 不会跳过并列的排名
percent_rank percentRank 计算公式: (组内排名-1)/(组内行数-1),如果组内只有1行,则结果为0
ntile ntile 将组内数据排序后,按照指定的n切分为n个桶,该值为当前行的桶号(桶号从1开始)
spark.sql("""
SELECT
    name 
    ,department
    ,salary
    ,row_number() over(partition by department order by salary) as index
    ,rank() over(partition by department order by salary) as rank
    ,dense_rank() over(partition by department order by salary) as dense_rank
    ,percent_rank() over(partition by department order by salary) as percent_rank
    ,ntile(2) over(partition by department order by salary) as ntile
FROM salary
""").toPandas()
name department salary index rank dense_rank percent_rank ntile
0 Patricio Marketing 2500.0 1 1 1 0.0 1
1 Jeff Marketing 3100.0 2 2 2 1.0 2
2 Berni Sales NaN 1 1 1 0.0 1
3 Kyoichi Sales 3000.0 2 2 2 0.2 1
4 Georgi Sales 4200.0 3 3 3 0.4 1
5 Guoxiang Sales 4200.0 4 3 3 0.4 2
6 Tom Sales 4500.0 5 5 4 0.8 2
7 Berni Sales 4700.0 6 6 5 1.0 2
8 Parto Finance 2700.0 1 1 1 0.0 1
9 Anneke Finance 3300.0 2 2 2 0.5 1
10 Sumant Finance 3900.0 3 3 3 1.0 2

analytic functions

sql DataFrame 功能
cume_dist cumeDist 计算公式: 组内小于等于值当前行数/组内总行数
lag lag lag(input, [offset,[default]]) 当前index
lead lead 与lag相反
first_value first_value 取分组内排序后,截止到当前行,第一个值
last_value last_value 取分组内排序后,截止到当前行,最后一个值
spark.sql("""
SELECT
    name 
    ,department
    ,salary
    ,row_number() over(partition by department order by salary) as index
    ,cume_dist() over(partition by department order by salary) as cume_dist
    ,lag(salary, 1) over(partition by department order by salary) as lag -- 当前行向上
    ,lead(salary, 1) over(partition by department order by salary) as lead -- 当前行向下
    ,lag(salary, 0) over(partition by department order by salary) as lag_0
    ,lead(salary, 0) over(partition by department order by salary) as lead_0
    ,first_value(salary) over(partition by department order by salary) as first_value
    ,last_value(salary) over(partition by department order by salary) as last_value 
FROM salary
""").toPandas()
name department salary index cume_dist lag lead lag_0 lead_0 first_value last_value
0 Patricio Marketing 2500.0 1 0.500000 NaN 3100.0 2500.0 2500.0 2500.0 2500.0
1 Jeff Marketing 3100.0 2 1.000000 2500.0 NaN 3100.0 3100.0 2500.0 3100.0
2 Berni Sales NaN 1 0.166667 NaN 3000.0 NaN NaN NaN NaN
3 Kyoichi Sales 3000.0 2 0.333333 NaN 4200.0 3000.0 3000.0 NaN 3000.0
4 Georgi Sales 4200.0 3 0.666667 3000.0 4200.0 4200.0 4200.0 NaN 4200.0
5 Guoxiang Sales 4200.0 4 0.666667 4200.0 4500.0 4200.0 4200.0 NaN 4200.0
6 Tom Sales 4500.0 5 0.833333 4200.0 4700.0 4500.0 4500.0 NaN 4500.0
7 Berni Sales 4700.0 6 1.000000 4500.0 NaN 4700.0 4700.0 NaN 4700.0
8 Parto Finance 2700.0 1 0.333333 NaN 3300.0 2700.0 2700.0 2700.0 2700.0
9 Anneke Finance 3300.0 2 0.666667 2700.0 3900.0 3300.0 3300.0 2700.0 3300.0
10 Sumant Finance 3900.0 3 1.000000 3300.0 NaN 3900.0 3900.0 2700.0 3900.0

aggregate functions

只是在一定窗口里实现一些普通的聚合函数

sql 功能
avg 平均值
sum 求和
min 最小值
max 最大值
spark.sql("""
SELECT
    name 
    ,department
    ,salary
    ,row_number() over(partition by department order by salary) as index
    ,sum(salary) over(partition by department order by salary) as sum
    ,avg(salary) over(partition by department order by salary) as avg
    ,min(salary) over(partition by department order by salary) as min
    ,max(salary) over(partition by department order by salary) as max
FROM salary
""").toPandas()
name department salary index sum avg min max
0 Patricio Marketing 2500.0 1 2500.0 2500.0 2500.0 2500.0
1 Jeff Marketing 3100.0 2 5600.0 2800.0 2500.0 3100.0
2 Berni Sales NaN 1 NaN NaN NaN NaN
3 Kyoichi Sales 3000.0 2 3000.0 3000.0 3000.0 3000.0
4 Georgi Sales 4200.0 3 11400.0 3800.0 3000.0 4200.0
5 Guoxiang Sales 4200.0 4 11400.0 3800.0 3000.0 4200.0
6 Tom Sales 4500.0 5 15900.0 3975.0 3000.0 4500.0
7 Berni Sales 4700.0 6 20600.0 4120.0 3000.0 4700.0
8 Parto Finance 2700.0 1 2700.0 2700.0 2700.0 2700.0
9 Anneke Finance 3300.0 2 6000.0 3000.0 2700.0 3300.0
10 Sumant Finance 3900.0 3 9900.0 3300.0 2700.0 3900.0

1.3 窗口子句

ROWS/RANG窗口子句: 用于控制窗口的尺寸边界,有两种(ROW,RANGE)

  • ROWS: 物理窗口,数据筛选基于排序后的index
  • RANGE: 逻辑窗口,数据筛选基于值

语法:OVER (PARTITION BY … ORDER BY … frame_type BETWEEN start AND end)

有以下5种边界

  • CURRENT ROW:
  • UNBOUNDED PRECEDING: 分区第一行
  • UNBOUNDED FOLLOWING: 分区最后一行
  • n PRECEDING: 当前行,向前n行
  • n FOLLOWING: 当前行,向后n行
  • UNBOUNDED: 起点
spark.sql("""
SELECT
    name 
    ,department
    ,create_time
    ,row_number() over(partition by department order by create_time) as index
    ,row_number() over(partition by department order by (case when salary is not null then create_time end)) as index_ignore_null
    ,salary    
    ,collect_list(salary) over(partition by department order by create_time rows between UNBOUNDED PRECEDING AND 1 PRECEDING) as before_salarys
    ,last(salary) over(partition by department order by create_time rows between UNBOUNDED PRECEDING AND 1 PRECEDING) as before_salary1
    ,lag(salary, 1) over(partition by department order by create_time) as before_salary2
    ,lead(salary, 1) over(partition by department order by create_time) as after_salary   
FROM salary
ORDER BY department, index
""").toPandas()
name department create_time index index_ignore_null salary before_salarys before_salary1 before_salary2 after_salary
0 Anneke Finance 2020-01-02 08:20:00 1 1 3300.0 [] NaN NaN 3900.0
1 Sumant Finance 2020-01-30 12:01:05 2 2 3900.0 [3300] 3300.0 3300.0 2700.0
2 Parto Finance 2020-02-20 12:01:00 3 3 2700.0 [3300, 3900] 3900.0 3900.0 NaN
3 Jeff Marketing 2020-01-02 12:01:00 1 1 3100.0 [] NaN NaN 2500.0
4 Patricio Marketing 2020-01-05 12:18:00 2 2 2500.0 [3100] 3100.0 3100.0 NaN
5 Tom Sales 2020-01-01 00:01:00 1 2 4500.0 [] NaN NaN 4200.0
6 Georgi Sales 2020-01-02 12:01:00 2 3 4200.0 [4500] 4500.0 4500.0 NaN
7 Berni Sales 2020-01-07 11:01:00 3 1 NaN [4500, 4200] 4200.0 4200.0 4200.0
8 Guoxiang Sales 2020-01-08 12:11:00 4 4 4200.0 [4500, 4200] NaN NaN 4700.0
9 Berni Sales 2020-01-10 11:01:00 5 5 4700.0 [4500, 4200, 4200] 4200.0 4200.0 3000.0
10 Kyoichi Sales 2020-02-02 12:10:00 6 6 3000.0 [4500, 4200, 4200, 4700] 4700.0 4700.0 NaN
# 同一个部门,上个非空工资入职同事的收入
spark.sql("""
SELECT
    name
    ,department
    ,create_time
    ,index
    ,salary
    ,before_salarys[size(before_salarys)-1] as before_salary
FROM(
    SELECT
        name 
        ,department
        ,create_time
        ,row_number() over(partition by department order by create_time) as index
        ,salary    
        ,collect_list(salary) over(partition by department order by create_time rows between UNBOUNDED PRECEDING AND 1 PRECEDING) as before_salarys 
    FROM salary
    ORDER BY department, index
) AS base
""").toPandas()
name department create_time index salary before_salary
0 Anneke Finance 2020-01-02 08:20:00 1 3300.0 NaN
1 Sumant Finance 2020-01-30 12:01:05 2 3900.0 3300.0
2 Parto Finance 2020-02-20 12:01:00 3 2700.0 3900.0
3 Jeff Marketing 2020-01-02 12:01:00 1 3100.0 NaN
4 Patricio Marketing 2020-01-05 12:18:00 2 2500.0 3100.0
5 Tom Sales 2020-01-01 00:01:00 1 4500.0 NaN
6 Georgi Sales 2020-01-02 12:01:00 2 4200.0 4500.0
7 Berni Sales 2020-01-07 11:01:00 3 NaN 4200.0
8 Guoxiang Sales 2020-01-08 12:11:00 4 4200.0 4200.0
9 Berni Sales 2020-01-10 11:01:00 5 4700.0 4200.0
10 Kyoichi Sales 2020-02-02 12:10:00 6 3000.0 4700.0

1.4 混合应用

spark.sql("""
SELECT
    name 
    ,department
    ,salary
    ,row_number() over(partition by department order by salary) as index
    ,salary - (min(salary) over(partition by department order by salary)) as salary_diff -- 比部门最低工资高多少
    ,min(salary) over() as min_salary_0 -- 最小工资
    ,first_value(salary) over(order by salary) as max_salary_1
    
    ,max(salary) over(order by salary) as current_max_salary_0 -- 截止到当前最大工资
    ,last_value(salary) over(order by salary) as current_max_salary_1 
    
    ,max(salary) over(partition by department order by salary rows between 1 FOLLOWING and 1 FOLLOWING) as next_salary_0 -- 按照salary排序下一条记录
    ,lead(salary) over(partition by department order by salary) as next_salary_1
FROM salary
WHERE salary is not null
""").toPandas()
name department salary index salary_diff min_salary_0 max_salary_1 current_max_salary_0 current_max_salary_1 next_salary_0 next_salary_1
0 Patricio Marketing 2500 1 0 2500 2500 2500 2500 3100.0 3100.0
1 Parto Finance 2700 1 0 2500 2500 2700 2700 3300.0 3300.0
2 Kyoichi Sales 3000 1 0 2500 2500 3000 3000 4200.0 4200.0
3 Jeff Marketing 3100 2 600 2500 2500 3100 3100 NaN NaN
4 Anneke Finance 3300 2 600 2500 2500 3300 3300 3900.0 3900.0
5 Sumant Finance 3900 3 1200 2500 2500 3900 3900 NaN NaN
6 Georgi Sales 4200 2 1200 2500 2500 4200 4200 4200.0 4200.0
7 Guoxiang Sales 4200 3 1200 2500 2500 4200 4200 4500.0 4500.0
8 Tom Sales 4500 4 1500 2500 2500 4500 4500 4700.0 4700.0
9 Berni Sales 4700 5 1700 2500 2500 4700 4700 NaN NaN

Reference: SparkSQL | 窗口函数

2 pyspark

文中对Window函数分类为三种:ranking functions,analytic functions,aggregate functions

  • ranking functions包括row_number(),rank(),dense_rank(),percent_rank(),ntile();
  • analytic functions包括cume_dist(),lag(), lead();
  • aggregate functions包括sum(),first(),last(),max(),min(),mean(),stddev()等。

2.1 Ranking functions

首先,假设我们的数据是如下形式:

# spark = SparkSession.builder.appName('Window functions').getOrCreate()
employee_salary = [
    ("Ali", "Sales", 8000),
    ("Bob", "Sales", 7000),
    ("Cindy", "Sales", 7500),
    ("Davd", "Finance", 10000),
    ("Elena", "Sales", 8000),
    ("Fancy", "Finance", 12000),
    ("George", "Finance", 11000),
    ("Haffman", "Marketing", 7000),
    ("Ilaja", "Marketing", 8000),
    ("Joey", "Sales", 9000)]
 
columns= ["name", "department", "salary"]
df = spark.createDataFrame(data = employee_salary, schema = columns)
df.show(truncate=False)
+-------+----------+------+
|name   |department|salary|
+-------+----------+------+
|Ali    |Sales     |8000  |
|Bob    |Sales     |7000  |
|Cindy  |Sales     |7500  |
|Davd   |Finance   |10000 |
|Elena  |Sales     |8000  |
|Fancy  |Finance   |12000 |
|George |Finance   |11000 |
|Haffman|Marketing |7000  |
|Ilaja  |Marketing |8000  |
|Joey   |Sales     |9000  |
+-------+----------+------+

row_number()

row_number() 可以用来给按照指定列排序的分组窗增加一个行序号,这个列从1开始依次递增,序数是依据分组窗的指定排序列依次从小到大变化。

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("row_number", F.row_number().over(windowSpec)).show(truncate=False)

按照部门对数据进行分组,然后按照薪水由高到低进行排序,结果如下:

+-------+----------+------+----------+
|name   |department|salary|row_number|
+-------+----------+------+----------+
|Joey   |Sales     |9000  |1         |
|Ali    |Sales     |8000  |2         |
|Elena  |Sales     |8000  |3         |
|Cindy  |Sales     |7500  |4         |
|Bob    |Sales     |7000  |5         |
|Fancy  |Finance   |12000 |1         |
|George |Finance   |11000 |2         |
|Davd   |Finance   |10000 |3         |
|Ilaja  |Marketing |8000  |1         |
|Haffman|Marketing |7000  |2         |
+-------+----------+------+----------+

观察上面的数据,你会发现,同样的薪水会有不同的行号,这是因为row_number() 是按照行来给定序号,其不关注实际数值的大小。由此我们可以引申出另一个用于给出排序数的函数rank。

rank()

rank() 用来给按照指定列排序的分组窗增加一个排序的序号,如果有相同数值,则排序数相同,下一个序数顺延一位。

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("rank",F.rank().over(windowSpec)).show(truncate=False)

按照部门进行分组,组内对薪水按照从高到低进行排序,结果如下:

+-------+----------+------+----+
|name   |department|salary|rank|
+-------+----------+------+----+
|Joey   |Sales     |9000  |1   |
|Ali    |Sales     |8000  |2   |
|Elena  |Sales     |8000  |2   |
|Cindy  |Sales     |7500  |4   |
|Bob    |Sales     |7000  |5   |
|Fancy  |Finance   |12000 |1   |
|George |Finance   |11000 |2   |
|Davd   |Finance   |10000 |3   |
|Ilaja  |Marketing |8000  |1   |
|Haffman|Marketing |7000  |2   |
+-------+----------+------+----+

上面的结果我们观察到,两个相同的8000排序都是2,而下一档排序数自然顺延至4了。说到这,不得不提另一个排序数函数dense_rank()

dense_rank()

dense_rank() 函数也是对分组窗进行排序,分组窗需指定排序列,排序时不考虑顺延,同样的值序号一致,后续数值排序不受影响。我们来看如下代码:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("dense_rank",F.dense_rank().over(windowSpec)).show()

按照部门进行分组,对组内的数据按照薪水进行从高到低进行排序,结果如下:

+-------+----------+------+----------+
|   name|department|salary|dense_rank|
+-------+----------+------+----------+
|   Joey|     Sales|  9000|         1|
|    Ali|     Sales|  8000|         2|
|  Elena|     Sales|  8000|         2|
|  Cindy|     Sales|  7500|         3|
|    Bob|     Sales|  7000|         4|
|  Fancy|   Finance| 12000|         1|
| George|   Finance| 11000|         2|
|   Davd|   Finance| 10000|         3|
|  Ilaja| Marketing|  8000|         1|
|Haffman| Marketing|  7000|         2|
+-------+----------+------+----------+

percent_rank()

一些业务场景下,我们需要计算不同数值的百分比排序数据。

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("percent_rank",F.percent_rank().over(windowSpec)).show()

按照部门进行分组,然后在组内对每个人的薪水进行排序,使用percent_rank() 增加排序列,结果如下:

+-------+----------+------+------------+
|   name|department|salary|percent_rank|
+-------+----------+------+------------+
|   Joey|     Sales|  9000|         0.0|
|    Ali|     Sales|  8000|        0.25|
|  Elena|     Sales|  8000|        0.25|
|  Cindy|     Sales|  7500|        0.75|
|    Bob|     Sales|  7000|         1.0|
|  Fancy|   Finance| 12000|         0.0|
| George|   Finance| 11000|         0.5|
|   Davd|   Finance| 10000|         1.0|
|  Ilaja| Marketing|  8000|         0.0|
|Haffman| Marketing|  7000|         1.0|
+-------+----------+------+------------+

上述结果可以理解为将dense_rank() 的结果进行归一化,即可得到0-1以内的百分数。percent_rank() 与SQL中的 PERCENT_RANK 函数效果一致。

ntile()

ntile()可将分组的数据按照指定数值n切分为n个部分,每一部分按照行的先后给定相同的序数。例如n指定为2,则将组内数据分为两个部分,第一部分序号为1,第二部分序号为2。理论上两部分数据行数是均等的,但当数据为奇数行时,中间的那一行归到前一部分。

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("ntile",F.ntile(2).over(windowSpec)).show()

按照部门对数据进行分组,然后在组内按照薪水高低进行排序,再使用ntile() 将组内数据切分为两个部分。结果如下:

+-------+----------+------+-----+
|   name|department|salary|ntile|
+-------+----------+------+-----+
|   Joey|     Sales|  9000|    1|
|    Ali|     Sales|  8000|    1|
|  Elena|     Sales|  8000|    1|
|  Cindy|     Sales|  7500|    2|
|    Bob|     Sales|  7000|    2|
|  Fancy|   Finance| 12000|    1|
| George|   Finance| 11000|    1|
|   Davd|   Finance| 10000|    2|
|  Ilaja| Marketing|  8000|    1|
|Haffman| Marketing|  7000|    2|
+-------+----------+------+-----+

2.2 Analytic functions

cume_dist()

cume_dist()函数用来获取数值的累进分布值

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("cume_dist",F.cume_dist().over(windowSpec)).show()

按照部门进行分组,对薪水进行排序,然后cume_dist()获取累进分布值,结果如下:

+-------+----------+------+------------------+
|   name|department|salary|         cume_dist|
+-------+----------+------+------------------+
|   Joey|     Sales|  9000|               0.2|
|    Ali|     Sales|  8000|               0.6|
|  Elena|     Sales|  8000|               0.6|
|  Cindy|     Sales|  7500|               0.8|
|    Bob|     Sales|  7000|               1.0|
|  Fancy|   Finance| 12000|0.3333333333333333|
| George|   Finance| 11000|0.6666666666666666|
|   Davd|   Finance| 10000|               1.0|
|  Ilaja| Marketing|  8000|               0.5|
|Haffman| Marketing|  7000|               1.0|
+-------+----------+------+------------------+

结果好像和前面的percent_rank()很类似对不对,于是我们联想到这个其实也是一种归一化结果,其按照rank() 的结果进行归一化处理。回想一下前面讲过的rank() 函数,并列排序会影响后续排序,于是序号中间可能存在隔断。这样Sales组的排序数就是1、2、2、4、5,归一化以后就得到了0.2、0.6、0.6、0.8、1。这个统计结果按照实际业务来理解就是:9000及以上的人占了20%,8000及以上的人占了60%,7500以上的人数占了80%,7000以上的人数占了100%,这样是不是就好理解多了。

lag()

lag() 函数用于寻找按照指定列排好序的分组内每个数值的上一个数值,通俗的说,就是数值排好序以后,寻找排在每个数值的上一个数值。

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("lag",F.lag("salary",1).over(windowSpec)).show()

按照部门进行分类,并按照薪水在组内进行排序,然后获取每一个薪水的上一个数值,结果如下:

+-------+----------+------+-----+
|   name|department|salary|  lag|
+-------+----------+------+-----+
|   Joey|     Sales|  9000| null|
|    Ali|     Sales|  8000| 9000|
|  Elena|     Sales|  8000| 8000|
|  Cindy|     Sales|  7500| 8000|
|    Bob|     Sales|  7000| 7500|
|  Fancy|   Finance| 12000| null|
| George|   Finance| 11000|12000|
|   Davd|   Finance| 10000|11000|
|  Ilaja| Marketing|  8000| null|
|Haffman| Marketing|  7000| 8000|
+-------+----------+------+-----+

与lag() 相对应的获取下一个数值的函数是lead() 。

lead()

lead() 用于获取排序后的数值的下一个,代码如下:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
df.withColumn("lead",F.lead("salary",1).over(windowSpec)).show()

按照部门进行分组,并在组内进行薪水排序,然后用lead获取每个薪水值的下一个数值,结果如下:

+-------+----------+------+-----+
|   name|department|salary| lead|
+-------+----------+------+-----+
|   Joey|     Sales|  9000| 8000|
|    Ali|     Sales|  8000| 8000|
|  Elena|     Sales|  8000| 7500|
|  Cindy|     Sales|  7500| 7000|
|    Bob|     Sales|  7000| null|
|  Fancy|   Finance| 12000|11000|
| George|   Finance| 11000|10000|
|   Davd|   Finance| 10000| null|
|  Ilaja| Marketing|  8000| 7000|
|Haffman| Marketing|  7000| null|
+-------+----------+------+-----+

实际业务场景中,假设我们获取了每个月的销售数据,我们可能想要知道,某月份与上一个月或下一个月数据相比怎么样,于是就可以使用lag和lead来进行数据分析了。

2.3 Aggregate Functions

常见的聚合函数有avg, sum, min, max, count, approx_count_distinct()等,我们用如下代码来同时使用这些函数:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
windowSpecAgg  = Window.partitionBy("department")

df.withColumn("row", F.row_number().over(windowSpec)) \
  .withColumn("avg", F.avg("salary").over(windowSpecAgg)) \
  .withColumn("sum", F.sum("salary").over(windowSpecAgg)) \
  .withColumn("min", F.min("salary").over(windowSpecAgg)) \
  .withColumn("max", F.max("salary").over(windowSpecAgg)) \
  .withColumn("count", F.count("salary").over(windowSpecAgg)) \
  .withColumn("distinct_count", F.approx_count_distinct("salary").over(windowSpecAgg)) \
  .show()
+-------+----------+------+---+-------+-----+-----+-----+-----+--------------+
|   name|department|salary|row|    avg|  sum|  min|  max|count|distinct_count|
+-------+----------+------+---+-------+-----+-----+-----+-----+--------------+
|   Joey|     Sales|  9000|  1| 7900.0|39500| 7000| 9000|    5|             4|
|    Ali|     Sales|  8000|  2| 7900.0|39500| 7000| 9000|    5|             4|
|  Elena|     Sales|  8000|  3| 7900.0|39500| 7000| 9000|    5|             4|
|  Cindy|     Sales|  7500|  4| 7900.0|39500| 7000| 9000|    5|             4|
|    Bob|     Sales|  7000|  5| 7900.0|39500| 7000| 9000|    5|             4|
|  Fancy|   Finance| 12000|  1|11000.0|33000|10000|12000|    3|             3|
| George|   Finance| 11000|  2|11000.0|33000|10000|12000|    3|             3|
|   Davd|   Finance| 10000|  3|11000.0|33000|10000|12000|    3|             3|
|  Ilaja| Marketing|  8000|  1| 7500.0|15000| 7000| 8000|    2|             2|
|Haffman| Marketing|  7000|  2| 7500.0|15000| 7000| 8000|    2|             2|
+-------+----------+------+---+-------+-----+-----+-----+-----+--------------+

需要注意的是 approx_count_distinct() 函数适用于窗函数的统计,而在groupby中通常用countDistinct()来代替该函数,用来求组内不重复的数值的条数。approx_count_distinct()取的是近似的数值,不太准确,使用需注意。从结果来看,统计值基本上是按照部门分组,统计组内的salary情况。如果我们只想要保留部门的统计结果,而将每个人的实际情况去掉,可以采用如下代码:

from pyspark.sql.window import Window
import pyspark.sql.functions as F

windowSpec  = Window.partitionBy("department").orderBy(F.desc("salary"))
windowSpecAgg  = Window.partitionBy("department")

df.withColumn("row", F.row_number().over(windowSpec)) \
  .withColumn("avg", F.avg("salary").over(windowSpecAgg)) \
  .withColumn("sum", F.sum("salary").over(windowSpecAgg)) \
  .withColumn("min", F.min("salary").over(windowSpecAgg)) \
  .withColumn("max", F.max("salary").over(windowSpecAgg)) \
  .withColumn("count", F.count("salary").over(windowSpecAgg)) \
  .withColumn("distinct_count", F.approx_count_distinct("salary").over(windowSpecAgg)) \
  .where(F.col("row")==1).select("department","avg","sum","min","max","count","distinct_count") \
  .show()
+----------+-------+-----+-----+-----+-----+--------------+
|department|    avg|  sum|  min|  max|count|distinct_count|
+----------+-------+-----+-----+-----+-----+--------------+
|     Sales| 7900.0|39500| 7000| 9000|    5|             4|
|   Finance|11000.0|33000|10000|12000|    3|             3|
| Marketing| 7500.0|15000| 7000| 8000|    2|             2|
+----------+-------+-----+-----+-----+-----+--------------+

Reference: PySpark–Window Functions

2.4 分组窗

分组窗在实际中用处还是很大的,部分关于Window的知识可移步 Window不同分组窗的使用
假设我们有以下数据:

from pyspark.sql import Row
from pyspark.sql.window import Window
from pyspark.sql.functions import mean, col

row = Row("name", "date", "score")
rdd = sc.parallelize([
    row("Ali", "2020-01-01", 10.0),
    row("Ali", "2020-01-02", 15.0),
    row("Ali", "2020-01-03", 20.0),
    row("Ali", "2020-01-04", 25.0),
    row("Ali", "2020-01-05", 30.0),
    row("Bob", "2020-01-01", 15.0),
    row("Bob", "2020-01-02", 20.0),
    row("Bob", "2020-01-03", 30.0)
])
df = rdd.toDF().withColumn("date", col("date").cast("date"))

我们使用分组的形式计算每个人的平均分,其他数据保留,则可用如下代码:

w1 = Window().partitionBy(col("name"))
df.withColumn("mean1", mean("score").over(w1)).show()
+----+----------+-----+------------------+
|name|      date|score|             mean1|
+----+----------+-----+------------------+
| Bob|2020-01-01| 15.0|21.666666666666668|
| Bob|2020-01-02| 20.0|21.666666666666668|
| Bob|2020-01-03| 30.0|21.666666666666668|
| Ali|2020-01-02| 15.0|              20.0|
| Ali|2020-01-05| 30.0|              20.0|
| Ali|2020-01-01| 10.0|              20.0|
| Ali|2020-01-03| 20.0|              20.0|
| Ali|2020-01-04| 25.0|              20.0|
+----+----------+-----+------------------+

从结果来看,新增加的一列mean1表示每个人所在的分组中所有分数的平均值。当然,你也可以求最大值、最小值或者方差之类的统计值。

下面我们来看一组变形的分组窗:

days = lambda i: i * 86400  # 一天转化为秒单位

w1 = Window().partitionBy(col("name"))
w2 = Window().partitionBy(col("name")).orderBy("date")
w3 = Window().partitionBy(col("name")).orderBy((col("date").cast("timestamp").cast("bigint")/3600/24)).rangeBetween(-4, 0)
w4 = Window().partitionBy(col("name")).orderBy("date").rowsBetween(Window.currentRow, 1)
  • w1就是常规的按照名字进行分组;
  • w2在按照名字分组的基础上,对其组内的日期按照从早到晚进行排序;
  • w3是在w2的基础上,增加了范围限制,限制在从前4天到当前日期的范围内;
  • w4则是在w2的基础上增加了行参数的限制,在当前行到下一行范围内。

是不是还是有些迷糊,不慌,来看下按照这些分组窗统计的结果:

df.withColumn("mean1", mean("score").over(w1))\
  .withColumn("mean2", mean("score").over(w2))\
  .withColumn("mean3", mean("score").over(w3))\
  .withColumn("mean4", mean("score").over(w4))\
  .show()
+----+----------+-----+-----+------------------+------------------+-----+
|name|      date|score|mean1|             mean2|             mean3|mean4|
+----+----------+-----+-----+------------------+------------------+-----+
| Bob|2020-01-01| 15.0| 30.0|              15.0|              15.0| 17.5|
| Bob|2020-01-02| 20.0| 30.0|              17.5|              17.5| 25.0|
| Bob|2020-01-03| 30.0| 30.0|21.666666666666668|21.666666666666668| 32.5|
| Bob|2020-01-04| 35.0| 30.0|              25.0|              25.0| 37.5|
| Bob|2020-01-05| 40.0| 30.0|              28.0|              28.0| 40.0|
| Bob|2020-01-06| 40.0| 30.0|              30.0|              33.0| 40.0|
| Ali|2020-01-01| 10.0| 20.0|              10.0|              10.0| 12.5|
| Ali|2020-01-02| 15.0| 20.0|              12.5|              12.5| 17.5|
| Ali|2020-01-03| 20.0| 20.0|              15.0|              15.0| 22.5|
| Ali|2020-01-04| 25.0| 20.0|              17.5|              17.5| 27.5|
| Ali|2020-01-05| 30.0| 20.0|              20.0|              20.0| 30.0|
+----+----------+-----+-----+------------------+------------------+-----+
  • mean1列很简单,就是每个name分组内所有分数的平均值
  • mean2比较有意思,分组窗是按照name分组后按照日期进行了排序,于是均值是在当前行及前面所有行的范围内进行计算,这个可以看每组最后一个mean2均值,都与mean1均值相等
  • mean3列是在当前行及往前数4天范围内计算均值,如Bob的最后一个mean3值是33,就是从2020-01-02开始计算的
  • mean4列每次只统计当前行和下一行的数值,如果没有下一行则是其本身。

  • Window.unboundedPreceding:前面所有行
  • Window.unboundedFollowing:后面所有行
  • Window.currentRow:当前行

而数值的正负表示往前或往后,大小表示行数


总结

  1. 单独的Window做聚合统计,仅对分组内所有数值进行计算;
  2. 添加orderBy排序的Window分组窗,统计时默认是从前面所有行到当前行进行计算;
  3. rangeBetween结合orderBy可用来限制指定范围内的数据,例如统计一周内数据的场景;
  4. rowsBetween用来限定前后指定行范围内的数据进行统计

你可能感兴趣的:(Python笔记,数据分析)