如何针对一个key-value的RDD按照key进行分组然后对value进行排序。典型的应用场景如我有一个淘宝交易流水的文件,第一列是店铺名称,第二列是交易货物,第三列是交易价格。我想知道每个店铺交易价格最高的十个货物是什么。
常见的做法可能是
spark.sparkContext.textFile("path").map(line=>{
val lineArr=line.split("\t")
(lineArr(0),(lineArr(1),lineArr(2).toDouble))
}).groupByKey().map(kv=>{
val (store,iter)=kv
val top10: List[(String, Double)] = iter.toList.sortBy(e=> (-1) * e._2).take(10)
(store,top10)
})
这样的做法针对数据量小的时候我们是可以处理的,但是一旦数据量上来以后,在将迭代器转成内存中的List并排序的时候很有可能会发生内存溢出。有没有更好的方案呢?
那当然是有的!就是利用spark的二次排序功能,就是将排序的过程放到shuffle去做。因为spark在shuffle的过程中做了很多优化,比如内存不够数据将回落到磁盘上等(具体可以参考spark SortMergeShuffle),所以发生内存溢出的风险将大大降低。那具体来说怎么做的?
首先呢,我们需要介绍一下repartitionAndSortWithinPartitions这个算子,该算子的使用需要提供一个partitioner参数。顾名思义,该算子就是按照用户提供的partiioner将rdd重新分区,并且分区内的数据是有序的,这个顺序也可以让用户来指定。然后我们将rdd变形一下,形成的数据是RDD[((商铺,价格),货物)],这个时候rdd的key是一个Tuple2。如果我们按照商铺进行分区,并且按照商铺+价格进行排序的话,神奇的事情发生了。
首先商铺一下的数据肯定回落在一个分区内
又因为我们按照了商铺+价格进行排序了,所以商铺相同的数据肯定会紧挨着,然后按照价格进行排序!
这样以来,如果我们想实现上述的需求,就可以遍历一个partition的迭代器,如果碰见了一个商铺,就取前十个数据即可。是不是很简单啦。
好啦废话不多数我们直接看代码!
进行分组排序所需要的四个步骤
- 自定义分区类将相同key的record分到同一个记录里面
- 定义一个隐式ordering。该ordering将会对key和value进行排序
- 使用repartitionAndSortWithinPartitions算子
- 对上一步生成的rdd进行mapPartition操作。将相同key的record放到一个Iterator中
下面代码展示了如何对 (r,AAA(r))
进行分组排序的例子。
case class AAA(num:Long)
object SecondSortTest{
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("user push task").setMaster("local[3]")
val sc = new SparkContext(conf)
val sourceRdd=sc.range(0,100).map(r=>{
((r,AAA(r)),1)
})
implicit val ordering:Ordering[AAA]=Ordering.by[AAA,Long](s=>s.num)
groupByKeyAndsortBySecondaryKey(sourceRdd,4).collect().foreach(println)
}
def groupByKeyAndsortBySecondaryKey[K: Ordering : ClassTag,
S: Ordering : ClassTag,
V: ClassTag](pairRDD: RDD[((K, S), V)], partitions: Int) = {
val colValuePartitioner = new PrimarykeyParititioner[Double, Int](partitions)
implicit val ordering: Ordering[(K, S)] = Ordering.Tuple2
pairRDD.repartitionAndSortWithinPartitions(colValuePartitioner).mapPartitions(iter => {
groupSorted(iter).toIterator
}).map { case (key, buf) => (key, buf) }
}
def groupSorted[K, S, V](it: Iterator[((K, S), V)]) = {
val res = List[(K, ArrayBuffer[(S, V)])]()
it.foldLeft(res)((list, next) => list match {
case Nil =>
val ((firstKey, secondKey), value) = next
List((firstKey, ArrayBuffer((secondKey, value))))
case head :: rest =>
val (curKey, valueBuf) = head
val ((firstKey, secondKey), value) = next
if (!firstKey.equals(curKey)) {
(firstKey, ArrayBuffer((secondKey, value))) :: list
} else {
valueBuf.append((secondKey, value))
list
}
})
}
}
结束语:试想一下,在spark-sql和hive中都有row_number函数,那他们是怎么用rdd来实现的呢?感兴趣的读者可以期待下一篇的内容