fork/join框架是ExecutorService
接口的一种具体实现,目的是为了帮助你更好地利用多处理器带来的好处。它是为那些能够被递归地拆解成子任务的工作类型量身设计的。其目的在于能够使用所有可用的运算能力来提升你的应用的性能。
类似于ExecutorService
接口的其他实现,fork/join框架会将任务分发给线程池中的工作线程。fork/join框架的独特之处在与它使用工作窃取(work-stealing)算法。完成自己的工作而处于空闲的工作线程能够从其他仍然处于忙碌(busy)状态的工作线程处窃取等待执行的任务。
fork/join框架的核心是ForkJoinPool
类,它是对AbstractExecutorService
类的扩展。ForkJoinPool
实现了工作偷取算法,并可以执行ForkJoinTask
任务。
使用fork/join框架的第一步是编写执行一部分工作的代码。你的代码结构看起来应该与下面所示的伪代码类似:
if (当前这个任务工作量足够小)
直接完成这个任务
else
将这个任务或这部分工作分解成两个部分
分别触发(invoke)这两个子任务的执行,并等待结果
你需要将这段代码包裹在一个ForkJoinTask
的子类中。不过,通常情况下会使用一种更为具体的的类型,或者是RecursiveTask
(会返回一个结果),或者是RecursiveAction
。
当你的ForkJoinTask
子类准备好了,创建一个代表所有需要完成工作的对象,然后将其作为参数传递给一个ForkJoinPool
实例的invoke()
方法即可。
work-stealing
isDone()
方法)。所有的任务都会 无阻塞 的完成。工作窃取算法的优点
工作窃取算法的缺点
ForkJoinPool 类图
sleep()
等会造成线程长时间阻塞的情况时,最好配合使用 ManagedBlocker。问题
解决方法
public class ExecutorServiceCalculator implements Calculator {
private int parallism;
private ExecutorService pool;
public ExecutorServiceCalculator() {
parallism = Runtime.getRuntime().availableProcessors(); // CPU的核心数
pool = Executors.newFixedThreadPool(parallism);
}
private static class SumTask implements Callable {
private long[] numbers;
private int from;
private int to;
public SumTask(long[] numbers, int from, int to) {
this.numbers = numbers;
this.from = from;
this.to = to;
}
@Override
public Long call() throws Exception {
long total = 0;
for (int i = from; i <= to; i++) {
total += numbers[i];
}
return total;
}
}
@Override
public long sumUp(long[] numbers) {
List> results = new ArrayList<>();
// 把任务分解为 n 份,交给 n 个线程处理
int part = numbers.length / parallism;
for (int i = 0; i < parallism; i++) {
int from = i * part;
int to = (i == parallism - 1) ? numbers.length - 1 : (i + 1) * part - 1;
results.add(pool.submit(new SumTask(numbers, from, to)));
}
// 把每个线程的结果相加,得到最终结果
long total = 0L;
for (Future f : results) {
try {
total += f.get();
} catch (Exception ignore) {}
}
return total;
}
}
public class ForkJoinCalculator implements Calculator {
private ForkJoinPool pool;
private static class SumTask extends RecursiveTask {
private long[] numbers;
private int from;
private int to;
public SumTask(long[] numbers, int from, int to) {
this.numbers = numbers;
this.from = from;
this.to = to;
}
@Override
protected Long compute() {
// 当需要计算的数字小于6时,直接计算结果
if (to - from < 6) {
long total = 0;
for (int i = from; i <= to; i++) {
total += numbers[i];
}
return total;
// 否则,把任务一分为二,递归计算
} else {
int middle = (from + to) / 2;
SumTask taskLeft = new SumTask(numbers, from, middle);
SumTask taskRight = new SumTask(numbers, middle+1, to);
taskLeft.fork();
taskRight.fork();
return taskLeft.join() + taskRight.join();
}
}
}
public ForkJoinCalculator() {
// 也可以使用公用的 ForkJoinPool:
// pool = ForkJoinPool.commonPool()
pool = new ForkJoinPool();
}
@Override
public long sumUp(long[] numbers) {
return pool.invoke(new SumTask(numbers, 0, numbers.length-1));
}
}
compute()
函数中,代码中没有显式地把任务分配给线程,只是分解了任务,而把具体的任务到线程的映射都交给了 ForkJoinPool 来完成。ForkJoinPool
fork()
)时,会放入工作队列的队尾,并且工作线程在处理自己的工作队列时,使用的是 LIFO 方式,也就是说每次从队尾取出任务来执行。join()
时,如果需要 Join 的任务尚未完成,则会先处理其他任务,并等待其完成。ForkJoinPool
WorkQueue
ForkJoinWorkThread
work-stealing
ForkJoinTask
fork 方法
fork()
做的工作只有一件事,既是把任务推入当前工作线程的工作队列里。public final ForkJoinTask fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
((ForkJoinWorkerThread)t).workQueue.push(this);
else
ForkJoinPool.common.externalPush(this);
return this;
}
join 方法
join()
的工作则复杂得多,也是它可以使得线程免于被阻塞的原因。
join()
的线程是否是 ForkJoinThread 线程。如果不是(例如 main 线程),则阻塞当前线程,等待任务完成。如果是,则不阻塞。join 方法的流程
submit()
和 fork()
其实没有本质区别,只是提交对象变成了 submitting queue 而已(还有一些同步,初始化的操作)。submitting queue 和其他 work queue 一样,是工作线程 " 窃取 " 的对象,因此当其中的任务被一个工作线程成功窃取时,就意味着提交的任务真正开始进入执行阶段。// 低位和高位掩码
private static final long SP_MASK = 0xffffffffL;
private static final long UC_MASK = ~SP_MASK;
// 活跃线程数
private static final int AC_SHIFT = 48;
private static final long AC_UNIT = 0x0001L << AC_SHIFT; //活跃线程数增量
private static final long AC_MASK = 0xffffL << AC_SHIFT; //活跃线程数掩码
// 工作线程数
private static final int TC_SHIFT = 32;
private static final long TC_UNIT = 0x0001L << TC_SHIFT; //工作线程数增量
private static final long TC_MASK = 0xffffL << TC_SHIFT; //掩码
private static final long ADD_WORKER = 0x0001L << (TC_SHIFT + 15); // 创建工作线程标志
// 池状态
private static final int RSLOCK = 1;
private static final int RSIGNAL = 1 << 1;
private static final int STARTED = 1 << 2;
private static final int STOP = 1 << 29;
private static final int TERMINATED = 1 << 30;
private static final int SHUTDOWN = 1 << 31;
// 实例字段
volatile long ctl; // 主控制参数
volatile int runState; // 运行状态锁
final int config; // 并行度|模式
int indexSeed; // 用于生成工作线程索引
volatile WorkQueue[] workQueues; // 主对象注册信息,workQueue
final ForkJoinWorkerThreadFactory factory;// 线程工厂
final UncaughtExceptionHandler ueh; // 每个工作线程的异常信息
final String workerNamePrefix; // 用于创建工作线程的名称
volatile AtomicLong stealCounter; // 偷取任务总数,也可作为同步监视器
/** 静态初始化字段 */
//线程工厂
public static final ForkJoinWorkerThreadFactory defaultForkJoinWorkerThreadFactory;
//启动或杀死线程的方法调用者的权限
private static final RuntimePermission modifyThreadPermission;
// 公共静态pool
static final ForkJoinPool common;
//并行度,对应内部common池
static final int commonParallelism;
//备用线程数,在tryCompensate中使用
private static int commonMaxSpares;
//创建workerNamePrefix(工作线程名称前缀)时的序号
private static int poolNumberSequence;
//线程阻塞等待新的任务的超时值(以纳秒为单位),默认2秒
private static final long IDLE_TIMEOUT = 2000L * 1000L * 1000L; // 2sec
//空闲超时时间,防止timer未命中
private static final long TIMEOUT_SLOP = 20L * 1000L * 1000L; // 20ms
//默认备用线程数
private static final int DEFAULT_COMMON_MAX_SPARES = 256;
//阻塞前自旋的次数,用在在awaitRunStateLock和awaitWork中
private static final int SPINS = 0;
//indexSeed的增量
private static final int SEED_INCREMENT = 0x9e3779b9;
ForkJoinPool 对象
// parallelism 定义并行级别
public static ExecutorService newWorkStealingPool(int parallelism);
// 默认并行级别为 JVM 可用的处理器个数
// Runtime.getRuntime().availableProcessors()
public static ExecutorService newWorkStealingPool();
// 类静态代码块中会调用makeCommonPool方法初始化一个commonPool
public static ForkJoinPool commonPool() {
// assert common != null : "static init error";
return common;
}
makeCommonPool()
,最终调用 ForkJoinPool 的构造函数。private ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
int mode,
String workerNamePrefix) {
this.workerNamePrefix = workerNamePrefix;
this.factory = factory;
this.ueh = handler;
this.config = (parallelism & SMASK) | mode;
long np = (long)(-parallelism); // offset ctl counts
this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}
参数说明
类型及其修饰符 | 变量名 | 作用 |
---|---|---|
volatile long | ctl | 主控制参数,分为 4 个区域保存(每 16 位为 1 个区域)。 |
volatile int | runState | 保存线程池的 运行状态。 |
final int | config | 保存线程池的 最大线程数量 及其 是否采用了公平模式。 |
int | indexSeed | 用于在构造 WorkQueue 时计算插入到 workQueues 的下标。 |
volatile WorkQueue[] | workQueues | 线程池持有的工作线程(即执行任务的线程)。 |
final ForkJoinWorkerThreadFactory | factory | 该线程池指定的线程工厂,用于生产 ForkJoinWorkerThread 对象。 |
final String | workerNamePrefix | 该线程池中工作线程的名称前缀。 |
volatile AtomicLong | stealCounter | 该线程池中所有的 WorkQueue 总共被窃取的任务数量。 |
ctl 变量说明
区域 | 属性 | 说明 |
---|---|---|
1 | AC | 正在运行工作线程数减去目标并行度,高 16 位。(49-64 位) |
2 | TC | 总工作线程数减去目标并行度,中高 16 位。(33-48 位) |
3 | SS | 栈顶等待线程的版本计数和状态,中低 16 位。(17-32 位) |
4 | ID | 栈顶 WorkQueue 在池中的索引(poolIndex),低 16 位。(1-16 位) |
sp=(int)ctl
来检查工作线程状态。当 sp 为 0 时说明此刻 没有已经启动但是空闲的线程。线程池状态(runState )说明
// runState bits: SHUTDOWN must be negative, others arbitrary powers of two
private static final int RSLOCK = 1;
private static final int RSIGNAL = 1 << 1;
private static final int STARTED = 1 << 2;
private static final int STOP = 1 << 29;
private static final int TERMINATED = 1 << 30;
private static final int SHUTDOWN = 1 << 31;
private int lockRunState() {
int rs;
return ((((rs = runState) & RSLOCK) != 0 ||
!U.compareAndSwapInt(this, RUNSTATE, rs, rs |= RSLOCK)) ?
awaitRunStateLock() : rs);
}
config 变量
static final int SMASK = 0xffff; // short bits == max index
static final int LIFO_QUEUE = 0;
static final int FIFO_QUEUE = 1 << 16;
this.config = (parallelism & SMASK) | mode;
WorkQueue 对象
WorkQueue
类型及其修饰符 | 变量名 | 作用 |
---|---|---|
volatile int | scanState | 保存这个 WorkQueue 的类型,线程是否繁忙(仅限 ACTIVE 类型)。 |
int | stackPred | 记录前驱 worker 的下标。 |
int | nsteals | 该 WorkQueue 被窃取的任务的总数。 |
int | hint | 用于窃取线程计算下次窃取的 workQueues 数组的下标。 |
int | config | 前 16 位(低 16 位)保存该 WorkQueue 在 workQueues 数组的下标,第 17 位(高 16 位)保存属于 LIFO 还是 FIFO 模式。 |
volatile int | qlock | 一个简单的锁,0 表示为加锁,1 表示已加锁,小于 0 表示当前 WorkQueue 已停止。 |
ForkJoinTask>[] | array | 任务队列,保存 ForkJoinTask 任务对象。 |
volatile int | base | bash 与 workQueues 数组长度取模的值窃取线程下次从 workQueues 数组取出任务的下标。 |
int | top | top 与 workQueues 数组长度取模的值即为下次将任务对象插入到 workQueues 数组的下标。 |
final ForkJoinPool | pool | 该 WorkQueue 对应的线程池。 |
final ForkJoinWorkerThread | owner | 该 WorkQueue 对应的工作线程对象(ACTIVE 类型的 WorkQueue 不会为 null)。 |
volatile Thread | parker | 当 currentThread 被 park(等待)时,用来保存这个线程对象来后续 unpark。 |
ForkJoinTask> | currentJoin | 调用 join 方法时等待结果的任务对象。 |
ForkJoinTask> | currentSteal | 保存正在执行的从别的 WorkQueue 窃取过来的任务。 |
WorkQueue 当前状态(scanState)
static final int SCANNING = 1; // false when running tasks
static final int INACTIVE = 1 << 31; // must be negative
static final int SS_SEQ = 1 << 16;
WorkQueue 锁标识(qlock)
config
static final int MODE_MASK = 0xffff << 16; // top half of int
int mode = config & MODE_MASK;
w.config = i | mode;
top 和 base
int top
不需要使用 volatile。static final int INITIAL_QUEUE_CAPACITY = 1 << 13;
base = top = INITIAL_QUEUE_CAPACITY >>> 1;
stackPred