spark sql中的udf和udaf实现

 

今天没什么事,突然想起之前写过的sqark中SQL中的UDAF方法,这个还是挺有意思的,难度比蜂房中UDAF高,其中直接体现了火花的分而治之的细想,所以打算今天的博客在加一个火花SQL的UDF和UDAF编写。

直接进入正题。

1.udf函数的
编写.sqlContext.udf.register(“CTOF”,(degreesCelcius:Double)=>((degreesCelcius * 9.0 / 5.0)+ 32.0))
sqlContext.sql(“SELECT city,CTOF(avgLow)AS avgLowF ,CTOF(avgHigh)AS avgHighF FROM citytemps“)。show()
较为简单不做过多描述。

2.udfa函数的编写。


   
   
   
   
  1. import org.apache.spark.sql.Row
  2. import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
  3. import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}
  4. /**
  5. * @author 李春凯
  6. * 在开始前有必要了解一下 StructField.DataType 中支持的类型
  7. * 类型有NullType, DateType, TimestampType, BinaryType,IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType
  8. * 用法 : 在要使用该函数的地方需要使用sqlContext声明 函数名 才能使用
  9. * e.g sqlContext.udf.register("numsAvg", new MyUDAF)
  10. */
  11. class MyUDAF extends UserDefinedAggregateFunction {
  12. /**
  13. * 指定具体的输入数据的类型 支持多个值输入
  14. * 自段名称随意:Users can choose names to identify the input arguments - 这里可以是“name”,或者其他任意串
  15. * 这里支持的格式有 NullType, DateType, TimestampType, BinaryType,
  16. * IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType
  17. *
  18. */
  19. override def inputSchema: StructType = StructType(StructField("docid",StringType)::Nil)
  20. /**
  21. * 在进行聚合操作的时候所要处理的数据的中间结果类型
  22. * 支持多个值进行计算
  23. * 根据具体需求声明参数个数 可以是一个或者是多个 在initialize方法中初始化
  24. */
  25. override def bufferSchema: StructType = StructType(StructField("val",StringType)::Nil)
  26. /**
  27. * 输出类型的定义,可以是对应类型的Array或者是AraayBuffer
  28. * e.g StringType 可以输出 String Array[String] 和 AraayBuffer[String] 但是不能是String[]
  29. */
  30. override def dataType: DataType = StringType
  31. override def deterministic: Boolean = true
  32. /**
  33. * 初始化Buffer
  34. *
  35. * @param buffer 用于接收和存储输入的值
  36. * buffer(0)对应bufferSchema的第一个类型的值
  37. * buffer(1)对应bufferSchema的第二个类型的值
  38. * 以此类推
  39. * buffer(0) 不支持数组,集合和可变集合
  40. * e.g 支持StringType但是不支持Arry[String] ArryBuffer[String] 和 String[]
  41. * --- 该类主要用于参考
  42. * --- 该类实现另外多个列的一个字段组合成了一个列的字段
  43. */
  44. override def initialize(buffer: MutableAggregationBuffer): Unit = {
  45. buffer(0) = ""
  46. }
  47. /**
  48. * 将同久值与新值合并的地方
  49. * 注意: 这里的新参数【buffer】因为在初始化的时候就转换了所以在这里不用再次转换,
  50. * 但是input【inputSchema】是传入的新值类型是IntegerTyper,和inputSchema的bufferSchema的类型不一致 ,
  51. * 所以需要转换
  52. * @param buffer 旧值 如果之前没有值传入吗,这个值就是初始值
  53. * @param input 新值函数每次接收的值,空值不会传进来
  54. */
  55. override def update(buffer: MutableAggregationBuffer, input: Row): Unit ={
  56. val Str1 = buffer.getAs[String](0)
  57. val Str2 = input.getAs[String](0)
  58. if(Str1.trim.equals("")){
  59. buffer(0)=buffer.getAs[String](0)+input.getAs[String](0)
  60. }else{
  61. buffer(0) =buffer.getAs[String](0)+"||"+input.getAs[String](0)
  62. }
  63. }
  64. /**
  65. * 将多个线程中的的内容合并在一个 Buffer1中,
  66. * 这个操作类似于mapreduce中的conbiner
  67. * 也类似于spark中的conbinerbykey函数的分区合并
  68. * @param buffer1 第一个线程/上一个线程【或者是分区,这里我不是很清楚到底有没有分区的概念 】
  69. * @param buffer2 第二个线程或者说是下一个线程
  70. */
  71. override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
  72. val str1 = buffer1.getString(0)
  73. val str2 = buffer2.getString(0)
  74. if(str1.trim.equals("")){
  75. buffer1(0)=str1+str2
  76. }else{
  77. buffer1(0) = str1+"||"+str2
  78. }
  79. // buffer1(0) = buffer1.getAs[String](0) +"="+buffer2.getAs[String](0)
  80. }
  81. /**
  82. * 指定返回值的内容 这里的返回类型只支持DataType对应的类型 不支持数组,集合和可变集合
  83. * e.g 支持StringType但是不支持Arry[String] ArryBuffer[String] 和 String[]
  84. * @param buffer
  85. * @return
  86. */
  87. override def evaluate(buffer: Row): Any = {
  88. // buffer.getAs[String](0)
  89. buffer.getString(0)
  90. }
  91. }

3.这里我在添加一个多个字段组合成一个字段的UDAF,这个类主要是方便大家对比理解其中的思想。
 


   
   
   
   
  1. package com.ql.UDAFUtil
  2. import org.apache.spark.sql.Row
  3. import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
  4. import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}
  5. /**
  6. * 将多行多个字段组合成一行一个字段
  7. */
  8. class kCollect_Set extends UserDefinedAggregateFunction{
  9. /*
  10. * collect_set(CONCAT(\"http://112.124.126.220:8080/slfh/download?path=\",da.filePath,da.fileName,\"&fileName=\",da.sourceFileName
  11. * String
  12. * |-- fileName: string (nullable = true)
  13. * |-- filePath: string (nullable = true)
  14. * String
  15. * |-- sourceFileName: string (nullable = true)
  16. */
  17. override def inputSchema: StructType = StructType(
  18. StructField("http",StringType)::
  19. StructField("filePath",StringType)::
  20. StructField("filePath",StringType)::
  21. StructField("symbol",StringType)::
  22. StructField("str",StringType)::
  23. StructField("sourceFileName",StringType)::
  24. Nil)
  25. override def bufferSchema: StructType = StructType(
  26. StructField("http_buffer",StringType)::
  27. Nil)
  28. override def dataType: DataType = StringType
  29. override def deterministic: Boolean = true
  30. override def initialize(buffer: MutableAggregationBuffer): Unit = {
  31. buffer(0)=""
  32. }
  33. override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  34. val input1 = input.getAs[String](0)
  35. val input2 = input.getAs[String](1)
  36. val input3 = input.getAs[String](2)
  37. val input4 = input.getAs[String](3)
  38. val input5 = input.getAs[String](4)
  39. val input6 = input.getAs[String](5)
  40. var str = ""
  41. // collect_set(CONCAT(\"http://112.124.126.220:8080/slfh/download?path=\",da.filePath,da.fileName,\"&fileName=\",da.sourceFileName
  42. if(input1!=null){ str = str+input1}
  43. if(input2!=null){ str = str+input2 }
  44. if(input3!=null){ str = str+input3 }
  45. if(input4!=null){ str = str+input4 }
  46. if(input5!=null){ str = str+input5 }
  47. if(input6!=null){ str = str+input6 }
  48. val buf = buffer.getAs[String](0)
  49. if(buf.trim.equals("")){
  50. buffer(0)=buf+str
  51. }else{
  52. buffer(0)=buf+"||"+str
  53. }
  54. }
  55. override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
  56. var str1 = buffer1.getString(0)
  57. var str2 = buffer2.getString(0)
  58. if(str1.trim.equals("")){
  59. buffer1(0)=str1+str2
  60. }else{
  61. buffer1(0)=str1+"||"+str2
  62. }
  63. }
  64. override def evaluate(buffer: Row): Any = {
  65. buffer.getString(0)
  66. }
  67. }

 4.总结spark sql中的udaf类的操作稍微复杂一点,但是并不是不能理解的,在我的第一个中中【myudaf】中有着详细的解释,在第二个实例类中有着多个字段的对比写法,这俩个实例可以但对比观看,查看其中区别和奥妙。在下面的文字中我会简单写出UDAF的番薯执行顺序和功能。
执行顺序,和我写的实例中的函数排列顺序一样
.a)中声明输入的类型,名字不用在意
b)中声明缓冲区中初始值的类型,名字不用在意
c)中声明输出类型
d)定义函数的确定性为真,这个的作用我没有看懂,在什么情况下可以为flase,读者知道的话,希望可以告诉我,我会留下联系方式
.e)初始化缓冲区的第一个值
.f)对分区/线程中的值进行操作.g
)分区/线程间的结果进行合并处理.h
)输出最后的结果
.i)在主函数中定义该类的函数,并在spark_sql中直接使用。
在主函数中定义如下:


   
   
   
   
  1. //生成collect_set函数,针对列级一个字段 整合为一个字段的函数
  2. sqlContext.udf.register("collect_set", new MyUDAF)
  3. //生成collect_sets函数,针对列级多字段 整合为一个字段的函数
  4. sqlContext.udf.register("collect_sets",new Collect_Set)

你可能感兴趣的:(spark)