SparkSQL(5):UDF和UDAF

1.二者区别

UDF:用户自定义函数,一输入一输出

UDAF:用户自定义聚合函数,多输入一输出

2.实现代码

(1)UDAF代码:

package _0728sql

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * 
  */
object Avg_UDAF extends UserDefinedAggregateFunction{
  //(a)
  override def inputSchema: StructType = {
    /**
      * 给定UDAF函数的输入参数类型(schema)
      * iv代表input value
      */
    StructType(Array(
      StructField("iv",DoubleType)
    ))
  }
  //(b)
  override def bufferSchema: StructType = {
    //给定缓存数据的数据类型 avg = totalValue / totalCount
    //tv:total value
    //tc:total count
    StructType(Array(
      StructField("tv",DoubleType),
      StructField("tc",IntegerType)
    ))
  }
  //(c)
  override def dataType: DataType = {
    //给定返回的数据类型
    DoubleType
  }
  //(d)
  override def deterministic: Boolean = {
    //给定多次运行是否允许返回结果不一致(模糊查询) true表示不允许
    //一般都为true
    true
  }
  //(e)
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //缓存数据的初始值
    buffer.update(0,0.0)
    buffer.update(1,0)

  }
  //(f)
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //对于每一条输入数据(当前分组的),更新buffer中的值
    //1、获取输入数据
    val  iv = input.getDouble(0)
    //2、获取缓存区数据
    val tv = buffer.getDouble(0)
    val tc = buffer.getInt(1)
    //3、更新缓存区数据
    buffer.update(0,tv + iv)
    buffer.update(1,tc + 1)
  }
  //(g)
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //当两个分区的结果需要进行合并的时候,会调用该merge方法
    //1、获取buffer1的数据
    val tv1 = buffer1.getDouble(0)
    val tc1 = buffer1.getInt(1)
    //2、获取buffer2的数据
    val tv2 = buffer2.getDouble(0)
    val tc2 = buffer2.getInt(1)
    /*
     3、然后把数据更新到buffer1当中去,不能更新到buffer2
       因为MutableAggregationBuffer这个数据类型才是可以更新的数据类型实现了update方法
     */
    buffer1.update(0,tv1+tv2)
    buffer1.update(1,tc1+tc2)
  }
  //(h)
  override def evaluate(buffer: Row): Any = {
    val tv = buffer.getDouble(0)
    val tc = buffer.getInt(1)
    tv/tc
  }
}

(2)总代码

package _0728sql

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{SQLContext, SparkSession}
//import _0728sql.Avg_UDAF
/**
  *
  */
object UDFandUDAF extends App{

  /**
    *
    */
  val conf = new SparkConf()
    .setMaster("local[*]")
    .setAppName("UDFandUDAF")
  //这个方法是一个锁的机制,通过这个方法可以保证只有一个上下文
  val sc = SparkContext.getOrCreate(conf)
  //如果不需要用hive就不要用hivecontext,使用sqlcontext就可以了
  val sqlContext = new SQLContext(sc)

  //1.UDF
// udf 保留小数点后两位
// format_double是函数名称,后面是匿名函数
  sqlContext.udf.register("format_double",(value:Double)=>{
    import java.math.BigDecimal
    val bd=new BigDecimal(value)
    bd.setScale(2,BigDecimal.ROUND_HALF_UP).doubleValue()
  })

import sqlContext.implicits._
  sc.parallelize(Array(
    (1, 1234),
    (1, 45212),
    (1, 22125),
    (1, 12521),
    (1, 12352),
    (2, 52352),
    (2, 2232),
    (2, 12521),
    (2, 12323),
    (3, 2253),
    (3, 2233),
    (3, 22558),
    (4, 252),
    (4, 235),
    (5, 523)
  )).toDF("id", "sal").registerTempTable("tmp_emp")

  sqlContext.sql(
    """
      			|select
      			|id,AVG(sal) as sal1,
      			|format_double(AVG(sal)) as sal3
      			|from tmp_emp
      			|group by id
    		""".stripMargin).show

  //2.UDAF
  sqlContext.udf.register("self_avg",Avg_UDAF)
  sqlContext.sql(
    """
      			|select
      			|id,AVG(sal) as sal1,
      			|format_double(AVG(sal)) as sal3,
            |format_double(self_avg(sal)) as sal4
      			|from tmp_emp
      			|group by id
    		""".stripMargin).show

}

 

你可能感兴趣的:(大数据开发,Spark,SparkSQL,SparkSQL,UDF,UDAF)