对于一些耗时任务来讲,往往单线程处理是非常低效的,这会导致cpu闲置,因此在java层面需要将任务拆分分配给多个线程提高处理效率,如果有结果依赖的再进行归并处理显然可以增加我们的任务处理效率;在java中ForkJoinPool就是可以实现这个功能,他本质是一个线程池,我们首先需要通过一个类描述我们的task,以及拆分的粒度和归并的逻辑等;
public class MyLongTimeTask extends RecursiveTask<Integer> {
private static final Integer TASK_LENGTH = 200;
private Integer startValue;
private Integer endValue;
public MyLongTimeTask(Integer startValue,Integer endValue){
this.startValue=startValue;
this.endValue=endValue;
}
protected Integer compute() {
if(endValue - startValue > TASK_LENGTH) {
return splitTask();
}
return startCompute();
}
/**
* @没到拆分要求,开始拆分逻辑
*/
private Integer splitTask(){
/**
* 一般我们就是二分法,fork意味着将任务加入到队列中
*/
int middle = (startValue + endValue) / 2;
MyLongTimeTask subTask1 = new MyLongTimeTask(startValue, middle);
subTask1.fork();
MyLongTimeTask subTask2 = new MyLongTimeTask(middle , endValue);
subTask2.fork();
/**
* 大任务被拆分成小任务后
* 大任务和小任务之间就会形成一个JOIN的树形等待结构,下级任务计算或者归并完成后,上级任务才被激活;
* 注意大任务只做归并;并不会运行startCompute;只有达到要求的小任务才会startCompute;
*/
return subTask1.join() + subTask2.join();
};
/**
* @达到拆分要求,开始计算逻辑
*/
private Integer startCompute(){
Integer totalValue = 0;
/**
* 已经拆分好了,就直接将这个基本单位内的数字累加一下就行
*/
for(int index = this.startValue ; index <= this.endValue ; index++) {
totalValue += index;
}
return totalValue;
}
}
/**
* @满足拆分要求句直接执行的任务
*/
public class MyLongTimeAction extends RecursiveAction {
private static final Integer ACTION_LENGTH = 200;
private Integer startValue;
private Integer endValue;
public MyLongTimeAction(Integer startValue,Integer endValue){
this.startValue=startValue;
this.endValue=endValue;
}
protected void compute() {
if(endValue - startValue > ACTION_LENGTH) {
splitAction();
}
startAction();
}
/**
* ACTION的特点是直接拆分,不用归并,大任务没有返回值也不用等的小任务完成
*/
private void splitAction() {
int middle = (startValue + endValue) / 2;
MyLongTimeAction left = new MyLongTimeAction(startValue, middle);
MyLongTimeAction right = new MyLongTimeAction(middle, endValue);
left.fork();
right.fork();
}
/**
* 开始执行小任务
*/
private void startAction() {
for(int i= startValue; i<endValue;i++) {
System.out.println(Thread.currentThread().getName()+"i的值"+i);
}
}
}
public class ForkJoinTest01 {
public static void main(String[] args) throws Exception {
ForkJoinPool pool = new ForkJoinPool();
/**
* @submit提交一个大任务
* 异步返回future拿到结果,future.get是一个阻塞方法
*/
ForkJoinTask<Integer> future = pool.submit(new MyLongTimeTask(1, 100));
Integer integer = future.get();
System.out.println(integer);
}
}
public class ForkJoinTest02 {
public static void main(String[] args) throws Exception {
ForkJoinPool pool = new ForkJoinPool();
/**
* @execute直接执行一个大任务
* ForkJoinTask也实现了executor的接口
*/
pool.execute(new MyLongTimeTask(1, 100));
}
}
public class ForkJoinTest03 {
public static void main(String[] args) throws Exception {
ForkJoinPool pool = new ForkJoinPool();
/**
* @invoke同步执行一个任务,是阻塞方法
*/
pool.invoke(new MyLongTimeAction(1,50000));
System.out.println("执行完成 1");
}
}
工作窃取算法,任务放在双端队列然后
默认的构造方法,是利用CPU处理器个数来决定Pool中的线程数量
public ForkJoinPool() {
this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
defaultForkJoinWorkerThreadFactory, null, false);
}
提交任务
public void execute(ForkJoinTask<?> task) {
if (task == null)
throw new NullPointerException();
externalPush(task);
}
队列和添加worker
final void externalPush(ForkJoinTask<?> task) {
WorkQueue[] ws; WorkQueue q; int m;
int r = ThreadLocalRandom.getProbe();
int rs = runState;
/**
* @上锁并将任务添加到队里中
*/
if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
(q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
U.compareAndSwapInt(q, QLOCK, 0, 1)) {
ForkJoinTask<?>[] a; int am, n, s;
if ((a = q.array) != null &&
(am = a.length - 1) > (n = (s = q.top) - q.base)) {
int j = ((am & s) << ASHIFT) + ABASE;
U.putOrderedObject(a, j, task);// 存放任务
U.putOrderedInt(q, QTOP, s + 1); // 更新top指针(索引位)
U.putIntVolatile(q, QLOCK, 0);// 解锁
if (n <= 1)
signalWork(ws, q); // 尝试创建或者激活线程
return;
}
/**
* @解锁
*/
U.compareAndSwapInt(q, QLOCK, 1, 0);
}
externalSubmit(task);
}
final void signalWork(ForkJoinPool.WorkQueue[] ws, ForkJoinPool.WorkQueue q) {
long c; int sp, i; ForkJoinPool.WorkQueue v; Thread p;
while ((c = ctl) < 0L) {
if ((sp = (int) c) == 0) {
if ((c & ADD_WORKER) != 0L)
tryAddWorker(c);//尝试添加worker
}
}
}
private void tryAddWorker(long c) {
boolean add = false;
do {
long nc = ((AC_MASK & (c + AC_UNIT)) |
(TC_MASK & (c + TC_UNIT)));
if (ctl == c) {
int rs, stop; // check if terminating
if ((stop = (rs = lockRunState()) & STOP) == 0)
add = U.compareAndSwapLong(this, CTL, c, nc);
unlockRunState(rs, rs & ~RSLOCK);
if (stop != 0)
break;
if (add) {
createWorker();//尝试创建
break;
}
}
} while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
}
/**
* ForkJoinWorkerThread的运行
* @return
*/
private boolean createWorker() {
ForkJoinWorkerThreadFactory fac = factory;
Throwable ex = null;
ForkJoinWorkerThread wt = null;
try {
if (fac != null && (wt = fac.newThread(this)) != null) {
wt.start();//运行
return true;
}
} catch (Throwable rex) {
ex = rex;
}
deregisterWorker(wt, ex);
return false;
}
运行worker
public class ForkJoinWorkerThread extends Thread {
public void run() {
if (workQueue.array == null) { // only run once
Throwable exception = null;
try {
onStart();
pool.runWorker(workQueue);
} catch (Throwable ex) {
exception = ex;
} finally {
try {
onTermination(exception);
} catch (Throwable ex) {
if (exception == null)
exception = ex;
} finally {
pool.deregisterWorker(this, exception);
}
}
}
}
}
扫描任务并调用
final void runWorker(WorkQueue w) {
w.growArray(); // allocate queue
int seed = w.hint; // initially holds randomization hint
int r = (seed == 0) ? 1 : seed; // avoid 0 for xorShift
for (ForkJoinTask<?> t;;) {
if ((t = scan(w, r)) != null)
w.runTask(t);
else if (!awaitWork(w, r))
break;
r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
}
}
final void runTask(ForkJoinTask<?> task) {
if (task != null) {
scanState &= ~SCANNING; // mark as busy
(currentSteal = task).doExec();
U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
execLocalTasks();
ForkJoinWorkerThread thread = owner;
if (++nsteals < 0) // collect on overflow
transferStealCount(pool);
scanState |= SCANNING;
if (thread != null)
thread.afterTopLevelExec();
}
}
public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
final int doExec() {
int s; boolean completed;
if ((s = status) >= 0) {
try {
completed = exec();
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
if (completed)
s = setCompletion(NORMAL);
}
return s;
}
}
public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
protected final boolean exec() {
result = compute();
return true;
}
}
join树形归并
public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
}
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
return (s = status) < 0 ? s :
/*
fork操作是原线程而非工作线程,不会被识别为ForkJoinWorkerThread,线程执行前会先fork
所以就这样externalAwaitDone下去;
*/
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
(w = (wt = (ForkJoinWorkerThread)t).workQueue).
tryUnpush(this) && (s = doExec()) < 0 ? s :
wt.pool.awaitJoin(w, this, 0L) //帮他完成compute的任务,工作窃取算法;
externalAwaitDone();
}
/*
帮其完成任务
*/
final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
int s = 0;
if (task != null && w != null) {
ForkJoinTask<?> prevJoin = w.currentJoin;
U.putOrderedObject(w, QCURRENTJOIN, task);
CountedCompleter<?> cc = (task instanceof CountedCompleter) ?
(CountedCompleter<?>)task : null;
for (;;) {
if ((s = task.status) < 0)
break;
if (cc != null)
helpComplete(w, cc, 0);
else if (w.base == w.top || w.tryRemoveAndExec(task))
helpStealer(w, task);
if ((s = task.status) < 0)
break;
long ms, ns;
if (deadline == 0L)
ms = 0L;
else if ((ns = deadline - System.nanoTime()) <= 0L)
break;
else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L)
ms = 1L;
if (tryCompensate(w)) {
task.internalWait(ms);
U.getAndAddLong(this, CTL, AC_UNIT);
}
}
U.putOrderedObject(w, QCURRENTJOIN, prevJoin);
}
return s;
}
/*
返回结果
*/
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
/*
结果的放入
*/
public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
protected final boolean exec() {
result = compute();
return true;
}
}