自己随意编写一份测试数据,所用的测试数据如下,需求是按照第一列的字母分组,然后按照第二列数据取出每一组内前N个数据,后面我分别列出了我使用的三种方案来实现该需求,不同方案在不同的场景下会有各自的优势
a 25 b 36 c 24 d 45 e 60 a 33 b 26 c 47 d 43 e 62 a 13 b 16 c 42 d 66 e 31 a 19 b 75 c 61 d 71 e 80 a 85 b 90 c 54 d 48 e 62 |
第一种方式:适合求每一组别中所需要的top个数很大的情况,是对数据分组后对每一个组内进行排序,先获得所有组的key的集合,然后循环每个key排序,最后只需要采用take(num)即可得到前num个数据。
import java.util.Iterator;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import scala.Tuple2;
/*适合求每一组别中所需要的top个数很大的情况*/
public class GroupTopN {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("GroupTopN").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf);
JavaPairRDD> grouppair = sc.textFile("E:/mr/grouptopn.txt").mapToPair(new PairFunction() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2 call(String line) throws Exception {
return new Tuple2(line.split(" ")[0],Integer.parseInt(line.split(" ")[1]));
}
}).groupByKey();
// System.out.println(linepair.count());
List keys = grouppair.map(new Function>, String>() {
private static final long serialVersionUID = 1L;
@Override
public String call(Tuple2> tuple) throws Exception {
return tuple._1;
}
}).collect();
for (int i = 0; i < keys.size(); i++) {
System.out.println(keys.get(i));
final int key = i;
JavaPairRDD result = grouppair.filter(new Function>, Boolean>() {
private static final long serialVersionUID = 1L;
@Override
public Boolean call(Tuple2> tuple) throws Exception {
return tuple._1.equals(keys.get(key));
}
}).flatMap(new FlatMapFunction>, Integer>() {
private static final long serialVersionUID = 1L;
@Override
public Iterator call(Tuple2> tuple) throws Exception {
return tuple._2.iterator();
}
}).mapToPair(new PairFunction() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2 call(Integer in) throws Exception {
return new Tuple2(in, keys.get(key));
}
}).sortByKey(false).mapToPair(new PairFunction, String, Integer>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2 call(Tuple2 tuple) throws Exception {
return new Tuple2(tuple._2, tuple._1);
}
});
List> list = result.take(4);
for (Tuple2 tuple2 : list) {
System.out.println(tuple2._1+" "+tuple2._2);
}
// result.foreach(new VoidFunction>() {
//
// private static final long serialVersionUID = 1L;
//
// @Override
// public void call(Tuple2 tuple) throws Exception {
// System.out.println(tuple._1+" "+tuple._2);
//
// }
// });
}
sc.close();
}
}
测试结果如下:
2018-07-17 10:19:41 INFO DAGScheduler:54 - Job 0 finished: collect at GroupTopN.java:44, took 0.374648 s …… |
第二种方式:当所求组内数据量不大时采取此方法较为合适,是将每一个组内的value值都存放在list中,使用Collections.sort(List
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import scala.Tuple2;
/*当所求组内数据量不大时采取此方法较为合适*/
public class GroupTopN2 {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("GroupTopN2").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf);
JavaPairRDD> grouppair = sc.textFile("E:/mr/grouptopn.txt").mapToPair(new PairFunction() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2 call(String line) throws Exception {
return new Tuple2(line.split(" ")[0],Integer.parseInt(line.split(" ")[1]));
}
}).groupByKey();
System.out.println(grouppair.count());
JavaPairRDD> result = grouppair.mapToPair(new PairFunction>, String, Iterable>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2> call(Tuple2> tuple) throws Exception {
List list = new ArrayList<>();
Iterator it = tuple._2.iterator();
while (it.hasNext()) {
int in = it.next();
list.add(in);
}
Collections.sort(list, new Comparator() {
@Override
public int compare(Integer o1, Integer o2) {
return -o1.compareTo(o2);
}
});
List re = list.subList(0, 4);
return new Tuple2>(tuple._1, re);
}
});
result.foreach(new VoidFunction>>() {
private static final long serialVersionUID = 1L;
@Override
public void call(Tuple2> tuple) throws Exception {
System.out.println(tuple._1+" "+tuple._2);
}
});
sc.close();
}
}
测试结果如下:
d [71, 66, 48, 45] e [80, 62, 62, 60] a [85, 33, 25, 19] b [90, 75, 36, 26] c [61, 54, 47, 42] |
第三种方式:利用插入排序的思想,适用于那些组内数据量大,但所取top数量较小时,定义一个大小为N的数组存TopN的数据
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import scala.Tuple2;
/*利用插入排序的思想,适用于那些组内数据量大,但所取top数量较小时*/
public class GroupTopN3 {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("GroupTopN3").setMaster("local");
JavaSparkContext sc = new JavaSparkContext(conf);
JavaPairRDD> grouppair = sc.textFile("E:/mr/grouptopn.txt").mapToPair(new PairFunction() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2 call(String line) throws Exception {
return new Tuple2(line.split(" ")[0],Integer.parseInt(line.split(" ")[1]));
}
}).groupByKey();
System.out.println(grouppair.count());
JavaPairRDD> result = grouppair.mapToPair(new PairFunction>, String, Iterable>() {
private static final long serialVersionUID = 1L;
@Override
public Tuple2> call(Tuple2> tuple) throws Exception {
List list = new ArrayList<>();
int[] arrin = new int[4];
Iterator it = tuple._2.iterator();
while (it.hasNext()) {
int in = it.next();
for (int i = 0; i < arrin.length; i++) {
if (in>arrin[i]) {
for(int j =arrin.length-1;j>i;j--){
arrin[j] = arrin[j-1];
}
arrin[i] = in;
break;
}
}
}
for (int i = 0; i < arrin.length; i++) {
list.add(arrin[i]);
}
return new Tuple2>(tuple._1, list);
}
});
result.foreach(new VoidFunction>>() {
private static final long serialVersionUID = 1L;
@Override
public void call(Tuple2> tuple) throws Exception {
System.out.println(tuple._1+" "+tuple._2);
}
});
sc.close();
}
}
测试结果如下:
d [71, 66, 48, 45] e [80, 62, 62, 60] a [85, 33, 25, 19] b [90, 75, 36, 26] c [61, 54, 47, 42] |