Fork/Join框架是Java7提供了的一个用于并行执行任务的框架, 是一个把大任务分割成若干个小任务,最终汇总每个小任务结果后得到大任务结果的框架。Fork/Join的运行流程如下图所示:
工作窃取(work-stealing)算法是指某个线程从其他队列里窃取任务来执行。那么为什么需要使用工作窃取算法呢?
假如我们需要做一个比较大的任务,我们可以把这个任务分割为若干互不依赖的子任务,为了减少线程间的竞争,于是把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应,比如A线程负责处理A队列里的任务。但是有的线程会先把自己队列里的任务干完,而其他线程对应的队列里还有任务等待处理。干完活的线程与其等着,不如去帮其他线程干活,于是它就去其他线程的队列里窃取一个任务来执行。而在这时它们会访问同一个队列,所以为了减少窃取任务线程和被窃取任务线程之间的竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行。
工作窃取的运行流程图如下:
工作窃取算法的优点:充分利用线程进行并行计算,并减少了线程间的竞争;
工作窃取算法的缺点:在某些情况下还是存在竞争,比如双端队列里只有一个任务时。并且消耗了更多的系统资源,比如创建多个线程和多个双端队列。
在Java的Fork/Join框架中,它提供了两个类来帮助我们完成任务分割以及执行任务并合并结果:
1、ForkJoinTask:我们要使用ForkJoin框架,必须首先创建一个ForkJoin任务。它提供在任务中执行fork()和join()操作的机制,通常情况下我们不需要直接继承ForkJoinTask类,而只需要继承它的子类,Fork/Join框架提供了以下两个子类:
2、ForkJoinPool :ForkJoinTask需要通过ForkJoinPool来执行,任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务。
ForkJoinPool里面,有两个特别重要的成员如下:
volatile WorkQueue[] workQueues; // main registry
final ForkJoinWorkerThreadFactory factory;
workQueues 用于保存向ForkJoinPool提交的任务,而具体的执行由ForkJoinWorkerThread执行,而ForkJoinWorkerThreadFactory可以用于生产出ForkJoinWorkerThread:
public static interface ForkJoinWorkerThreadFactory {
/**
* Returns a new worker thread operating in the given pool.
*
* @param pool the pool this thread works in
* @return the new worker thread
* @throws NullPointerException if the pool is null
*/
public ForkJoinWorkerThread newThread(ForkJoinPool pool);
}
ForkJoinTask的fork方法实现原理
public final ForkJoinTask fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this);
return this;
}
若当前线程是ForkJoinWorkerThread线程,则强制类型转换(向下转换)成ForkJoinWorkerThread,然后将任务push到这个线程负责的队列里面去,在ForkJoinWorkerThread类中有一个pool和一个workQueue字段:
// 线程工作的ForkJoinPool
final ForkJoinPool pool; // the pool this thread works in
// 工作窃取队列
final ForkJoinPool.WorkQueue workQueue; // work-stealing mechanics
workQueue的push()方法如下:
final void push(ForkJoinTask> task) {
ForkJoinTask>[] a; ForkJoinPool p;
int b = base, s = top, n;
if ((a = array) != null) { // ignore if queue removed
int m = a.length - 1; // fenced write for task visibility
U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
U.putOrderedInt(this, QTOP, s + 1);
if ((n = s - b) <= 1) {
if ((p = pool) != null)
p.signalWork(p.workQueues, this);
}
else if (n >= m)
growArray();
}
}
该方法的主要功能就是将当前任务存放在ForkJoinTask数组array里。然后再调用ForkJoinPool的signalWork()方法唤醒或创建一个工作线程来执行任务。
Join方法的主要作用是阻塞当前线程并等待获取结果,其源码如下:
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
首先,它调用了doJoin()方法,通过doJoin()方法得到当前任务的状态来判断返回什么结果,任务状态有四种:已完成(NORMAL),被取消(CANCELLED),信号(SIGNAL)和出现异常(EXCEPTIONAL):
若状态不是NORMAL,则通过reportException(int)方法来处理状态:
private void reportException(int s) {
if (s == CANCELLED)
throw new CancellationException();
if (s == EXCEPTIONAL)
rethrow(getThrowableException());
}
让我们再来分析下doJoin()方法的实现代码:
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
return (s = status) < 0 ? s :
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
(w = (wt = (ForkJoinWorkerThread)t).workQueue).
tryUnpush(this) && (s = doExec()) < 0 ? s :
wt.pool.awaitJoin(w, this, 0L) :
externalAwaitDone();
}
在doJoin()方法里,首先通过查看任务的状态,看任务是否已经执行完了,如果执行完了,则直接返回任务状态,如果没有执行完,则从任务数组里取出任务并执行。如果任务顺利执行完成了,则设置任务状态为NORMAL,如果出现异常,则纪录异常,并将任务状态设置为EXCEPTIONAL。
执行任务是通过doExec()方法来完成的:
final int doExec() {
int s; boolean completed;
if ((s = status) >= 0) {
try {
completed = exec();
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
if (completed)
s = setCompletion(NORMAL);
}
return s;
}
真正的执行过程是由exec()方法来完成的:
protected abstract boolean exec();
这就是我们需要重写的方法,若是我们的任务继承自RecursiveAction,则我们需要重写RecursiveAction的compute()方法:
public abstract class RecursiveAction extends ForkJoinTask {
private static final long serialVersionUID = 5232453952276485070L;
/**
* The main computation performed by this task.
*/
protected abstract void compute();
/**
* Always returns {@code null}.
*
* @return {@code null} always
*/
public final Void getRawResult() { return null; }
/**
* Requires null completion value.
*/
protected final void setRawResult(Void mustBeNull) { }
/**
* Implements execution conventions for RecursiveActions.
*/
protected final boolean exec() {
compute();
return true;
}
}
若是我们的任务继承自RecursiveTask,则我们同样需要重写RecursiveTask的compute()方法:
public abstract class RecursiveTask extends ForkJoinTask {
private static final long serialVersionUID = 5232453952276485270L;
/**
* The result of the computation.
*/
V result;
/**
* The main computation performed by this task.
* @return the result of the computation
*/
protected abstract V compute();
public final V getRawResult() {
return result;
}
protected final void setRawResult(V value) {
result = value;
}
/**
* Implements execution conventions for RecursiveTask.
*/
protected final boolean exec() {
result = compute();
return true;
}
}
通过上面的分析可知,执行我们的业务代码是在调用了join()之后的,也就是说,fork仅仅是分割任务,只有当我们执行join的时候,我们的任务才会被执行。
public class ForkJoinTest {
static class SumTask extends RecursiveTask {
static final int THRESHOLD = 100;
long[] array;
int start;
int end;
SumTask(long[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
if (end - start <= THRESHOLD) {
// 如果任务足够小,直接计算:
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
}
System.out.println(String.format("compute %d~%d = %d", start, end, sum));
return sum;
}
// 任务太大,一分为二:
int middle = (end + start) / 2;
System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end));
SumTask subtask1 = new SumTask(this.array, start, middle);
SumTask subtask2 = new SumTask(this.array, middle, end);
invokeAll(subtask1, subtask2);
// subtask1.fork();
// subtask2.fork();
Long subresult1 = subtask1.join();
Long subresult2 = subtask2.join();
Long result = subresult1 + subresult2;
System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result);
return result;
}
}
public static void main(String[] args) {
// 创建随机数组成的数组:
long[] array = new long[1000];
fillRandom(array);
// fork/join task:
ForkJoinPool fjp = new ForkJoinPool(4); // 最大并发数4
ForkJoinTask task = new SumTask(array, 0, array.length);
long startTime = System.currentTimeMillis();
Long result = fjp.invoke(task);
long endTime = System.currentTimeMillis();
System.out.println("Fork/join sum: " + result + " in " + (endTime - startTime) + " ms.");
}
private static void fillRandom(long[] array) {
Random random = new Random();
int bound = 100000;
for (int i = 0; i < array.length; i++) {
array[i] = random.nextInt(bound);
}
}
}
运行结果:
split 0~1000 ==> 0~500, 500~1000
split 0~500 ==> 0~250, 250~500
split 500~1000 ==> 500~750, 750~1000
split 250~500 ==> 250~375, 375~500
split 250~375 ==> 250~312, 312~375
split 0~250 ==> 0~125, 125~250
split 0~125 ==> 0~62, 62~125
split 500~750 ==> 500~625, 625~750
split 750~1000 ==> 750~875, 875~1000
split 500~625 ==> 500~562, 562~625
split 750~875 ==> 750~812, 812~875
compute 0~62 = 3321492
compute 750~812 = 3109921
compute 250~312 = 2809208
compute 500~562 = 3113648
compute 62~125 = 3177944
result = 3321492 + 3177944 ==> 6499436
compute 812~875 = 3123212
split 125~250 ==> 125~187, 187~250
result = 3109921 + 3123212 ==> 6233133
compute 562~625 = 3423239
result = 3113648 + 3423239 ==> 6536887
split 875~1000 ==> 875~937, 937~1000
split 625~750 ==> 625~687, 687~750
compute 312~375 = 3161144
result = 2809208 + 3161144 ==> 5970352
split 375~500 ==> 375~437, 437~500
compute 125~187 = 2999804
compute 625~687 = 3002258
compute 875~937 = 2825923
compute 375~437 = 2772717
compute 187~250 = 3647221
result = 2999804 + 3647221 ==> 6647025
result = 6499436 + 6647025 ==> 13146461
compute 937~1000 = 3277843
result = 2825923 + 3277843 ==> 6103766
compute 687~750 = 2819135
result = 6233133 + 6103766 ==> 12336899
result = 3002258 + 2819135 ==> 5821393
result = 6536887 + 5821393 ==> 12358280
result = 12358280 + 12336899 ==> 24695179
compute 437~500 = 3185771
result = 2772717 + 3185771 ==> 5958488
result = 5970352 + 5958488 ==> 11928840
result = 13146461 + 11928840 ==> 25075301
result = 25075301 + 24695179 ==> 49770480
Fork/join sum: 49770480 in 4031 ms.
我们是采用了廖雪峰老师的源代码作为应用示例讲解的,该代码就是通过Fork/Join框架来计算数组的和,计算耗时4031毫秒。通过该代码作为应用示例主要是为了告诉大家,使用Fork/Join模型的正确方式,在源代码中可以看到,SumTask继承自RecursiveTask,重写的compute方法为:
protected Long compute() {
if (end - start <= THRESHOLD) {
// 如果任务足够小,直接计算:
long sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
}
System.out.println(String.format("compute %d~%d = %d", start, end, sum));
return sum;
}
// 任务太大,一分为二:
int middle = (end + start) / 2;
System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end));
SumTask subtask1 = new SumTask(this.array, start, middle);
SumTask subtask2 = new SumTask(this.array, middle, end);
invokeAll(subtask1, subtask2);
// subtask1.fork();
// subtask2.fork();
Long subresult1 = subtask1.join();
Long subresult2 = subtask2.join();
Long result = subresult1 + subresult2;
System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result);
return result;
}
compute()方法使用了invokeAll方法来分解任务,而不是它下面的
subtask1.fork();
subtask2.fork();
这两个方法,使用invokeAll方法的主要原因是为了充分利用线程池,在invokeAll的N个任务中,其中N-1个任务会使用fork()交给其它线程执行,但是,它还会留一个任务自己执行,这样,就充分利用了线程池,保证没有空闲的不干活的线程。
若是采用另外一种方式来运行,程序的运行时间为6028毫秒,可以看到,明显比invokeAll方式慢了很多。
方腾飞:《Java并发编程的艺术》
Java的Fork/Join任务,你写对了吗?