spark版本:spark 2.0.2
scala版本:2.11.8
服务器版本:CentOS 6.7
spark TopN问题,其实就是分组、排序、组内取值问题。
在shell下输入
spark-shell
进入spark后输入以下命令:
//使用dataframe解决spark TopN问题:分组、排序、取TopN import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ val df = sc.parallelize(Seq( (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3), (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3), (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8), (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue") df.show /* +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 0| cat13| 22.1| | 0| cat95| 19.6| | 0| cat105| 1.3| | 1| cat67| 28.5| | 1| cat4| 26.8| | 1| cat13| 12.6| | 1| cat23| 5.3| | 2| cat56| 39.6| | 2| cat40| 29.7| | 2| cat187| 27.9| | 2| cat68| 9.8| | 3| cat8| 35.6| +----+--------+----------+ */ val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc) //取Top1 val dfTop1 = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn") //注意:row_number()在spark1.x版本中为rowNumber() //取Top3 val dfTop3 = df.withColumn("rn", row_number.over(w)).where($"rn" <= 3).drop("rn") dfTop1.show /* +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 1| cat67| 28.5| | 3| cat8| 35.6| | 2| cat56| 39.6| | 0| cat26| 30.9| +----+--------+----------+ */ dfTop3.show /* +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 1| cat67| 28.5| | 1| cat4| 26.8| | 1| cat13| 12.6| | 3| cat8| 35.6| | 2| cat56| 39.6| | 2| cat40| 29.7| | 2| cat187| 27.9| | 0| cat26| 30.9| | 0| cat13| 22.1| | 0| cat95| 19.6| +----+--------+----------+ */ //使用RDD解决spark TopN问题:分组、排序、取TopN val rdd1 = sc.parallelize(Seq( (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3), (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3), (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8), (3,"cat8",35.6))) val rdd2 = rdd1.map(x => (x._1,(x._2, x._3))).groupByKey() /* rdd2.collect res9: Array[(Int, Iterable[(String, Double)])] = Array((0,CompactBuffer((cat26,30.9), (cat13,22.1), (cat95,19.6), (cat105,1.3))), (1,CompactBuffer((cat67,28.5), (cat4,26.8), (cat13,12.6), (cat23,5.3))), (2,CompactBuffer((cat56,39.6), (cat40,29.7), (cat187,27.9), (cat68,9.8))), (3,CompactBuffer((cat8,35.6)))) */ val N_value = 3 val rdd3 = rdd2.map( x => { val i2 = x._2.toBuffer val i2_2 = i2.sortBy(_._2) if (i2_2.length > N_value) i2_2.remove(0, (i2_2.length - N_value)) (x._1, i2_2.toIterable) }) /* rdd3.collect res8: Array[(Int, Iterable[(String, Double)])] = Array((0,ArrayBuffer((cat95,19.6), (cat13,22.1), (cat26,30.9))), (1,ArrayBuffer((cat13,12.6), (cat4,26.8), (cat67,28.5))), (2,ArrayBuffer((cat187,27.9), (cat40,29.7), (cat56,39.6))), (3,ArrayBuffer((cat8,35.6)))) */ val rdd4 = rdd3.flatMap(x => { val y = x._2 for (w <- y) yield (x._1, w._1, w._2) }) rdd4.collect /* res3: Array[(Int, String, Double)] = Array((0,cat95,19.6), (0,cat13,22.1), (0,cat26,30.9), (1,cat13,12.6), (1,cat4,26.8), (1,cat67,28.5), (2,cat187,27.9), (2,cat40,29.7), (2,cat56,39.6), (3,cat8,35.6)) */ rdd4.toDF("Hour", "Category", "TotalValue").show /* +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat95| 19.6| | 0| cat13| 22.1| | 0| cat26| 30.9| | 1| cat13| 12.6| | 1| cat4| 26.8| | 1| cat67| 28.5| | 2| cat187| 27.9| | 2| cat40| 29.7| | 2| cat56| 39.6| | 3| cat8| 35.6| +----+--------+----------+ */
参考资料:
http://stackoverflow.com/questions/33878370/spark-dataframe-select-the-first-row-of-each-group
《Spark MLlib机器学习》