需求:按餐品分组,并求出无优惠金额的订单数。
package cd.custom.jde.job.udf
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* create by roy 2020-02-12
* 去重订单,并判断是否是折扣
*/
class CountDistinctAndIf extends UserDefinedAggregateFunction {
override def inputSchema: StructType = {
new StructType().add("orderid", StringType, nullable = true)
.add("price", DoubleType, nullable = true)
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// println("update==>>>", buffer,input,input.getDouble(1) <= 0) //=1,说是折扣的
if (input.getDouble(1) <= 0) {
//取出新加入的行,并加入缓存区
buffer(0) = (buffer.getSeq[String](0).toSet + input.getString(0)).toSeq
}
}
override def bufferSchema: StructType = {
new StructType().add("items", ArrayType(StringType, true), nullable = true)
// .add("price", DoubleType, nullable = true)
}
//合并数据
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// println("merge==>", buffer2)
// if (buffer2 != null && buffer2.size >= 2 && buffer2.get(1) != null && buffer2.get(0) != null && buffer2.getDouble(1) > 0) {
buffer1(0) = (buffer1.getSeq[String](0).toSet ++ buffer2.getSeq[String](0).toSet).toSeq
}
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Seq[String]()
}
override def deterministic: Boolean = true
override def evaluate(buffer: Row): Any = {
buffer.getSeq[String](0).length
}
override def dataType: DataType = IntegerType
}
实例应用:
package spark.udf
import cd.custom.jde.job.udf.CountDistinctAndIf
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.types.{DoubleType, StringType, StructType}
import org.apache.spark.sql.{Row, SparkSession}
object MyOrderTest {
Logger.getRootLogger.setLevel(Level.WARN)
def main(args: Array[String]): Unit = {
val data = Seq(
Row("a", "a100", 0.0, "300"),
Row("a", "a100", 7.0, "300"),
Row("a", "a101", 6.0, "300"),
Row("a", "a101", 5.0, "301"),
Row("a", "a100", 0.0, "300")
)
val schme = new StructType()
.add("storeid", StringType)
.add("orderid", StringType)
.add("yhPrice", DoubleType)
.add("pid", StringType)
val spark = SparkSession.builder().master("local[*]").getOrCreate()
val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schme)
df.show()
df.createOrReplaceTempView("tab_tmp")
val cCountDistinct2 = new CountDistinctAndIf
spark.sqlContext.udf.register("cCountDistinct2", cCountDistinct2)
spark.sql(
"""
|select pid,count(1) pid_num,
|sum(if(yhPrice<=0,1,0)) as zk_all_order_num,
|cCountDistinct2(orderid,yhPrice) as zk_order_num
|from tab_tmp group by pid
""".stripMargin).show()
}
/* +---+-------+----------------+------------+
|pid|pid_num|zk_all_order_num|zk_order_num|
+---+-------+----------------+------------+
|300| 4| 2| 1|
|301| 1| 0| 0|
+---+-------+----------------+------------+*/
}