spark TopN问题:dataframe和RDD比较

spark版本:spark 2.0.2

scala版本:2.11.8

服务器版本:CentOS 6.7

spark TopN问题,其实就是分组、排序、组内取值问题。

在shell下输入

spark-shell

 进入spark后输入以下命令:

//使用dataframe解决spark TopN问题:分组、排序、取TopN
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

df.show
/*
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
|   0|  cat105|       1.3|
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   1|   cat23|       5.3|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   2|   cat68|       9.8|
|   3|    cat8|      35.6|
+----+--------+----------+
*/
  
val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)
//取Top1
val dfTop1 = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
//注意:row_number()在spark1.x版本中为rowNumber()
//取Top3
val dfTop3 = df.withColumn("rn", row_number.over(w)).where($"rn" <= 3).drop("rn")

dfTop1.show
/*
+----+--------+----------+                                                      
|Hour|Category|TotalValue|
+----+--------+----------+
|   1|   cat67|      28.5|
|   3|    cat8|      35.6|
|   2|   cat56|      39.6|
|   0|   cat26|      30.9|
+----+--------+----------+
*/
dfTop3.show
/*
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   3|    cat8|      35.6|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
+----+--------+----------+
*/


//使用RDD解决spark TopN问题:分组、排序、取TopN

val rdd1 = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6)))

val rdd2 = rdd1.map(x => (x._1,(x._2, x._3))).groupByKey()
/*
rdd2.collect
res9: Array[(Int, Iterable[(String, Double)])] = Array((0,CompactBuffer((cat26,30.9), (cat13,22.1), (cat95,19.6), (cat105,1.3))), 
                                                       (1,CompactBuffer((cat67,28.5), (cat4,26.8), (cat13,12.6), (cat23,5.3))), 
													   (2,CompactBuffer((cat56,39.6), (cat40,29.7), (cat187,27.9), (cat68,9.8))), 
													   (3,CompactBuffer((cat8,35.6))))

*/
val N_value = 3

val rdd3 = rdd2.map( x => {
    val i2 = x._2.toBuffer
    val i2_2 = i2.sortBy(_._2)
    if (i2_2.length > N_value) i2_2.remove(0, (i2_2.length - N_value))
    (x._1, i2_2.toIterable)
})

/*
 rdd3.collect
res8: Array[(Int, Iterable[(String, Double)])] = Array((0,ArrayBuffer((cat95,19.6), (cat13,22.1), (cat26,30.9))), 
                                                       (1,ArrayBuffer((cat13,12.6), (cat4,26.8), (cat67,28.5))), 
													   (2,ArrayBuffer((cat187,27.9), (cat40,29.7), (cat56,39.6))), 
													   (3,ArrayBuffer((cat8,35.6))))
*/

val rdd4 = rdd3.flatMap(x => {
    val y = x._2
    for (w <- y) yield (x._1, w._1, w._2)
})

rdd4.collect
/*
res3: Array[(Int, String, Double)] = Array((0,cat95,19.6), (0,cat13,22.1), (0,cat26,30.9), 
                                           (1,cat13,12.6), (1,cat4,26.8), (1,cat67,28.5), 
										   (2,cat187,27.9), (2,cat40,29.7), (2,cat56,39.6), 
										   (3,cat8,35.6))
*/

rdd4.toDF("Hour", "Category", "TotalValue").show
/*
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat95|      19.6|
|   0|   cat13|      22.1|
|   0|   cat26|      30.9|
|   1|   cat13|      12.6|
|   1|    cat4|      26.8|
|   1|   cat67|      28.5|
|   2|  cat187|      27.9|
|   2|   cat40|      29.7|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
+----+--------+----------+
*/

  

  参考资料:

http://stackoverflow.com/questions/33878370/spark-dataframe-select-the-first-row-of-each-group
《Spark MLlib机器学习》

 

转载于:https://www.cnblogs.com/Sarah-2017/p/6377822.html

你可能感兴趣的:(spark TopN问题:dataframe和RDD比较)