Spark实现分组并求每一组内TopN(Java)——三种不同情形下适用的三种方法

自己随意编写一份测试数据,所用的测试数据如下,需求是按照第一列的字母分组,然后按照第二列数据取出每一组内前N个数据,后面我分别列出了我使用的三种方案来实现该需求,不同方案在不同的场景下会有各自的优势

 

a 25
b 36
c 24
d 45
e 60
a 33
b 26
c 47
d 43
e 62
a 13
b 16
c 42
d 66
e 31
a 19
b 75
c 61
d 71
e 80
a 85
b 90
c 54
d 48
e 62

 

第一种方式:适合求每一组别中所需要的top个数很大的情况,是对数据分组后对每一个组内进行排序,先获得所有组的key的集合,然后循环每个key排序,最后只需要采用take(num)即可得到前num个数据。

import java.util.Iterator;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;

import scala.Tuple2;

/*适合求每一组别中所需要的top个数很大的情况*/

public class GroupTopN {
	public static void main(String[] args) {
		SparkConf conf = new SparkConf().setAppName("GroupTopN").setMaster("local");
		JavaSparkContext sc = new JavaSparkContext(conf);
		
		JavaPairRDD> grouppair = sc.textFile("E:/mr/grouptopn.txt").mapToPair(new PairFunction() {

			private static final long serialVersionUID = 1L;

			@Override
			public Tuple2 call(String line) throws Exception {
				return new Tuple2(line.split(" ")[0],Integer.parseInt(line.split(" ")[1]));
			}
		}).groupByKey();
		
//		System.out.println(linepair.count());
		
		List keys = grouppair.map(new Function>, String>() {

			private static final long serialVersionUID = 1L;

			@Override
			public String call(Tuple2> tuple) throws Exception {

				return tuple._1;
			}
		}).collect();
		
		for (int i = 0; i < keys.size(); i++) {
			System.out.println(keys.get(i));
			
			final int key = i;
			JavaPairRDD result = grouppair.filter(new Function>, Boolean>() {

				private static final long serialVersionUID = 1L;

				@Override
				public Boolean call(Tuple2> tuple) throws Exception {
					return tuple._1.equals(keys.get(key));
				}
			}).flatMap(new FlatMapFunction>, Integer>() {

				private static final long serialVersionUID = 1L;
	
				@Override
				public Iterator call(Tuple2> tuple) throws Exception {
	
					return tuple._2.iterator();
				}
			}).mapToPair(new PairFunction() {
	
				private static final long serialVersionUID = 1L;
	
				@Override
				public Tuple2 call(Integer in) throws Exception {
		
					return new Tuple2(in, keys.get(key));
				}
			}).sortByKey(false).mapToPair(new PairFunction, String, Integer>() {
	
				private static final long serialVersionUID = 1L;
	
				@Override
				public Tuple2 call(Tuple2 tuple) throws Exception {

				return new Tuple2(tuple._2, tuple._1);
			}
		});
		
			
		List>  list = result.take(4);	
		
		for (Tuple2 tuple2 : list) {
			System.out.println(tuple2._1+" "+tuple2._2);
		}
//		result.foreach(new VoidFunction>() {
//	
//			private static final long serialVersionUID = 1L;
//
//			@Override
//			public void call(Tuple2 tuple) throws Exception {
//				System.out.println(tuple._1+" "+tuple._2);
//				
//			}
//		});
		}
		sc.close();
	}

}

测试结果如下:

2018-07-17 10:19:41 INFO  DAGScheduler:54 - Job 0 finished: collect at GroupTopN.java:44, took 0.374648 s
d
2018-07-17 10:19:41 INFO  SparkContext:54 - Starting job: take at

……
GroupTopN.java:88, took 0.080054 s
d 71
d 66
d 48
d 45

 

第二种方式:当所求组内数据量不大时采取此方法较为合适,是将每一个组内的value值都存放在list中,使用Collections.sort(List,Comparator)给list排序,重写public int compare(Object o1, Object o2) 方法自定义实现逆序排列,然后再截取list。

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;

import scala.Tuple2;

/*当所求组内数据量不大时采取此方法较为合适*/

public class GroupTopN2 {
	public static void main(String[] args) {
		SparkConf conf = new SparkConf().setAppName("GroupTopN2").setMaster("local");
		JavaSparkContext sc = new JavaSparkContext(conf);
		
		JavaPairRDD> grouppair = sc.textFile("E:/mr/grouptopn.txt").mapToPair(new PairFunction() {

			private static final long serialVersionUID = 1L;

			@Override
			public Tuple2 call(String line) throws Exception {
				return new Tuple2(line.split(" ")[0],Integer.parseInt(line.split(" ")[1]));
			}
		}).groupByKey();
		
		System.out.println(grouppair.count());
		
		JavaPairRDD> result = grouppair.mapToPair(new PairFunction>, String, Iterable>() {

			private static final long serialVersionUID = 1L;

			@Override
			public Tuple2> call(Tuple2> tuple) throws Exception {
				
				List list = new ArrayList<>();
				Iterator it = tuple._2.iterator();
				while (it.hasNext()) {
					int in = it.next();
					list.add(in);
				}
				Collections.sort(list, new Comparator() {

					@Override
					public int compare(Integer o1, Integer o2) {
						return -o1.compareTo(o2);
					}
				});
				List re = list.subList(0, 4);
				
				return new Tuple2>(tuple._1, re);
			}
		});
		
		result.foreach(new VoidFunction>>() {

			private static final long serialVersionUID = 1L;

			@Override
			public void call(Tuple2> tuple) throws Exception {
				System.out.println(tuple._1+" "+tuple._2);
			}
		});
		
		sc.close();
	}

}

测试结果如下:

d [71, 66, 48, 45]
e [80, 62, 62, 60]
a [85, 33, 25, 19]
b [90, 75, 36, 26]
c [61, 54, 47, 42]

 

第三种方式:利用插入排序的思想,适用于那些组内数据量大,但所取top数量较小时,定义一个大小为N的数组存TopN的数据

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;

import scala.Tuple2;

/*利用插入排序的思想,适用于那些组内数据量大,但所取top数量较小时*/

public class GroupTopN3 {
	public static void main(String[] args) {
		SparkConf conf = new SparkConf().setAppName("GroupTopN3").setMaster("local");
		JavaSparkContext sc = new JavaSparkContext(conf);
		
		JavaPairRDD> grouppair = sc.textFile("E:/mr/grouptopn.txt").mapToPair(new PairFunction() {

			private static final long serialVersionUID = 1L;

			@Override
			public Tuple2 call(String line) throws Exception {
				return new Tuple2(line.split(" ")[0],Integer.parseInt(line.split(" ")[1]));
			}
		}).groupByKey();
		
		System.out.println(grouppair.count());
		
		JavaPairRDD> result = grouppair.mapToPair(new PairFunction>, String, Iterable>() {

			private static final long serialVersionUID = 1L;

			@Override
			public Tuple2> call(Tuple2> tuple) throws Exception {
				
				List list = new ArrayList<>();
				int[] arrin = new int[4];
				Iterator it = tuple._2.iterator();
				while (it.hasNext()) {
					int in = it.next();
					for (int i = 0; i < arrin.length; i++) {
						if (in>arrin[i]) {
							for(int j =arrin.length-1;j>i;j--){
								arrin[j] = arrin[j-1];
							}
							arrin[i] = in;
							break;
						}
					}
				}
				
				for (int i = 0; i < arrin.length; i++) {
					list.add(arrin[i]);
				}
				
				return new Tuple2>(tuple._1, list);
			}
		});
		
		result.foreach(new VoidFunction>>() {

			private static final long serialVersionUID = 1L;

			@Override
			public void call(Tuple2> tuple) throws Exception {
				System.out.println(tuple._1+" "+tuple._2);
			}
		});
		
		sc.close();
	}
}

测试结果如下:

d [71, 66, 48, 45]
e [80, 62, 62, 60]
a [85, 33, 25, 19]
b [90, 75, 36, 26]
c [61, 54, 47, 42]

你可能感兴趣的:(Spark,Java)