Jdk1.7 JUC源码增量解析(4)-ForkJoin-ForkJoin任务的fork/join过程
作者:大飞
- 这篇通过分析一个ForkJoin任务的执行过程来分析ForkJoin的相关代码,主要侧重于分裂(fork)/合并(join)过程。
public class SumTask extends RecursiveTask{
private static final int THRESHOLD = 10;
private long start;
private long end;
public SumTask(long n) {
this(1,n);
}
private SumTask(long start, long end) {
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
long sum = 0;
if((end - start) <= THRESHOLD){
for(long l = start; l <= end; l++){
sum += l;
}
}else{
long mid = (start + end) >>> 1;
SumTask left = new SumTask(start, mid);
SumTask right = new SumTask(mid + 1, end);
left.fork();
right.fork();
sum = left.join() + right.join();
}
return sum;
}
private static final long serialVersionUID = 1L;
}
- 这里略过任务提交、执行的一些过程,上篇都分析过了。任务执行过程中,新建了子任务,然后进行fork操作,看下fork的源码:
public final ForkJoinTask fork() {
((ForkJoinWorkerThread) Thread.currentThread())
.pushTask(this);
return this;
}
fork里面其实是将任务放到当前工作线程的任务队列里面了,看下pushTask方法细节:
final void pushTask(ForkJoinTask> t) {
ForkJoinTask>[] q; int s, m;
if ((q = queue) != null) { // ignore if queue removed
//这里首先根据当前的queueTop对队列(数组)长度取模来算出放置任务的下标
//然后再通过下标算出偏移地址,提供给Unsafe使用。
long u = (((s = queueTop) & (m = q.length - 1)) << ASHIFT) + ABASE;
//设置任务。
UNSAFE.putOrderedObject(q, u, t);
//修改queueTop
queueTop = s + 1; // or use putOrderedInt
if ((s -= queueBase) <= 2)
pool.signalWork();
else if (s == m)
growQueue(); //如果队列满了,扩展一下队列容量。
}
}
看下扩展队列调用的growQueue方法:
private void growQueue() {
ForkJoinTask>[] oldQ = queue;
int size = oldQ != null ? oldQ.length << 1 : INITIAL_QUEUE_CAPACITY;
if (size > MAXIMUM_QUEUE_CAPACITY)
throw new RejectedExecutionException("Queue capacity exceeded");
if (size < INITIAL_QUEUE_CAPACITY)
size = INITIAL_QUEUE_CAPACITY;
ForkJoinTask>[] q = queue = new ForkJoinTask>[size];
int mask = size - 1;
int top = queueTop;
int oldMask;
if (oldQ != null && (oldMask = oldQ.length - 1) >= 0) {
for (int b = queueBase; b != top; ++b) {
long u = ((b & oldMask) << ASHIFT) + ABASE;
Object x = UNSAFE.getObjectVolatile(oldQ, u);
if (x != null && UNSAFE.compareAndSwapObject(oldQ, u, x, null))
UNSAFE.putObjectVolatile
(q, ((b & mask) << ASHIFT) + ABASE, x);
}
}
}
- 完事儿了,fork过程就这么简单。fork出子任务后,当前任务的计算可能会需要子任务的结果,需要join子任务:
sum = left.join() + right.join();
看下join的源码:
public final V join() {
if (doJoin() != NORMAL)
return reportResult();
else
return getRawResult();
}
先看下doJoin方法:
private int doJoin() {
Thread t; ForkJoinWorkerThread w; int s; boolean completed;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) {
if ((s = status) < 0)
return s; //如果当前任务已经完成,直接返回状态。
if ((w = (ForkJoinWorkerThread)t).unpushTask(this)) {
//如果当前任务恰好是当前工作线程的队列顶端的第一个任务
//那么将该任务出队,然后执行。
try {
completed = exec();
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
if (completed)
return setCompletion(NORMAL);
}
//否则调用当前工作线程的joinTask方法。
return w.joinTask(this);
}
else
//如果当前线程不是ForkJoin工作线程,那么调用externalAwaitDone
return externalAwaitDone();
}
注意这里调用的是ForkJoinWorkerThread的unpushTash方法,这是另一个版本的pop,看下实现:
final boolean unpushTask(ForkJoinTask> t) {
ForkJoinTask>[] q;
int s;
if ((q = queue) != null && (s = queueTop) != queueBase &&
UNSAFE.compareAndSwapObject
(q, (((q.length - 1) & --s) << ASHIFT) + ABASE, t, null)) {
queueTop = s; // or putOrderedInt
return true;
}
return false;
}
继续doJoin方法,如果unpushTask方法失败,就会调用ForkJoinWorkerThread的joinTask方法,看下这个方法:
private static final int MAX_HELP = 16;
...
final int joinTask(ForkJoinTask> joinMe) {
//记录之前的合并任务。
ForkJoinTask> prevJoin = currentJoin;
//设置当前工作线程的合并任务。
currentJoin = joinMe;
for (int s, retries = MAX_HELP;;) {
if ((s = joinMe.status) < 0) {
//如果合并任务已经完成,恢复之前的合并任务。
currentJoin = prevJoin;
return s; //返回任务状态。
}
if (retries > 0) {
if (queueTop != queueBase) {
/*
* 如果当前任务队列中有任务,尝试从当前队列顶端获取给定任务
* (如果给定任务恰好在当前任务队列顶端的话)或者其他一个已经
* 被取消的任务。
*/
if (!localHelpJoinTask(joinMe))
retries = 0; // cannot help
}
else if (retries == MAX_HELP >>> 1) {
--retries; // check uncommon case
/*
* 这里尝试一种特殊情况:如果给定的任务正好在其他工作线程的
* 队列的底部,那么尝试窃取这个任务并执行。
*/
if (tryDeqAndExec(joinMe) >= 0)
Thread.yield(); // 如果没成功,这里出让一下CPU。
}
else
retries = helpJoinTask(joinMe) ? MAX_HELP : retries - 1;
}
else {
retries = MAX_HELP; // restart if not done
pool.tryAwaitJoin(joinMe);
}
}
}
看下localHelpJoinTask方法:
private boolean localHelpJoinTask(ForkJoinTask> joinMe) {
int s, i; ForkJoinTask>[] q; ForkJoinTask> t;
if ((s = queueTop) != queueBase && (q = queue) != null &&
(i = (q.length - 1) & --s) >= 0 &&
(t = q[i]) != null) {
if (t != joinMe && t.status >= 0)
//如果当前工作线程的任务队列顶端的任务不是给定任务,
//且任务的状态是未取消(这里如果<0,一定是取消的任务),返回false。
return false;
if (UNSAFE.compareAndSwapObject
(q, (i << ASHIFT) + ABASE, t, null)) {
//取出给定任务或者一个被取消的任务。
queueTop = s; // or putOrderedInt
t.doExec();
}
}
return true;
}
再看下joinTask方法中调用的tryDeqAndExec方法:
private int tryDeqAndExec(ForkJoinTask> t) {
int m = pool.scanGuard & SMASK;
ForkJoinWorkerThread[] ws = pool.workers;
//扫描所有工作线程
if (ws != null && ws.length > m && t.status >= 0) {
for (int j = 0; j <= m; ++j) {
ForkJoinTask>[] q; int b, i;
ForkJoinWorkerThread v = ws[j];
if (v != null &&
(b = v.queueBase) != v.queueTop &&
(q = v.queue) != null &&
(i = (q.length - 1) & b) >= 0 &&
q[i] == t) {
//如果有工作线程的任务队列的底部正好是给定任务t。
//尝试窃取t后执行。
long u = (i << ASHIFT) + ABASE;
if (v.queueBase == b &&
UNSAFE.compareAndSwapObject(q, u, t, null)) {
v.queueBase = b + 1;
v.stealHint = poolIndex;
ForkJoinTask> ps = currentSteal;
currentSteal = t;
t.doExec();
currentSteal = ps;
}
break;
}
}
}
return t.status;
}
最后看下joinTask方法中调用的helpJoinTask方法:
private boolean helpJoinTask(ForkJoinTask> joinMe) {
boolean helped = false;
int m = pool.scanGuard & SMASK;
ForkJoinWorkerThread[] ws = pool.workers;
if (ws != null && ws.length > m && joinMe.status >= 0) {
int levels = MAX_HELP; // remaining chain length
ForkJoinTask> task = joinMe; // base of chain
outer:for (ForkJoinWorkerThread thread = this;;) {
// 找到线程thread的窃取者v
ForkJoinWorkerThread v = ws[thread.stealHint & m];
if (v == null || v.currentSteal != task) {
//如果thread没有窃取者或者v当前窃取的任务不是task,扫描工作线程数组。
for (int j = 0; ;) { // search array
if ((v = ws[j]) != null && v.currentSteal == task) {
//如果找到了窃取线程,将其设置为thread的窃取线程。
thread.stealHint = j;
break; // save hint for next time
}
if (++j > m)
break outer; // 没找到的话,直接跳出outer循环。
}
}
// 找到了窃取者v。
for (;;) {
ForkJoinTask>[] q; int b, i;
if (joinMe.status < 0)
break outer; //如果joinMe任务已经完成,跳出outer循环。
if ((b = v.queueBase) == v.queueTop ||
(q = v.queue) == null ||
(i = (q.length-1) & b) < 0)
break; //如果v的队列是空的,跳出当前循环。
long u = (i << ASHIFT) + ABASE;
ForkJoinTask> t = q[i];
if (task.status < 0)
break outer; //如果task任务已经完成,跳出outer循环。
//尝试窃取v的任务队列底部的任务。
if (t != null && v.queueBase == b &&
UNSAFE.compareAndSwapObject(q, u, t, null)) {
//窃取成功后,执行任务。
v.queueBase = b + 1;
v.stealHint = poolIndex;
ForkJoinTask> ps = currentSteal;
currentSteal = t;
t.doExec();
currentSteal = ps;
helped = true;
}
}
// 再去找v的窃取者,注意这里是一个链。
ForkJoinTask> next = v.currentJoin;
if (--levels > 0 && task.status >= 0 &&
next != null && next != task) {
task = next;
thread = v;
}
else
break; // 如果超过最大深度(MAX_HELP) 或者 task已经执行完成 或者 找到了头(next==null) 或者 出现循环 退出。
}
}
return helped;
}
接上面joinTask方法,如果尝试不成功,会调用Pool的tryAwaitJoin方法:
final void tryAwaitJoin(ForkJoinTask> joinMe) {
int s;
Thread.interrupted(); // clear interrupts before checking termination
//如果joinMe未完成
if (joinMe.status >= 0) {
//尝试阻塞等待之前的预操作
if (tryPreBlock()) {
//在joinMe任务上阻塞等待
joinMe.tryAwaitDone(0L);
//被唤醒后的操作
postBlock();
}
else if ((ctl & STOP_BIT) != 0L)
//如果Pool关闭了,取消任务。
joinMe.cancelIgnoringExceptions();
}
}
先看下tryAwaitJoin方法中调用的tryPreBlock方法:
/**
* Tries to increment blockedCount, decrement active count
* (sometimes implicitly) and possibly release or create a
* compensating worker in preparation for blocking. Fails
* on contention or termination.
*
* @return true if the caller can block, else should recheck and retry
*/
private boolean tryPreBlock() {
int b = blockedCount;
//累加等待join任务的计数。
if (UNSAFE.compareAndSwapInt(this, blockedCountOffset, b, b + 1)) {
int pc = parallelism;
do {
ForkJoinWorkerThread[] ws; ForkJoinWorkerThread w;
int e, ac, tc, rc, i;
long c = ctl;
int u = (int)(c >>> 32);
if ((e = (int)c) < 0) {
// 如果Pool关闭了,跳过。
}
else if ((ac = (u >> UAC_SHIFT)) <= 0 && e != 0 &&
(ws = workers) != null &&
(i = ~e & SMASK) < ws.length &&
(w = ws[i]) != null) {
//如果当前活动的工作线程不大于cpu核数,且有线程在等待任务(处于空闲状态)。
//那么唤醒这个工作线程。
long nc = ((long)(w.nextWait & E_MASK) |
(c & (AC_MASK|TC_MASK)));
if (w.eventCount == e &&
UNSAFE.compareAndSwapLong(this, ctlOffset, c, nc)) {
w.eventCount = (e + EC_UNIT) & E_MASK;
if (w.parked)
UNSAFE.unpark(w);
return true;
}
}
else if ((tc = (short)(u >>> UTC_SHIFT)) >= 0 && ac + pc > 1) {
//如果总的工作线程数量不少于cpu核心数量,且至少有一个活动的工作线程。
//尝试在总控信息上将AC递减。
long nc = ((c - AC_UNIT) & AC_MASK) | (c & ~AC_MASK);
if (UNSAFE.compareAndSwapLong(this, ctlOffset, c, nc))
return true;
}
else if (tc + pc < MAX_ID) {
//如果不满足上面条件,这里会增加一个工作线程。
long nc = ((c + TC_UNIT) & TC_MASK) | (c & ~TC_MASK);
if (UNSAFE.compareAndSwapLong(this, ctlOffset, c, nc)) {
addWorker();
return true;
}
}
//如果失败,这里会把刚才对b增加的1给减回去。
} while (!UNSAFE.compareAndSwapInt(this, blockedCountOffset,
b = blockedCount, b - 1));
}
return false;
}
继续看下tryAwaitJoin方法中调用的ForkJoinTask的tryAwaitDone方法:
final void tryAwaitDone(long millis) {
int s;
try {
if (((s = status) > 0 ||
(s == 0 &&
UNSAFE.compareAndSwapInt(this, statusOffset, 0, SIGNAL))) &&
status > 0) {
synchronized (this) {
if (status > 0)
wait(millis);
}
}
} catch (InterruptedException ie) {
// caller must check termination
}
}
在看下tryAwaitJoin方法中调用的postBlock方法:
private void postBlock() {
long c;
do {} while (!UNSAFE.compareAndSwapLong(this, ctlOffset, // 累加活动线程计数
c = ctl, c + AC_UNIT));
int b;
do {} while (!UNSAFE.compareAndSwapInt(this, blockedCountOffset, // 递减等待join任务的计数。
b = blockedCount, b - 1));
}
最后,tryAwaitJoin方法中如果发现Pool关闭,会取消joinMe任务,调用其cancelIgnoringExceptions方法:
final void cancelIgnoringExceptions() {
try {
cancel(false);
} catch (Throwable ignore) {
}
}
private static final int CANCELLED = -2;
public boolean cancel(boolean mayInterruptIfRunning) {
return setCompletion(CANCELLED) == CANCELLED;
}
private int setCompletion(int completion) {
for (int s;;) {
if ((s = status) < 0)
return s;
if (UNSAFE.compareAndSwapInt(this, statusOffset, s, completion)) {
if (s != 0)
synchronized (this) { notifyAll(); }
return completion;
}
}
}
最后回到join方法,如果正常完成会调用getRawResult方法:
public abstract V getRawResult();
ForkJoinTask中的getRawResult方法未实现,交由子类去实现,比如在RecursiveTask中:
V result;
...
public final V getRawResult() {
return result;
}
...
protected final boolean exec() {
result = compute();
return true;
}
如果join方法中,任务非正常结束,会调用reportResult方法:
private V reportResult() {
int s; Throwable ex;
if ((s = status) == CANCELLED)
throw new CancellationException();
if (s == EXCEPTIONAL && (ex = getThrowableException()) != null)
UNSAFE.throwException(ex);
return getRawResult();
}
- 简单总结一下ForkJoinPool中的ForkJoinTask的fork/join流程:
本篇通过一个ForkJoin任务的fork/join过程来分析代码,结合上一篇的话,已经涵盖了ForkJoinTask一个完整执行过程的相关代码了。下篇会做一个收尾工作,将本篇和上篇未涉及到的ForkJoin框架源码分析一下。