首先我们构造一份班级的成绩数据,这份数据有三列组成,第一列是考试科目category,第二列是学生的名字name,第三列是学生的成绩。如下:
val df = spark.createDataFrame(Seq(
("A", "Tom", 78),
("B", "James", 47),
("A", "Jim", 43),
("C", "James", 89),
("A", "Lee", 93),
("C", "Jim", 65),
("A", "James", 10),
("C", "Lee", 39),
("B", "Tom", 99),
("C", "Tom", 53),
("B", "Lee", 100),
("B", "Jim", 100)
)).toDF("category", "name", "score")
df.show(false)
输出:
+--------+-----+-----+ |category|name |score| +--------+-----+-----+ |A |Tom |78 | |B |James|47 | |A |Jim |43 | |C |James|89 | |A |Lee |93 | |C |Jim |65 | |A |James|10 | |C |Lee |39 | |B |Tom |99 | |C |Tom |53 | |B |Lee |100 | |B |Jim |100 | +--------+-----+-----+ |
Spark SQL从1.4开始支持窗口分析函数,我们可以使用窗口函数row_number来进行分组排序,然后在对每个分区取出TopN个元素。row_number函数作用于一个分区,并为该分区中的每条记录生成一个从1开始递增的序列号,这样在外层循环就可以通过过滤该序列号来获取特定的数据。
①使用窗口函数取TopN
val N = 3
val window = Window.partitionBy(col("category")).orderBy(col("score").desc)
val top3DF = df.withColumn("topn", row_number().over(window)).where(col("topn") <= N)
top3DF.show(false)
输出:
+--------+-----+-----+----+ |category|name |score|topn| +--------+-----+-----+----+ |B |Lee |100 |1 | |B |Jim |100 |2 | |B |Tom |99 |3 | |C |James|89 |1 | |C |Jim |65 |2 | |C |Tom |53 |3 | |A |Lee |93 |1 | |A |Tom |78 |2 | |A |Jim |43 |3 | +--------+-----+-----+----+ |
②也可以使用sql查询的方式
df.createOrReplaceTempView("grade")
val sql = "select category, name, score from (select category, name, score, row_number() over (partition by category order by score desc ) rank from grade) g where g.rank <= 3".stripMargin
val top3DFBySQL = spark.sql(sql)
top3DFBySQL.show(false)
使用原生RDD接口来获取TopN元素主要需要以下三个步骤:
// 使用RDD取Top
// step 1: 分组
val groupRDD = df.rdd.map(x => (x.getString(0), (x.getString(1), x.getInt(2)))).groupByKey()
// step 2: 排序并取TopN
val N = 3
val sortedRDD = groupRDD.map(x => {
val rawRows = x._2.toBuffer
val sortedRows = rawRows.sortBy(_._2.asInstanceOf[Integer])
// 取TopN元素
if (sortedRows.size > N) {
sortedRows.remove(0, (sortedRows.length - N))
}
(x._1, sortedRows.toIterator)
})
// step 3: 展开
val flatRDD = sortedRDD.flatMap(x => {
val y = x._2
for (w <- y) yield (x._1, w._1, w._2)
})
flatRDD.toDF("category", "name", "score").show(false)
输出:
+--------+-----+-----+ |category|name |score| +--------+-----+-----+ |B |Lee |100 | |B |Jim |100 | |B |Tom |99 | |C |James|89 | |C |Jim |65 | |C |Tom |53 | |A |Lee |93 | |A |Tom |78 | |A |Jim |43 | +--------+-----+-----+ --------------------- |
转载:https://blog.csdn.net/Xiejingfa/article/details/79831938