思路一:1到20万相加,可分段相加,1到10000,10001到20000,20001到30000…,可以分成20个片段相加,然后把各片段的结果相加得到总结果。
思路二:1到20万相加,一分为二,判断首尾值,若首尾间隔小于给定的阈值,则从首加到尾;否则递归至满足条件,然后把各结果相加得到总结果。
想要得到总结果需要考虑两个问题:
带着这些问题继续往下走:
大家首先想到的肯定是ThreadPool,池化技术也正好符合这样的场景,创建一个corePoolSize为5的线程池,各小任务提交到线程池,由池中的线程执行小任务,最后通过submit()方法返回计算结果,由此得到总结果。
实例代码:
public class ThreadPoolTest {
static ExecutorService executorService = Executors.newFixedThreadPool(5);
private static class ThreadPoolExample implements Callable{
private int from;
private int to;
public ThreadPoolExample(int from, int to) {
this.from = from;
this.to = to;
}
@Override
public Long call(){
long total = 0;
for (int i = from; i <= to; i++) {
total += i;
}
return total;
}
}
private static Long addAll(int from, int to){
List> futureList = new ArrayList>();
// int step = 10000;
int step = (from + to) / 5;
for (int i = 0; i < 5; i++) {
int fromTemp = i * step;
int toTemp = (i == step -1) ? to : (i+1)*step -1;
futureList.add(executorService.submit(new ThreadPoolExample(fromTemp, toTemp)));
}
Long total = 0L;
for (Future future : futureList ){
try {
total += future.get();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
}
return total;
}
public static void main(String[] args){
System.out.println(ThreadPoolTest.addAll(0, 200000));
executorService.shutdown();
}
}
下面给出forkjoin框架的实例代码:
public class ForkJoinFramWork {
private static class CountAdd extends RecursiveTask{
private int from;
private int to;
public CountAdd(int from, int to) {
this.from = from;
this.to = to;
}
@Override
protected Long compute() {
if (to - from < 500){
long sum = 0 ;
for (int i = from; i <= to; i++)
sum += i;
return sum;
}else{
int middle = (from + to) / 2;
CountAdd countAdd1 = new CountAdd(from, middle);
CountAdd countAdd2 = new CountAdd(middle+1, to);
countAdd1.fork();
countAdd2.fork();
return countAdd1.join() + countAdd2.join();
}
}
}
public static void main(String[] args){
ForkJoinPool forkJoinPool = new ForkJoinPool();
long startTime = System.nanoTime();
long result = forkJoinPool.invoke(new CountAdd(0, 1000));
long endTiem = System.nanoTime();
System.out.println("forkJoin: " + result + ", time:" + (endTiem - startTime) + "毫秒");
}
对比以上两种实现方式可以发现:
>forkjoin实现代码更简洁;
>forkjoin不需要显示的把任务通过循环分配给线程,只需要调用invoke()方法就可以返回结果;
因此,forkjoin在实现多线程时具有一定的优势,也就值得进一步深入了解。
追溯源码发现compute()方法是父类RecursiveTask
很明显,compute()方法是暴露给实现类的抽象方法,用于实现“分治”任务的逻辑;“分治”逻辑的具体实现则依靠fork()和join()方法。
了解forkjoin框架,需要进一步了解invoke(),fork()和join()方法。
Invoke()方法是ForkJoinPool线程池的方法【ForkJoinPool与ThreadPoolExecutor之间有什么区别?】,用于执行提交给线程池的大任务,并返回大任务的结果,返回类型与小任务返回类型相同;源码如下:【如果不需要返回呢?】
public T invoke(ForkJoinTask task) {
if (task == null)
throw new NullPointerException();
externalPush(task);
return task.join();
}
externalPush()内部尝试创建一个新的队列并把任务提交给线程池的队列,一般情况下正常执行task.join()。
join()方法是属于ForkJoinTask类的方法:
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
join()方法是final类型的,保证了子类不可重写父类的方法,很好的保证了ForkJoin框架的核心功能。if里面是用于判断当前线程工作的状态,状态分别有三种:NORMAL/CANCELLED/EXCETIONAL,如果为cancelled或者exceptional,则分别抛出不同的异常,否则执行getRawResult(),并返回结果;
getRawResult()是ForkJoinTask类的抽象方法,由RecurSiveTask类和RecurSiveAction类分别实现;其中RecursiveAction类的getRawResult()方法返回null;RecurSiveTask类的getRawResult()方法返回compute()方法的计算结果result;
public abstract class RecursiveTask extends ForkJoinTask {
private static final long serialVersionUID = 5232453952276485270L;
V result;
protected abstract V compute();
public final V getRawResult() {
return result;
}
protected final void setRawResult(V value) {
result = value;
}
protected final boolean exec() {
result = compute();
return true;
}
}
public abstract class RecursiveAction extends ForkJoinTask {
private static final long serialVersionUID = 5232453952276485070L;
protected abstract void compute();
public final Void getRawResult() { return null; }
protected final void setRawResult(Void mustBeNull) { }
protected final boolean exec() {
compute();
return true;
}
}
所以ForkJoinTask主要有两种实现,返回结果和不返回结果。
由此可以看出compute()方法是实现的主要逻辑;
在本实例中compute()方法代码如下:
@Override
protected Long compute() {
if (to - from < 500){
long sum = 0 ;
for (int i = from; i <= to; i++)
sum += i;
return sum;
}else{
int middle = (from + to) / 2;
CountAdd countAdd1 = new CountAdd(from, middle);
CountAdd countAdd2 = new CountAdd(middle+1, to);
countAdd1.fork();
countAdd2.fork();
return countAdd1.join() + countAdd2.join();
}
概括就是:对参数不断迭代,直到间隔小于500才会开始计算从头到尾的值并返回计算的结果。看到这里想必有种醍醐灌顶的感觉,没错compute()方法关键在于迭代;如何维持每代之间的关系,则是通过调用ForkJoinTask类的fork()方法,join()方法是返回compute()方法计算的结果;
fork()方法:
public final ForkJoinTask fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this);
return this;
}
判断当前线程是否是ForkJoinWorkerThread,如果是则把当前任务放到该线程的工作队列里。否则交给FrokJoinPool处理,externalPush()方法在上面已经讨论过;这里要理解一下ForkJoinWorkerThread类;该类有两个构造函数,一个较为关键的构造函数如下:
protected ForkJoinWorkerThread(ForkJoinPool pool) {
// Use a placeholder until a useful name can be set in registerWorker
super("aForkJoinWorkerThread");
this.pool = pool;
this.workQueue = pool.registerWorker(this);
}
首先:该线程是属于ForkJoinPool,且该线程带有一个workQueue;工作队列存放任务。
可知,迭代的任务放于workQueue中,那workQueue如何保证迭代的顺序?根据相关论文可知,workQueue实现是FILo,即先进后出,这样能保证后进去的小任务能最先被线程消耗,从而保证结果的准确性,最后通过join()方法可以得到迭代的和。
fork()和join()方法分别承担着处理迭代的任务、合并迭代的结果。具体的ForkJoin框架理论可以看一下参考博客,我只是从实际需求角度对forkJoin框架进行分析,原理部分个人理解不是很透彻,还需要进一步思考。
参考博客:
http://blog.dyngr.com/blog/2016/09/15/java-forkjoinpool-internals/
https://www.jianshu.com/p/44b09f52a225
https://www.jianshu.com/p/f777abb7b251