看到了一个forkjoin 框架,就想知道有了多线程为什么还需要有forkjoin。
准备
static long[] array;
/**
* 初始化一个long 数组 1-100000000
*/
static {
StopWatch watch = new StopWatch();
watch.start();
array = LongStream.rangeClosed(1, 100000000).toArray();
watch.stop();
System.out.println("初始化array耗时:" + watch.getTotalTimeMillis() + "ms");
}
/**
* 使用for循环计算0-100000000的sum
*
*/
public static void simpleFor() {
StopWatch watch = new StopWatch();
watch.start();
long sum = 0;
for (long i : array) {
sum += i;
}
watch.stop();
System.out.println("耗时:" + watch.getTotalTimeMillis() + "ms");
System.out.println("sum = " + sum);
}
注意:java8表达式的并发循环使用的是同一个线程池
/**
* 使用java8新特性计算0-100000000的sum
*
* @param args
*/
public static void java8ParallelFor() {
StopWatch watch = new StopWatch();
watch.start();
long sum = Arrays.stream(array).parallel().reduce(0,Long::sum);
watch.stop();
System.out.println("耗时:" + watch.getTotalTimeMillis() + "ms");
System.out.println("sum = " + sum);
}
/**
* 使用java8新特性计算0-100000000的sum
*
* @param args
*/
public static void java8For() {
StopWatch watch = new StopWatch();
watch.start();
long sum = Arrays.stream(array).reduce(0,Long::sum);
watch.stop();
System.out.println("耗时:" + watch.getTotalTimeMillis() + "ms");
System.out.println("sum = " + sum);
}
/**
* 使用多线程完成计算0-100000000的sum
* java8 future -> CompletableFuture 优化方案不用主动阻塞,自动回调
*/
public static void threadExecuteForCompletableFuture() {
//cpu数量
int parallel = Runtime.getRuntime().availableProcessors();
List<CompletableFuture<Long>> list = new ArrayList<>();
//线程数
ExecutorService executorService = Executors.newFixedThreadPool(parallel);
StopWatch watch = new StopWatch();
watch.start();
long result = 0;
int part = array.length/parallel;
int remainder = array.length%parallel;
for (int i = 0; i < parallel; i++) {
int from = part * i;
int to = parallel == (i + 1)? part * (i + 1) + remainder: part * (i + 1);
//异步实现sum 计算
list.add(CompletableFuture.supplyAsync(() -> {
long sum = 0L;
for (int start = from; start < to; start ++) {
sum += array[start];
}
return sum;
}, executorService));
}
for (Future<Long> longFuture : list) {
try {
//结果汇总
result += longFuture.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
// java8写法
// result = list.stream().map(future -> {
// long sum = 0L;
// try {
// sum = future.get();
// return sum;
// } catch (Exception e) {
// e.printStackTrace();
// }
// return sum;
// }).reduce(0L, Long::sum);
watch.stop();
//关闭线程池
executorService.shutdown();
System.out.println("耗时:" + watch.getTotalTimeMillis() + "ms");
System.out.println("sum = " + result);
}
/**
* 使用多线程完成计算0-100000000的sum
*/
public static void threadExecuteFor() {
int parallel = Runtime.getRuntime().availableProcessors();
List<Future<Long>> list = new ArrayList<>();
ExecutorService executorService = Executors.newFixedThreadPool(parallel);
StopWatch watch = new StopWatch();
watch.start();
long result = 0;
int part = array.length / parallel;
int remainder = array.length % parallel;
for (int i = 0; i < parallel; i++) {
int from = part * i;
int to = parallel == (i + 1)? part * (i + 1) + remainder: part * (i + 1);
list.add(executorService.submit(new SumTaskThread(array, from, to)));
}
for (Future<Long> longFuture : list) {
try {
result += longFuture.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
watch.stop();
executorService.shutdown();
System.out.println("耗时:" + watch.getTotalTimeMillis() + "ms");
System.out.println("sum = " + result);
}
/**
* 线程操作类
*/
static class SumTaskThread implements Callable<Long> {
private long[] numberArray;
private int start;
private int end;
public SumTaskThread(long[] numberArray, int start, int end) {
this.numberArray = numberArray;
this.start = start;
this.end = end;
}
@Override
public Long call() throws Exception {
long sum = 0L;
for (int i = start; i < end; i++) {
sum += numberArray[i];
}
return sum;
}
}
/**
* 通过forkjoin 完成对0-100000000的sum 计算
*
*/
public static void forkJoinFor() {
ForkJoinPool forkJoinPool = new ForkJoinPool(Runtime.getRuntime().availableProcessors());
StopWatch watch = new StopWatch();
watch.start();
MyForkJoin forkJoin = new MyForkJoin(array);
Long result = forkJoinPool.invoke(forkJoin);
watch.stop();
forkJoinPool.shutdown();
System.out.println("耗时:" + watch.getTotalTimeMillis() + "ms");
System.out.println("sum = " + result);
}
/**
* forkjoin 具体操作
*/
static class MyForkJoin extends RecursiveTask<Long> {
@Setter
/***任务切分临界点*/
private long MAX = 10000000L;
private long[] numbers;
private int from;
private int to;
public MyForkJoin(long[] numbers) {
this.numbers = numbers;
}
public MyForkJoin(long[] numbers, int from, int to) {
this.numbers = numbers;
this.from = from;
this.to = to;
}
public MyForkJoin(long MAX, long[] numbers, int from, int to) {
this.MAX = MAX;
this.numbers = numbers;
this.from = from;
this.to = to;
}
@Override
protected Long compute() {
//入参校验
if (ArrayUtil.isEmpty(numbers)) {
return 0L;
}
int length = numbers.length;
//是否拆分任务
if (notSplitTask(length)) {
//直接使用for循环计算
return completeSum();
}
//开始拆分任务
int mid= (from + to)/2;
MyForkJoin left = new MyForkJoin(array, from, mid);
MyForkJoin right = new MyForkJoin(array, mid + 1, to);
//并发执行
invokeAll(left, right);
return left.join() + right.join();
}
private long completeSum() {
long sum = 0;
for (long number : numbers) {
sum += number;
}
return sum;
}
//如果不满足这个条件永远拆分下去
private boolean notSplitTask(int length) {
return (to - from) <= MAX? true: false;
}
}
运行结果:
初始化array耗时:423ms
普通for循环 ------------------------------------------
耗时:68ms
sum = 5000000050000000
java 8 parallel循环 ------------------------------
耗时:31ms
sum = 5000000050000000
java 8 循环 ---------------------------------------
耗时:50ms
sum = 5000000050000000
java 8 CompletableFuture 多线程---------------
耗时:31ms
sum = 5000000050000000
threadExecute 多线程--------------------------------
耗时:34ms
sum = 5000000050000000
forkjoin -------------------------------------
耗时:58ms
sum = 5000000050000000