JDK1.8源码分析:ForkJoin任务递归分解与并行计算框架的设计与用法

概述

  • ForkJoin框架是在JDK1.7推出的,支持将一个大任务递归拆分成多个小任务,然后交给线程池的线程执行任务的并行处理,最后可以获取所有这些任务的执行结果并汇总。
  • 这个框架的主要设计目的就是实现任务的自动化递归拆分,然后交给线程池的线程处理,而不需要在应用代码中实现任务的递归拆分和汇总,简化了多线程环境下,递归拆解任务进行并行执行的编程难度。简单来说就是对递归算法的多线程实现,即支持每个拆分出来的最小单元使用一个线程来处理,实现并行计算。
  • 在任务角度而言,一个大任务拆分成多个小任务,小任务可以继续拆分为更小的任务,依次类推。这些拆分出来的线程被分配到多个工作队列当中等待线程处理,每个工作队列会分配一个工作线程来处理。
  • 在线程池角度而言,ForkJoinPool线程池内的线程会关联这些工作队列,即负责处理关联的工作队列中的任务。同时线程之间支持任务的窃取,即空闲线程可以从其他线程的工作队列窃取任务来处理,从而提高线程的效率。
  • 在线程的任务窃取方面,为了减少线程竞争,工作队列为基于双端队列实现,即工作队列的拥有者线程从队列的头部获取任务,从该队列窃取任务的线程从队列的尾部窃取任务。

核心类

ForkJoinPool线程池

  • ForkJoinPool也是Executor接口的一个实现类:

    • 在ForkJoinPool内部维护了一个工作队列集合workQueues,也就是ForkJoinPool内部的工作线程池,其中集合内的每个工作队列的类型为WorkQueue,在WorkQueue内部包含了一个任务列表,即元素类型为ForkJoinTask的任务列表,同时关联了一个工作线程,工作线程在ForkJoinWorkerThread定义,默认由该工作线程处理该任务列表中的任务,同时由于支持线程的工作窃取,故在实现当中也允许其他空闲的工作线程从该任务列表窃取任务来执行。
  • ForkJoinPool的定义:

    @sun.misc.Contended
    public class ForkJoinPool extends AbstractExecutorService {
    
        public static final ForkJoinWorkerThreadFactory
            defaultForkJoinWorkerThreadFactory;
    
        static final ForkJoinPool common;
        // 默认等于系统处理器数量
        static final int commonParallelism;
    
        private static int commonMaxSpares;
    
    
        // Instance fields
        volatile long ctl;                   // main pool control
        volatile int runState;               // lockable status
        final int config;                    // parallelism, mode
        int indexSeed;                       // to generate worker index
        // 工作队列集合
        volatile WorkQueue[] workQueues;     // main registry
        // 线程创建工厂
        final ForkJoinWorkerThreadFactory factory;
        final UncaughtExceptionHandler ueh;  // per-worker UEH
        final String workerNamePrefix;       // to create worker name string
        volatile AtomicLong stealCounter;    // also used as sync monitor
        
        // 工作队列
        static final class WorkQueue {
        
            ...
            
        }
    }
    
工作队列WorkQueue(工作线程池)
  • 工作队列WorkQueue作为ForkJoinPool的一个内部类实现,定义如下:主要包含一个任务数组来存放任务,一个工作线程来处理任务数组中的任务。

    static final class WorkQueue {
        // 任务数组的初始大小
        static final int INITIAL_QUEUE_CAPACITY = 1 << 13;
        // 任务数组的最大大小
        static final int MAXIMUM_QUEUE_CAPACITY = 1 << 26; // 64M
    
        // Instance fields
        volatile int scanState;    // versioned, <0: inactive; odd:scanning
        int stackPred;             // pool stack (ctl) predecessor
        int nsteals;               // number of steals
        int hint;                  // randomization and stealer index hint
        int config;                // pool index and mode
        volatile int qlock;        // 1: locked, < 0: terminate; else 0
        volatile int base;         // index of next slot for poll
        int top;                   // index of next slot for push
        
        // 任务数组,延迟初始化,大小为2的N次方
        ForkJoinTask<?>[] array;   // the elements (initially unallocated)
        // 线程池引用
        final ForkJoinPool pool;   // the containing pool (may be null)
        // 该队列关联的工作线程
        final ForkJoinWorkerThread owner; // owning thread or null if shared
        volatile Thread parker;    // == owner during call to park; else null
        volatile ForkJoinTask<?> currentJoin;  // task being joined in awaitJoin
        volatile ForkJoinTask<?> currentSteal; // mainly used by helpStealer
        
        WorkQueue(ForkJoinPool pool, ForkJoinWorkerThread owner) {
            this.pool = pool;
            this.owner = owner;
            // Place indices in the center of array (that is not yet allocated)
            base = top = INITIAL_QUEUE_CAPACITY >>> 1;
        }
        
        ...
        
        // 将任务添加到工作队列内部的任务数组的尾部
        final void push(ForkJoinTask<?> task) {
            ForkJoinTask<?>[] a; ForkJoinPool p;
            int b = base, s = top, n;
            if ((a = array) != null) {    // ignore if queue removed
                int m = a.length - 1;     // fenced write for task visibility
                U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
                U.putOrderedInt(this, QTOP, s + 1);
                if ((n = s - b) <= 1) {
                    if ((p = pool) != null)
                        // 通知ForkJoinPool有新任务提交了,如可能当前存在不够的worker,
                        // 则ForkJoinPool会创建新的工作线程来处理任务。
                        p.signalWork(p.workQueues, this);
                }
                else if (n >= m)
                    growArray();
            }
        }
    }
    
工作线程的创建
  • 由以上分析可知,ForkJoinPool的工作线程池就是workQueues,即工作队列集合,即每个工作队列就对应一个工作线程,具体为每个工作队列包含一个工作线程ForkJoinWorkerThread,一个任务数组。所以工作线程的创建其实就是对应工作队列的创建,底层实现方法如下:在tryAddWorker中调用createWorker创建一个工作线程,并在工作线程的构造函数中回调注册到ForkJoinPool的线程池workQueues中。

    // 由ForkJoinPool在当前工作线程不够时,调用该方法创建工作线程
    private void tryAddWorker(long c) {
        boolean add = false;
        do {
            long nc = ((AC_MASK & (c + AC_UNIT)) |
                       (TC_MASK & (c + TC_UNIT)));
            if (ctl == c) {
                int rs, stop;                 // check if terminating
                if ((stop = (rs = lockRunState()) & STOP) == 0)
                    add = U.compareAndSwapLong(this, CTL, c, nc);
                unlockRunState(rs, rs & ~RSLOCK);
                if (stop != 0)
                    break;
                if (add) {
                    // 创建一个工作线程ForkJoinWorkerThread,并在其构造函数回调中,注册到线程池ForkJoinPool
                    createWorker();
                    break;
                }
            }
        } while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
    }
    
    // 创建一个ForkJoinWorkerThread工作线程,并启动该线程
    private boolean createWorker() {
        ForkJoinWorkerThreadFactory fac = factory;
        Throwable ex = null;
        ForkJoinWorkerThread wt = null;
        try {
            // 创建工作线程
            // 将ForkJoinPool自身引用作为构造参数
            if (fac != null && (wt = fac.newThread(this)) != null) {
                wt.start();
                return true;
            }
        } catch (Throwable rex) {
            ex = rex;
        }
        deregisterWorker(wt, ex);
        return false;
    }
    
ForkJoinWorkerThread工作线程
  • 由上面的分析可知,工作线程注册到ForkJoinPool的线程池是在工作线程ForkJoinWorkerThread的构造函数的回调中进行的,如下为ForkJoinWorkerThread的定义:

    public class ForkJoinWorkerThread extends Thread {
    
        // 工作线程池引用
        final ForkJoinPool pool;                // the pool this thread works in
        // 工作队列引用
        final ForkJoinPool.WorkQueue workQueue; // work-stealing mechanics
    
        protected ForkJoinWorkerThread(ForkJoinPool pool) {
            // Use a placeholder until a useful name can be set in registerWorker
            super("aForkJoinWorkerThread");
            this.pool = pool;
            
            // 回调注册到ForkJoinPool的线程池当中
            
            this.workQueue = pool.registerWorker(this);
        }
        
        ...
        
    }
    
  • ForkJoinPool的registerWorker的实现如下:创建一个WorkQueue对象来包装该工作线程,然后添加到ForkJoinPool的工作队列集合workQueues中,即注册到线程池中。

    // 将工作线程注册到线程池ForkJoinPool中
    // 其中会创建一个工作队列WorkQueue来包装这个工作线程
    final WorkQueue registerWorker(ForkJoinWorkerThread wt) {
        UncaughtExceptionHandler handler;
        wt.setDaemon(true);                           // configure thread
        if ((handler = ueh) != null)
            wt.setUncaughtExceptionHandler(handler);
        // 为该工作线程创建一个工作队列,在该工作队列内部保存该工作线程需要处理的任务
        WorkQueue w = new WorkQueue(this, wt);
        int i = 0;                                    // assign a pool index
        int mode = config & MODE_MASK;
        // 加锁
        int rs = lockRunState();
        try {
            WorkQueue[] ws; int n;                    // skip if no array
            if ((ws = workQueues) != null && (n = ws.length) > 0) {
                int s = indexSeed += SEED_INCREMENT;  // unlikely to collide
                int m = n - 1;
                i = ((s << 1) | 1) & m;               // odd-numbered indices
                if (ws[i] != null) {                  // collision
                    int probes = 0;                   // step by approx half n
                    int step = (n <= 4) ? 2 : ((n >>> 1) & EVENMASK) + 2;
                    while (ws[i = (i + step) & m] != null) {
                        if (++probes >= n) {
                            workQueues = ws = Arrays.copyOf(ws, n <<= 1);
                            m = n - 1;
                            probes = 0;
                        }
                    }
                }
                w.hint = s;                           // use as random seed
                w.config = i | mode;
                w.scanState = i;                      // publication fence
    
                // 将该新的工作队列放到工作队列集合
                ws[i] = w;
            }
        } finally {
            unlockRunState(rs, rs & ~RSLOCK);
        }
        wt.setName(workerNamePrefix.concat(Integer.toString(i >>> 1)));
        return w;
    }
    
ForkJoinPool构造函数
  • ForkJoinPool的所有构造函数如下:主要是在应用代码中根据需要创建对应的工作线程池ForkJoinPool实例。

    // 默认构造函数,默认parallelism等于系统处理器数量
    public ForkJoinPool() {
        this(Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors()),
             defaultForkJoinWorkerThreadFactory, null, false);
    }
    
    // 指定并行性parallelism
    public ForkJoinPool(int parallelism) {
        this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
    }
    
    public ForkJoinPool(int parallelism,
                        ForkJoinWorkerThreadFactory factory,
                        UncaughtExceptionHandler handler,
                        boolean asyncMode) {
        this(checkParallelism(parallelism),
             checkFactory(factory),
             handler,
             asyncMode ? FIFO_QUEUE : LIFO_QUEUE,
             "ForkJoinPool-" + nextPoolId() + "-worker-");
        checkPermission();
    }
    
    private ForkJoinPool(int parallelism,
                     ForkJoinWorkerThreadFactory factory,
                     UncaughtExceptionHandler handler,
                     int mode,
                     String workerNamePrefix) {
        this.workerNamePrefix = workerNamePrefix;
        this.factory = factory;
        this.ueh = handler;
        this.config = (parallelism & SMASK) | mode;
        long np = (long)(-parallelism); // offset ctl counts
        this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
    }
    
任务提交
  • ForkJoinPool提供了以下方法来进行任务提交,其中任务可以是普通任务和递归任务。对于返回结果,主要包括需要返回结果,不需要返回结果的版本,其中如果是递归分解任务,则返回的是最终所有子任务汇总的结果,类似于递归函数的返回值。

    // 返回结果为任务最终的计算结果,即所有递归产生的小任务的计算结果的汇总
    public <T> T invoke(ForkJoinTask<T> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
        return task.join();
    }
    
    // 不需要计算结果,递归执行该任务即可
    public void execute(ForkJoinTask<?> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
    }
    
    // 单个普通任务
    public void execute(Runnable task) {
        if (task == null)
            throw new NullPointerException();
        ForkJoinTask<?> job;
        if (task instanceof ForkJoinTask<?>) // avoid re-wrap
            job = (ForkJoinTask<?>) task;
        else
            job = new ForkJoinTask.RunnableExecuteAction(task);
        externalPush(job);
    }
    
    // 返回递归任务自身引用
    public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
        return task;
    }
    
    // 将一个普通任务保证成一个递归任务来执行,并返回该递归任务的引用
    public <T> ForkJoinTask<T> submit(Callable<T> task) {
        ForkJoinTask<T> job = new ForkJoinTask.AdaptedCallable<T>(task);
        externalPush(job);
        return job;
    }
    
    public <T> ForkJoinTask<T> submit(Runnable task, T result) {
        ForkJoinTask<T> job = new ForkJoinTask.AdaptedRunnable<T>(task, result);
        externalPush(job);
        return job;
    }
    
    public ForkJoinTask<?> submit(Runnable task) {
        if (task == null)
            throw new NullPointerException();
        ForkJoinTask<?> job;
        if (task instanceof ForkJoinTask<?>) // avoid re-wrap
            job = (ForkJoinTask<?>) task;
        else
            job = new ForkJoinTask.AdaptedRunnableAction(task);
        externalPush(job);
        return job;
    }
    

ForkJoinTask任务

  • ForkJoin框架中任务分解语义的实现类,提供了任务拆分为子任务,获取子任务执行结果的方法实现。ForkJoinTask为一个抽象类,由应用代码继承该类来定义任务拆分条件等。核心方法为fork,join,invoker和抽象方法exec,其中exec有子类实现来定义该任务的执行逻辑。

    // 支持将该任务分解为更小的任务,汇总各个子任务的结果,提供了任务递归分解的语义实现
    public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
        // 分解子任务
        public final ForkJoinTask<V> fork() {
            Thread t;
            if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
                ((ForkJoinWorkerThread)t).workQueue.push(this);
            else
                ForkJoinPool.common.externalPush(this);
            return this;
        }
    
        // 等待子任务返回结果
        public final V join() {
            int s;
            if ((s = doJoin() & DONE_MASK) != NORMAL)
                reportException(s);
            return getRawResult();
        }
    
        
        // 执行当前任务
        public final V invoke() {
            int s;
            if ((s = doInvoke() & DONE_MASK) != NORMAL)
                reportException(s);
            return getRawResult();
        }
    
        /**
         * Immediately performs the base action of this task and returns
         * true if, upon return from this method, this task is guaranteed
         * to have completed normally. This method may return false
         * otherwise, to indicate that this task is not necessarily
         * complete (or is not known to be complete), for example in
         * asynchronous actions that require explicit invocations of
         * completion methods. This method may also throw an (unchecked)
         * exception to indicate abnormal exit. This method is designed to
         * support extensions, and should not in general be called
         * otherwise.
         *
         * @return {@code true} if this task is known to have completed normally
         */
        // 应用类实现该方法定义任务处理逻辑
        protected abstract boolean exec();
        
        ...
        
    }
    

RecursiveTask递归分解的最小任务

  • RecursiveTask为递归任务拆分体系中的最小任务单元,即在RecursiveTask中不再执行任务分解,而是执行任务的执行计算,生成该最小任务的单元的计算结果。这是一个抽象类,由具体任务实现compute方法来定义任务计算逻辑,类定义如下:

    public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
        private static final long serialVersionUID = 5232453952276485270L;
    
        /**
         * The result of the computation.
         */
        // 最小任务的计算结果
        V result;
    
        /**
         * The main computation performed by this task.
         * @return the result of the computation
         */
        // 抽象方法,由应用代码定义任务计算、处理逻辑,返回计算结果
        protected abstract V compute();
    
        public final V getRawResult() {
            return result;
        }
    
        protected final void setRawResult(V value) {
            result = value;
        }
    
        /**
         * Implements execution conventions for RecursiveTask.
         */
        // ForkJoinPool对每个任务调用的是这个方法
        protected final boolean exec() {
            result = compute();
            return true;
        }
    }
    

使用示例

  • 在使用层面,主要包含工作线程池ForkJoinPool的对象实例的创建,根据业务需求定义任务实现,即继承ForkJoinTask并实现exec方法;如果需要递归拆分任务,则可以继承RecursiveTask并实现其compute方法,定义任务的处理逻辑。
  1. 斐波那契问题实现:f(n) = f(n-1) + f(n-2),n>=2,其中f(0)=0,f(1)=1

  2. 1到n的和:f(n) = 1+2+3+…+n

    /**
     * @author xyz
     * @date 17/2/2019 17:14
     * @description:
     */
    public class ForkJoinTest {
    
        public static void main(String[] args) {
            ForkJoinPool forkJoinPool = new ForkJoinPool();
            // 第8项为21
            FibonacciTask task = new FibonacciTask(8);
            FibonacciTask2 task2 = new FibonacciTask2(8);
    
            int result1 = forkJoinPool.invoke(task);
            int result2 = forkJoinPool.invoke(task2);
            System.out.println(result1 + ":" + result2);
    
            // 计算1到n的和
            SumNums task3 = new SumNums(5);
            int result3 = forkJoinPool.invoke(task3);
            System.out.println(result3);
    		
    		forkJoinPool.shutdown();
        }
        // 斐波那契数列:0、1、1、2、3、5、8、13、21、……
        // 在数学上,斐波纳契数列以如下被以递归的方法定义:F0=0,F1=1,Fn=F(n-1)+F(n-2)(n>=2,n∈N*)
        // 获取斐波那契数量的第n项的值,如第0项为0,第1项为1
        private static class FibonacciTask extends RecursiveTask<Integer> {
            // n为斐波那契数列的第几项
            final int n;
            FibonacciTask(int n) { this.n = n; }
            protected Integer compute() {
                if (n <= 1)
                    return n;
                FibonacciTask f1 = new FibonacciTask(n - 1);
                // fork分解为子任务
                f1.fork();
    
                FibonacciTask f2 = new FibonacciTask(n - 2);
                // f2不再分解
                return f2.compute() + f1.join();
            }
        }
    
        private static class FibonacciTask2 extends RecursiveTask<Integer> {
            // n为斐波那契数列的第几项
            final int n;
            FibonacciTask2(int n) { this.n = n; }
    
            @Override
            protected Integer compute() {
                if (n <= 1)
                    return n;
                FibonacciTask2 f1 = new FibonacciTask2(n - 1);
                // fork分解为子任务
                f1.fork();
                FibonacciTask2 f2 = new FibonacciTask2(n - 2);
                // f2继续分解
                f2.fork();
                return f2.join() + f1.join();
            }
        }
    
        // 累加1到n的顺序数组:f(n) = 1+2+..+n
        private static class SumNums extends RecursiveTask<Integer> {
            private int n;
            private int result;
    
            public SumNums(int n) {
                this.n = n;
            }
    
            @Override
            protected Integer compute() {
                if (n == 0) {
                    return n;
                }
                SumNums s = new SumNums(n-1);
                s.fork();
                return s.join() + n;
            }
        }
    }
    

你可能感兴趣的:(Java,ForkJoin,ForkJoinPool,递归并行计算)