它同ThreadPoolExecutor一样,也实现了Executor和ExecutorService接口。它使用了一个无限队列来保存需要执行的任务,而线程的数量则是通过构造函数传入,如果没有向构造函数中传入希望的线程数量,那么当前计算机可用的CPU数量会被设置为线程数量作为默认值。
ForkJoinPool主要用来使用分治法(Divide-and-Conquer Algorithm)来解决问题。典型的应用比如快速排序算法。
这里的要点在于,ForkJoinPool需要使用相对少的线程(默认系统自带cpu核数)来处理大量的任务。
比如要对1000万个数据进行排序,那么会将这个任务分割成两个500万的排序任务和一个针对这两组500万数据的合并任务。以此类推,对于500万的数据也会做出同样的分割处理,到最后会设置一个阈值来规定当数据规模到多少时,停止这样的分割处理。比如,当元素的数量小于10时,会停止分割,转而使用插入排序对它们进行排序。
那么到最后,所有的任务加起来会有大概2000000+个。问题的关键在于,对于一个任务而言,只有当它所有的子任务完成之后,它才能够被执行。
所以当使用ThreadPoolExecutor时,使用分治法会存在问题,因为ThreadPoolExecutor中的线程无法像任务队列中再添加一个任务并且在等待该任务完成之后再继续执行。而使用ForkJoinPool时,就能够让其中的线程创建新的任务,并挂起当前的任务,此时线程就能够从队列中选择子任务执行。
以上程序的关键是fork()和join()方法。在ForkJoinPool使用的线程中,会使用一个内部队列来对需要执行的任务以及子任务进行操作来保证它们的执行顺序。
那么使用ThreadPoolExecutor或者ForkJoinPool,会有什么性能的差异呢?
首先,使用ForkJoinPool能够使用数量有限的线程来完成非常多的具有父子关系的任务,比如使用4个线程来完成超过200万个任务。但是,使用ThreadPoolExecutor时,是不可能完成的,因为ThreadPoolExecutor中的Thread无法选择优先执行子任务,需要完成200万个具有父子关系的任务时,也需要200万个线程,显然这是不可行的。
ps:ForkJoinPool在执行过程中,会创建大量的子任务,导致GC进行垃圾回收,这些是需要注意的。
具体Demo如下:
1.ForkJoinPoolAction
package chap07.ForkJoinDemo.ForkJoin;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
public class ForkJoinPoolAction {
//ForkJoinPool的优势在于,可以充分利用多cpu,多核cpu的优势,
// 把一个任务拆分成多个“小任务”,把多个“小任务”放到多个处理器核心上并行执行;
// 当多个“小任务”执行完成之后,再将这些执行结果合并起来即可
//创建了ForkJoinPool实例之后,就可以调用ForkJoinPool的submit(ForkJoinTask task)
// 或invoke(ForkJoinTask task)方法来执行指定任务了
//其中ForkJoinTask代表一个可以并行、合并的任务。ForkJoinTask是一个抽象类,
// 它还有两个抽象子类:RecusiveAction和RecusiveTask。
// 其中RecusiveTask代表有返回值的任务,而RecusiveAction代表没有返回值的任务。
public static void main(String[] args) throws Exception {
//需求:简单打印1-300的数字,程序将一个大任务拆分成多个小任务,并将任务交给ForkJoinPool来执行
PrintTask task = new PrintTask(0, 3000);
//创建线程池
ForkJoinPool pool = new ForkJoinPool();
//将task提交至线程池
pool.submit(task);
//线程阻塞,等待所有任务完成
pool.awaitTermination(2, TimeUnit.SECONDS);
pool.shutdown();
}
}
package chap07.ForkJoinDemo.ForkJoin;
import java.util.concurrent.RecursiveAction;
/**
* 继承RecursiveAction来实现可分解的任务,无返回结果
*/
public class PrintTask extends RecursiveAction {
private static final int THRESHLOD = 50;//最多只能打印50个数
private int start;
private int end;
@Override
protected void compute() {
if (end - start < THRESHLOD) {
for (int i = start; i < end; i++) {
System.out.println(Thread.currentThread().getName() + "的i值:" + i);
}
}else{
//递归切分
int mid=(start+end)/2;
PrintTask leftTask = new PrintTask(start, mid);
PrintTask rightTask = new PrintTask(mid, end);
//并行执行两个任务
leftTask.fork();
rightTask.fork();
}
}
public PrintTask(int start, int end) {
this.start = start;
this.end = end;
}
public int getStart() {
return start;
}
}
输出结果:
ForkJoinPool-1-worker-12的i值:2385
ForkJoinPool-1-worker-12的i值:2386
ForkJoinPool-1-worker-12的i值:2387
ForkJoinPool-1-worker-12的i值:2388
ForkJoinPool-1-worker-12的i值:2389
....
2.ForkJoinPoolTask
package chap07.ForkJoinDemo.ForkJoin;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
/**
* RescursiveTask存在返回值
*/
public class ForkJoinPoolTask {
public static void main(String[] args) throws Exception {
//需求:对于长度为100的元素进行累加
int[] nums = new int[100];
Random random = new Random();
int total=0;
//初始化100个数组元素
long start = System.nanoTime();
for (int i = 0; i < 100; i++) {
int temp = random.nextInt(20);
nums[i]=temp;
total += nums[i];
}
long end = System.nanoTime();
System.out.println("用时:"+(end-start)+"mesecs,初始化数组总和:" + total);
long startTask=System.nanoTime();
//创建Taks
SumTask task = new SumTask(nums,0,nums.length);
//创建线程池
ForkJoinPool pool = new ForkJoinPool();
//提交任务,存在返回值
ForkJoinTask future = pool.submit(task);
//显示结果
long endTask = System.nanoTime();
System.out.println("用时:"+(endTask-startTask)+"mesecs,多线程执行结果:" + future.get());
//关闭多线程
pool.shutdown();
}
}
package chap07.ForkJoinDemo.ForkJoin;
import java.util.concurrent.RecursiveTask;
public class SumTask extends RecursiveTask {
private static final int THRESHOLD=20;//每个小任务,最多只累加20个数
private int nums[];
private int start;
private int end;
public SumTask(int[] nums, int start, int end) {
super();
this.nums = nums;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
int sum=0;
if (end - start < THRESHOLD) {
for (int i = start; i < end; i++) {
sum += nums[i];
}
return sum;
}else{
//当分块超过阈值时,则需要对数据进行拆分
int mid = (start + end) / 2;
SumTask leftTask = new SumTask(nums, start, mid);
SumTask rightTask = new SumTask(nums, mid, end);
//并行执行两个小任务
leftTask.fork();
rightTask.fork();
//把两个小任务累加合并
return leftTask.join() + rightTask.join();
}
}
}
输出结果:
用时:43392mesecs,初始化数组总和:863
用时:2823759mesecs,多线程执行结果:863
package chap07.ForkJoinDemo;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.stream.LongStream;
//ForkJoinTask是一个抽象类,它还有两个抽象子类:
// RecusiveAction和RecusiveTask。其中RecusiveTask代表有返回值的任务,
// 而RecusiveAction代表没有返回值的任务
public class ForkJoinSumCalculator extends java.util.concurrent.RecursiveTask {
private final long[] numbers;
private final int start;
private final int end;
//切分阈值
public static final long THRESHOLD=10_000;
public ForkJoinSumCalculator(long[] numbers) {
//调用多参构造
this(numbers, 0, numbers.length);
}
public ForkJoinSumCalculator(long[] numbers,int start,int end) {
this.numbers = numbers;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
int length=end-start;
//如果小于阈值则顺序计算结果
if (length < THRESHOLD) {
return computeSequentiall();
}
//如果不小于,则采用类似分治递归分块求解
ForkJoinSumCalculator leftTask = new ForkJoinSumCalculator(numbers, start, start + length / 2);
//利用另一个ForkJoinPool线程异步执行新创建的子任务
leftTask.fork();
//创建一个任务为数组的后一半求和
ForkJoinSumCalculator rightTask = new ForkJoinSumCalculator(numbers, start + length / 2, end);
//同步执行第二个子任务,有可能允许进一步递归划分
Long rightResult = rightTask.compute();
//读取第一个子任务的结果,如果尚未完成就等待
Long leftResult = leftTask.join();
//合并结果,并返回
return rightResult+leftResult;
}
private Long computeSequentiall() {
long sum=0;
for (int i = start; i < end; i++) {
sum += numbers[i];
}
return sum;
}
//用分支/合并框架执行并行求和
//测试函数
public static long forkJoinSum(long n) {
//该步骤明显慢于并行流的版本,因为必须先把整个数字流都放进一个数组中,之后才能在ForkJoinSumCalculator任务中使用它
long[] numbers = LongStream.rangeClosed(1, n).toArray();
ForkJoinTask task = new ForkJoinSumCalculator(numbers);
//开启多线程
return new ForkJoinPool().invoke(task);
}
}
package chap07.ForkJoinDemo;
import java.util.function.Function;
public class Main {
public static void main(String[] args) {
//用分支/合并框架执行并行求和
int n=1_000_000;
System.out.println("ForkJoin sum done in: " + measureSumPerf(ForkJoinSumCalculator::forkJoinSum, 10_000_000) + "msecs");
}
//定义测试函数
public static long measureSumPerf(Function adder, long n) {
long fastest = Long.MAX_VALUE;
//迭代10次
for (int i = 0; i < 10; i++) {
long start=System.nanoTime();
long sum = adder.apply(n);
long duration=(System.nanoTime()-start)/1_000_000;
System.out.println("Result: " + sum);
//取最小值
if (duration < fastest) {
fastest = duration;
}
}
return fastest;
}
}
输出结果:
Result: 50000005000000
Result: 50000005000000
Result: 50000005000000
Result: 50000005000000
Result: 50000005000000
Result: 50000005000000
Result: 50000005000000
Result: 50000005000000
Result: 50000005000000
Result: 50000005000000
ForkJoin sum done in: 56msecs