spark UDAF 自定义聚合函数 UserDefinedAggregateFunction 带条件的去重操作

需求:按餐品分组,并求出无优惠金额的订单数。

package cd.custom.jde.job.udf

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * create by roy 2020-02-12
  * 去重订单,并判断是否是折扣
  */
class CountDistinctAndIf extends UserDefinedAggregateFunction {

  override def inputSchema: StructType = {
    new StructType().add("orderid", StringType, nullable = true)
      .add("price", DoubleType, nullable = true)
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //    println("update==>>>", buffer,input,input.getDouble(1) <= 0) //=1,说是折扣的
    if (input.getDouble(1) <= 0) {
      //取出新加入的行,并加入缓存区
      buffer(0) = (buffer.getSeq[String](0).toSet + input.getString(0)).toSeq
    }
  }

  override def bufferSchema: StructType = {
    new StructType().add("items", ArrayType(StringType, true), nullable = true)
    //      .add("price", DoubleType, nullable = true)
  }

  //合并数据
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//    println("merge==>", buffer2)
    //    if (buffer2 != null && buffer2.size >= 2 && buffer2.get(1) != null && buffer2.get(0) != null && buffer2.getDouble(1) > 0) {
    buffer1(0) = (buffer1.getSeq[String](0).toSet ++ buffer2.getSeq[String](0).toSet).toSeq
  }

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Seq[String]()
  }

  override def deterministic: Boolean = true

  override def evaluate(buffer: Row): Any = {
    buffer.getSeq[String](0).length
  }

  override def dataType: DataType = IntegerType
}

实例应用:

package spark.udf

import cd.custom.jde.job.udf.CountDistinctAndIf
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.types.{DoubleType, StringType, StructType}
import org.apache.spark.sql.{Row, SparkSession}

object MyOrderTest {

  Logger.getRootLogger.setLevel(Level.WARN)

  def main(args: Array[String]): Unit = {

    val data = Seq(
      Row("a", "a100", 0.0, "300"),
      Row("a", "a100", 7.0, "300"),
      Row("a", "a101", 6.0, "300"),
      Row("a", "a101", 5.0, "301"),
      Row("a", "a100", 0.0, "300")
    )
    val schme = new StructType()
      .add("storeid", StringType)
      .add("orderid", StringType)
      .add("yhPrice", DoubleType)
      .add("pid", StringType)
    val spark = SparkSession.builder().master("local[*]").getOrCreate()
    val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schme)
    df.show()
    df.createOrReplaceTempView("tab_tmp")

    val cCountDistinct2 = new CountDistinctAndIf
    spark.sqlContext.udf.register("cCountDistinct2", cCountDistinct2)
    spark.sql(
      """
        |select pid,count(1) pid_num,
        |sum(if(yhPrice<=0,1,0)) as zk_all_order_num,
        |cCountDistinct2(orderid,yhPrice) as zk_order_num
        |from tab_tmp group by pid
      """.stripMargin).show()

  }

 /* +---+-------+----------------+------------+
  |pid|pid_num|zk_all_order_num|zk_order_num|
  +---+-------+----------------+------------+
  |300|      4|               2|           1|
  |301|      1|               0|           0|
  +---+-------+----------------+------------+*/
}

 

你可能感兴趣的:(大数据,spark)