Spark DataFrame 使用UDF实现UDAF的一种方法

Background:
当我们使用Spark Dataframe的时候常常需要进行group by操作,然后针对这一个group算出一个结果来。即所谓的聚合操作

然而

Spark提供的 aggregation函数太少,常常不能满足我们的需要,怎么办呢?

Spark 贴心的提供了UDAF(User-defined aggregate function),听起来不错。
但是,这个函数实现起来太复杂,反正我是看的晕晕乎乎,难受的很。反倒是UDF的实现非常简单,无非是UDF针对所有行,UDAF针对一个group中的所有行。

So,两者在某种程度上是一样的。

下面我们就看看如何用UDF实现UDAF的功能

举个例子来说明问题:
我们有一个dataframe是长这样的:

+-------+-------+-------+
|groupid|column1|column2|
+-------+-------+-------+
|   1   |  1    |   7   |
|   1   |  12   |   9   |
|   1   |  30   |   8   |
|   1   |  18   |   1   |
|   1   |  19   |   13  |
|   1   |  15   |   20  |
|   2   |  41   |   2   |
|   2   |  50   |   19  |
|   2   |  16   |   11  |
|   2   |  27   |   5   |
|   3   |  83   |   6   |
|   3   |  91   |   15  |
|   3   |  10   |   8   |

我们想对它group by id,然后对每一个group里的内容进行自定义操作。
比如寻找某一列第三大的数、通过某两列的数据计算出一个参数等等很多user-define的操作。

抽象的步骤看这里:

STEP.1. 对想要操作的列执行 collect_list(),生成新列,此时一个group就是一行。
        +-------+--------------------------+-----------------------+
        |groupid|        column1           |        column2        |
        +-------+--------------------------+-----------------------+
        |   1   |  [1,12,30,18,19,15]  | [7,9,8,1,13,20]   |
        |   2   |      [41,50,16,27]       |      [2,19,11,5]      | 
        |   3   |        [83,91,10]        |      [6,15,8]         |
STEP.2.写一个UDF,传入参数为上边生成的列,相当于传入了一个或多个数组。
 import org.apache.spark.sql.functions._
    def createNewCol = udf((column1: collection.mutable.WrappedArray[Int], column2: collection.mutable.WrappedArray[Int]) => {  // udf function
      var balabala  //各种要用到的自定义变量 
      var resultArray = Array.empty[(Int, Int, Int)]
      for(column1.size):  //遍历计算
          result[i] = 对俩数组column1,column2进行某种计算操作 //一个group中第i行的结果
      resultArray[i]=(column1[i],column2[i],result[i])
      resultArray   //返回值
    })    
STEP.3.UDF中可以对数组做任意操作,你对数组想怎么操作就怎么操作,最后返回一个数组就可以了,长度和你传入的数组相同(显然),数组每个元素的格式是tuple的 (column1.vaule,column2.value, result)因为 column1,column2的值我们后边展开的时候还要用。
STEP.4.执行UDF函数,传入的第一步中生成的列,获得结果列newcolumn,存储UDF的返回值。此时一个group还是一行。
+-------+--------------------------+-----------------------+-------------------------------+
|groupid|        column1           |        column2        |          newcolumn            |
+-------+--------------------------+-----------------------+-------------------------------+
|   1   |  [1,12,30,18,19,15]  | [7,9,8,1,13,20]   | [(1,7,v1.1),(12,9,v1.2)...]   |
|   2   |      [41,50,16,27]       |      [2,19,11,5]      | [(41,2,v2.1),(50,19,v2.2)..]  |
|   3   |        [83,91,10]        |      [6,15,8]         | [(83,91,v3.1),(6,15,v3.2)..]  |
STEP.5. column1,column2可以丢掉了,因为用不到。
+-------+-------------------------------+
|groupid|          newcolumn            |
+-------+-------------------------------+
|   1   | [(1,7,v1.1),(12,9,v1.2)...]   |
|   2   | [(41,2,v2.1),(50,19,v2.2)..]  |
|   3   | [(83,91,v3.1),(6,15,v3.2)..]  |
STEP.6.对结果列执行 explode(col("newcolumn"))操作,相当于把数组撑开来到整个group中。
+-------+----------------------+
|groupid|         new          |
+-------+----------------------+
|   1   | (1,7,value1.1)    |
|   1   | (12,9,value1.2)   |
|   1   | (30,8,value1.3)   |
|   1   | (18,1,value1.4)   |
.....省略
|   2   |  (41,2,value2.1)     |
|   2   |  (50,19,value2.2)    |
|   3   |  (83,91,value3.1)    | ...大面积省略
    
STEP.7.把tuple分开成三列

select(col("groupid"), col("new._1").as("rownum"), col("new._2").as("column2"), col("new._3").as("resultcolumn")) //selecting as separate column

所有代码看这里:



df.groupBy("groupid").agg(collect_list("column1").as("column1"),collect_list("column2").as("column2")) // 把要操作的列转换成数组,作为group的一个列属性。
      .withColumn("newcolumn", createNewCol(col("column1"), col("column2")))  //把存储数组的列传入udf,返回一个新列     
      .drop("column1", "column2") //丢弃两个存储数组的列,因为用不到了                                                        
      .withColumn("new", explode(col("newcolumn"))) //把新计算出来的内容从一行explode到整个group
      .select(col("groupid"), col("new._1").as("rownum"), col("new._2").as("column2"), col("new._3").as("column3"))  //selecting as separate column                                         
      .show(false)

The end

实际案例就不举了,码字太麻烦了。 这里有一个,英文的,来自我的stackoverflow
PS:collect 是 一个shuffle算子,会特别消耗资源,如果出现OOM,别怪我

你可能感兴趣的:(Spark DataFrame 使用UDF实现UDAF的一种方法)