Fork/Join浅探与使用

声明 本文部分文字介绍,直接摘录自《精通Java并发编程(第二版)》, 该书写得通俗易懂、且分析相对透彻,推荐阅读,具体信息见文末。

声明 本文不会介绍具体的方法调用API,但是给出CountedCompleterRecursiveTaskRecursiveAction的简单使用示例。

序言 JDK8开始,提供/优化了很多非常好用的并发组件,如parallelStreamCompletableFutureForkJoin等,本文初步学习ForkJoin。


Fork/Join框架

  • 简介
      Java7并发API引入了Fork/Join框架。该框架基于(Executor的实现类)ForkJoinPool,除了具备基础的Executor功能外,ForkJoinPool主要由fork()方法、join()方法(以及它们的不同变体),以及一个被称作工作窃取算法的内部算法组成。

  • Fork/Join框架的基本特征
      Fork/Join框架主要用于解决基于分治方法的问题。将原始问题拆分为较小的问题,直到问题很小,可以直接解决。即:拆分大问题为小问题,解决小问题并得到一系列结果,归并这些结果得到大问题的结果。
      Fork/Join框架还有一个非常重要的特性——工作窃取算法。当一个任务使用join()方法等待某个子任务结束时,执行该任务的线程将会从任务池中选取另一个正等待被执行的任务去执行。Java8开始,为Fork/Join框架提供了一个默认的执行器ForkJoinPool,可通过ForkJoinPool.commonPool()来获得。

  • Fork/Join框架的局限性

    • 出于性能考虑,不再进行细分的子任务的规模不能太大也不能太小。当一个大任务最终被拆分为100~10000个小任务时,比较适用于Fork/Join框架。
    • 出于性能考虑,不要使用Fork/Join框架来处理IO相关的任务。
    • 在Fork/Join框架中,不要/不能抛出检查异常,如果需要,可以抛出运行时异常,并使用特殊的处理方式进行处理。

Fork/Join框架的组件

Fork/Join框架主要包含四个基本类

Fork/Join浅探与使用_第1张图片

  • ForkJoinPool类 该类实现了Executor接口和ExecutorService接口,而执行Fork/Join任务时将用到Executor接口。Java8开始,提供了一个默认的ForkJoinPool对象作为公用池,但是如果需要,你还可以创建一些构造函数。你可以指定并行处理的任务的最大线程数目。默认情况下,它将使用可用处理器的数目作为最大并发线程数。
    Fork/Join浅探与使用_第2张图片

  • RecursiveTask类 这是一个抽象类,其继承了ForkJoinTask. 这是一个有返回值的task类。RecursiceTask类提供有抽象的compute方法,实际的计算任务逻辑,应该在子类的compute实现方法中完成。

  • RecursiceAction类 这是一个抽象类,其继承了ForkJoinTask. 这是一个无返回值的task类。RecursiceAction类提供有抽象的compute方法,实际的计算任务逻辑,应该在子类的compute实现方法中完成。

  • CountedCompleter类 这是一个抽象类,其继承了ForkJoinTask. 这个类除了有与RecursiceAction类类似的功能外,还主要用于作为触发器,当当前任务的所有子任务全部都已经完成后,会触发当前任务的onComplete()方法,完成当前任务。

    • 注:CountedCompleter采用了和CountDownLatch一样的思路。在实例中维护了一个pending变量(这个变量定义在CountedCompleter中)用于标识计数(其默认值为0)。当我们编写实现类时,在compute方法中,fork出新的子任务时,应该调用addToPendingCount(xxx)增加pengding计数(增加多少由自己的逻辑写法决定);在compute方法中,应该有调用this.tryComplete()的地方(在哪里由自己的逻辑写法决定)。tryComplete方法里,会判断:若pending的值为0(即:倒计时结束),那么会调用onCompletion方法(,此时,如果completer不为null,那么还会判断completer是否为null,若不为null,还会对completer的pending进行相关操作);若pending的值不为0,那么,那么会对pending原子性减一。具体细节可详见java.util.concurrent.CountedCompleter#tryComplete
    • 注:在构造CountedCompleter时,可以传一个CountedCompleter参数作为completer,以辅助b标识任务的完成情况,以及让子任务通过pending判断completer的pending,回调父任务的onCompletion方法)。一般的,我们传递正创建的任务的父任务作为completer(若没有父任务,则传null),当pending为0时,如果completer不为null,那么会将completer标识为完成状态。具体细节可详见java.util.concurrent.CountedCompleter#tryComplete
    • (关键部分)示例:
      Fork/Join浅探与使用_第3张图片
    • 等价于:
      Fork/Join浅探与使用_第4张图片
    • 注:上面两图都没有贴出onCompletion方法实现,这个倒是无所谓,因为业务逻辑不同,onCompletion多半也是不同的。
    • 注:虽然两种写法在效果上是一样的,但是下面的写法会多调用一次tryComplete方法,所以性能上还是推荐使用上面的那种写法。

CountedCompleter、RecursiveTask、RecursiveAction简单使用示例

CountedCompleter实现多线程归并排序算法
  • 啥也不说,直接上代码
    import com.aspire.demo.author.JustryDeng;
    import java.util.concurrent.CountedCompleter;
    
    /**
     * Fork/Join之CountedCompleter实现 多线程归并排序
     *
     * P.S. 好吧,我写的归并算法的实现, 没有把归并算法的最佳性能发挥出来。。。。。。
     *      简单测试发现: 当 数据量处于(0, 1万]时, Collections.sort性能优于MergeSortCompleter
     *                  当 数据量处于(1万, 100万]时, MergeSortCompleter性能优于Collections.sort
     *                  当 数据量处于(100万, 2000万]时, Collections.sort性能优于MergeSortCompleter
     *                  。。。
     *
     * @author {@link JustryDeng}
     * @since 2020/7/9 16:22:33
     */
    @SuppressWarnings("unused")
    public class MergeSortCompleter<T extends Comparable<T>> extends CountedCompleter<Void> {
        
        private final Comparable<T>[] data;
        
        private int startIndex, middleIndex, endIndex;
        
        private final boolean asc;
        
        /**
         * 进行fork的数组长度阈值
         */
        private final int FORK_THRESHOLD;
        
        /**
         * 默认的进行fork的数组长度阈值
         */
        private static final int DEFAULT_FORK_THRESHOLD = 200;
        
        /**
         * @see this#MergeSortCompleter(MergeSortCompleter, Comparable[], int, int, int, boolean)
         */
        public MergeSortCompleter(MergeSortCompleter parent, Comparable<T>[] data, int startIndex, int endIndex) {
            this(parent, data, startIndex, endIndex, DEFAULT_FORK_THRESHOLD, true);
        }
        
        /**
         * @see this#MergeSortCompleter(MergeSortCompleter, Comparable[], int, int, int, boolean)
         */
        public MergeSortCompleter(MergeSortCompleter parent, Comparable<T>[] data, int startIndex, int endIndex, boolean asc) {
            this(parent, data, startIndex, endIndex, DEFAULT_FORK_THRESHOLD, asc);
        }
        
        /**
         * 构造器
         *
         * @param parent
         *            父任务
         * @param data
         *            数据容器
         * @param startIndex
         *            要被排序的数据的起始索引
         * @param endIndex
         *            要被排序的数据的结尾引
         * @param forkThreshold
         *            进行fork的数组长度阈值
         * @param asc
         *            true-升序; false-降序
         */
        public MergeSortCompleter(MergeSortCompleter parent, Comparable<T>[] data,
                                  int startIndex, int endIndex, int forkThreshold, boolean asc) {
            super(parent);
            this.data = data;
            this.startIndex = startIndex;
            this.endIndex = endIndex;
            this.asc = asc;
            FORK_THRESHOLD = forkThreshold;
        }
        
        @Override
        public void compute() {
            // 如果长度>=指定的阈值, 那么fork
            if (endIndex - startIndex >= FORK_THRESHOLD - 1) {
                middleIndex = (endIndex + startIndex) >> 1;
                MergeSortCompleter<T> task1 = new MergeSortCompleter<>(this, data, startIndex, middleIndex, asc);
                MergeSortCompleter<T> task2 = new MergeSortCompleter<>(this, data, middleIndex + 1, endIndex, asc);
                // 对pending进行add操作,必须在fork之前
                this.addToPendingCount(1);
                task1.fork();
                task2.fork();
            // 任务粒度已经足够笑了, 不再fork, 直接进行逻辑处理
            } else {
                // 执行排序
                doSort(data, startIndex, endIndex, asc);
                // 主要逻辑处理完后,调用tryComplete, 使执行onCompletion如果需要的话
                tryComplete();
            }
        }
        
        /**
         * 触发onCompletion逻辑
         *
         * @param caller
         *         触发调用onCompletion方法的对象
         */
        @Override
        public void onCompletion(CountedCompleter<?> caller) {
            // middle == 0 说明没有fork过
            if (middleIndex == 0) {
                return;
            }
            merge(data, startIndex, middleIndex, endIndex, asc);
        }
        
        /// ********************************************** 下面的是归并排序实现
        
        /**
         * 归并排序
         *
         * @param data
         *            数据容器
         * @param start
         *            要被排序的数据的起始索引
         * @param end
         *            要被排序的数据的结尾引
         * @param asc
         *            true-升序; false-降序
         */
        public void doSort(Comparable<T>[] data, int start, int end, boolean asc) {
            if (end - start < 2) {
                return;
            }
            int middle = (end + start) >> 1;
            splitAndMerge(data, start, middle, asc);
            splitAndMerge(data, middle + 1, end, asc);
            merge(data, start, middle, end, asc);
        }
        
        /**
         * (两路)拆分、归并 数组
         *
         * @param originArray
         *         数组
         * @param left
         *         数组的起始元素索引
         * @param right
         *         数组的结尾元素索引
         * @param asc
         *         升序/降序。 true-升序; false-降序
         */
        public void splitAndMerge(Comparable<T>[] originArray, int left, int right, boolean asc) {
            // 中间那个数的索引
            int middle = (left + right) >> 1;
            /*
             * 当目标区域要只有一个元素时,不再进行拆分
             *
             * 已知originArray长度大于0, 这里简单数学证明: 当middle = right时,originArray长度为1
             * ∵ middle = (left + right) / 2 且 middle = right
             * ∴ right = (left + right) / 2
             * ∴ 2 * right = left + right
             * ∴ right = left
             * ∴ right = left
             * ∴ originArray长度为1
             */
            if (middle == right) {
                return;
            }
            // 二叉树【前序遍历】, 再次进行拆分
            splitAndMerge(originArray, left, middle, asc);
            splitAndMerge(originArray, middle + 1, right, asc);
            // 合并
            merge(originArray, left, middle, right, asc);
        }
        
        /**
         * 归并两个有序的数组
         *
         * @param originArray
         *         数组。 注:该数组由两个紧邻的 有序数组组成
         * @param left
         *         要归并的第一个数组的起始元素索引
         * @param middle
         *         要归并的第一个数组的结尾元素索引
         * @param right
         *         要归并的第二个数组的结尾元素索引 注:要合并的第二个数组的结尾元素索引为middle + 1
         * @param asc
         *         升序/降序。 true-升序; false-降序
         */
        @SuppressWarnings("unchecked")
        private void merge(Comparable<T>[] originArray, int left, int middle, int right, boolean asc) {
            Comparable<T>[] tmpArray = new Comparable[right - left + 1];
            int i = left, j = middle + 1, tmpIndex = 0;
            int result;
            // 循环比较, 直至其中一个数组所有元素 拷贝至 tmpArray
            while (i <= middle && j <= right) {
                result = originArray[i].compareTo((T) originArray[j]);
                // 控制升序降序
                boolean ascFlag = asc ? result <= 0 : result >= 0;
                if (ascFlag) {
                    tmpArray[tmpIndex] = originArray[i];
                    i++;
                } else {
                    tmpArray[tmpIndex] = originArray[j];
                    j++;
                }
                tmpIndex++;
            }
            // 将剩余那个没拷贝完的数组中剩余的元素 拷贝至 tmpArray
            while (i <= middle) {
                tmpArray[tmpIndex] = originArray[i];
                i++;
                tmpIndex++;
            }
            while (j <= right) {
                tmpArray[tmpIndex] = originArray[j];
                j++;
                tmpIndex++;
            }
            // 将临时数组中的元素按顺序拷贝至originArray
            System.arraycopy(tmpArray, 0, originArray, left, tmpArray.length);
        }
        
    }
    
  • 测试一下
    • 编写一个简单的测试类
      Fork/Join浅探与使用_第5张图片
    • 运行main方法,程序输出
      在这里插入图片描述
RecursiveTask使用示例
  • 封装一个RecursiveTask抽象模板
    import com.aspire.demo.author.JustryDeng;
    import org.springframework.util.Assert;
    
    import java.util.Collection;
    import java.util.List;
    import java.util.concurrent.ForkJoinPool;
    import java.util.concurrent.ForkJoinTask;
    import java.util.concurrent.RecursiveTask;
    import java.util.stream.Collectors;
    
    /**
     * 定义抽象模板,使用RecursiveTask
     *
     * 
      *
    • P: 参数泛型
    • *
    • R: 结果泛型
    • *
    * * @author {@link JustryDeng} * @since 2020/7/30 19:28:12 */
    @SuppressWarnings("unused") public abstract class AbstractRecursiveTask<P, R> extends RecursiveTask<R> { /** if non-null, to use it */ protected final ForkJoinPool forkJoinPool; /** * 源数据 * * P.S. 本次分析的范围为 [lowerLimitIndex, upperLimitIndex) */ protected final P[] originDataArray; /** 当前RecursiveTask要分析的数据范围的下限位置,(含lowerLimitIndex对应的元素) */ protected final int lowerLimitIndex; /** 当前RecursiveTask要分析的数据范围的上限位置,(不含lowerLimitIndex对应的元素) */ protected final int upperLimitIndex; /** 触发进行任务拆分的阈值 */ protected final int triggerForkSize; /** 默认的触发进行任务拆分的阈值 */ private static final int DEFAULT_TRIG_FORK_SIZE = 2; public AbstractRecursiveTask(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex) { this(originDataArray, lowerLimitIndex, upperLimitIndex, DEFAULT_TRIG_FORK_SIZE, null); } public AbstractRecursiveTask(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex, int triggerForkSize, ForkJoinPool forkJoinPool) { Assert.notNull(originDataArray, "originDataArray cannot be null"); Assert.isTrue(upperLimitIndex > lowerLimitIndex, "upperLimitIndex must great-than lowerLimitIndex, but curr upperLimitIndex is -> " + lowerLimitIndex + ", curr lowerLimitIndex is -> " + lowerLimitIndex); Assert.isTrue(triggerForkSize > 1, "triggerForkSize must great-than 1, but curr triggerForkSize is -> " + triggerForkSize); this.originDataArray = originDataArray; this.lowerLimitIndex = lowerLimitIndex; this.upperLimitIndex = upperLimitIndex; this.triggerForkSize = triggerForkSize; this.forkJoinPool = forkJoinPool; } @Override protected R compute() { // -> 如果不需要拆分, 那么直接计算 if (shouldComputeDirectly()) { return this.computeDirectly(originDataArray, lowerLimitIndex, upperLimitIndex); } // -> 如果需要任务拆分 // map (任务-拆) List<ForkJoinTask<? extends R>> tasks = this.mapTask(); Collection<ForkJoinTask<? extends R>> forkJoinTasks; if (forkJoinPool == null) { forkJoinTasks = invokeAll(tasks); } else { forkJoinTasks = tasks.stream().peek(forkJoinPool::submit).collect(Collectors.toList()); } List<R> resultList = forkJoinTasks.stream().map(ForkJoinTask::join).collect(Collectors.toList()); // reduce (结果-并) return this.reduceResult(resultList); } /** * 是否应该直接计算 * * @return 是否应该直接计算 */ protected boolean shouldComputeDirectly() { return upperLimitIndex - lowerLimitIndex <= triggerForkSize; } /** * 直接计算结果 * * @param originDataArray * 源数据 * @param lowerLimitIndex * 当前RecursiveTask要分析的数据范围的下限位置,(含lowerLimitIndex对应的元素) * @param upperLimitIndex * 当前RecursiveTask要分析的数据范围的上限位置,(不含lowerLimitIndex对应的元素) * @return 计算结果 */ protected abstract R computeDirectly(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex); /** * 将当前大任务拆分为一个一个小任务 * * @return 拆分出来的小任务 */ protected abstract List<ForkJoinTask<? extends R>> mapTask(); /** * 将所有结果进行合并并返回 * * @param resultList * 要进行合并处理的结果集 * @return 所有任务结果合并后的总结果 */ protected abstract R reduceResult(List<R> resultList); }
  • 简单实现抽象模板
    import com.aspire.demo.author.JustryDeng;
    
    import java.util.*;
    import java.util.concurrent.ForkJoinTask;
    
    /**
     * 简单实现AbstractRecursiveTask
     *
     * @author {@link JustryDeng}
     * @since 2020/7/30 20:13:35
     */
    public class DemoRecursiveTask extends AbstractRecursiveTask<Integer, Integer[]> {
        
        public DemoRecursiveTask(Integer[] originDataArray, int lowerLimitIndex, int upperLimitIndex) {
            super(originDataArray, lowerLimitIndex, upperLimitIndex);
        }
    
        @Override
        protected Integer[] computeDirectly(Integer[] originDataArray, int lowerLimitIndex, int upperLimitIndex) {
            Set<Integer> tmpSet = new HashSet<>();
            Integer item;
            for (int i = lowerLimitIndex; i < upperLimitIndex; i++) {
                item = originDataArray[i];
                if (item == null) {
                    continue;
                }
                // 算闰年
                if (item % 4 == 0 && item % 100 != 0) {
                    tmpSet.add(item);
                } else if (item % 400 == 0) {
                    tmpSet.add(item);
                }
            }
            return tmpSet.toArray(new Integer[0]);
        }
        
        @Override
        protected List<ForkJoinTask<? extends Integer[]>> mapTask() {
            int middleIndex = (upperLimitIndex + lowerLimitIndex) / 2;
            DemoRecursiveTask taskOne = new DemoRecursiveTask(originDataArray, lowerLimitIndex, middleIndex);
            DemoRecursiveTask taskTwo = new DemoRecursiveTask(originDataArray, middleIndex, upperLimitIndex);
            List<ForkJoinTask<? extends Integer[]>> list = new ArrayList<>(2);
            list.add(taskOne);
            list.add(taskTwo);
            return list;
        }
        
        @Override
        protected Integer[] reduceResult(List<Integer[]> resultList) {
            Set<Integer> tmpSet = new HashSet<>();
            resultList.forEach(x -> tmpSet.addAll(Arrays.asList(x)));
            return tmpSet.toArray(new Integer[0]);
        }
        
    }
    
  • 测试一下
    Fork/Join浅探与使用_第6张图片
RecursiveAction使用示例
  • 封装一个RecursiveTask抽象模板

    import com.aspire.demo.author.JustryDeng;
    import org.springframework.util.Assert;
    
    import java.util.Collection;
    import java.util.List;
    import java.util.concurrent.ForkJoinPool;
    import java.util.concurrent.ForkJoinTask;
    import java.util.concurrent.RecursiveAction;
    import java.util.stream.Collectors;
    
    /**
     * 定义抽象模板,使用RecursiveAction
     *
     * 
      *
    • P: 参数泛型
    • *
    * * @author {@link JustryDeng} * @since 2020/7/30 19:28:12 */
    @SuppressWarnings("unused") public abstract class AbstractRecursiveAction<P> extends RecursiveAction { /** if non-null, to use it */ protected final ForkJoinPool forkJoinPool; /** * 源数据 * * P.S. 本次分析的范围为 [lowerLimitIndex, upperLimitIndex) */ protected final P[] originDataArray; /** 当前RecursiveAction要分析的数据范围的下限位置,(含lowerLimitIndex对应的元素) */ protected final int lowerLimitIndex; /** 当前RecursiveAction要分析的数据范围的上限位置,(不含lowerLimitIndex对应的元素) */ protected final int upperLimitIndex; /** 触发进行任务拆分的阈值 */ protected final int triggerForkSize; /** 默认的触发进行任务拆分的阈值 */ private static final int DEFAULT_TRIG_FORK_SIZE = 2; public AbstractRecursiveAction(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex) { this(originDataArray, lowerLimitIndex, upperLimitIndex, DEFAULT_TRIG_FORK_SIZE, null); } public AbstractRecursiveAction(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex, int triggerForkSize, ForkJoinPool forkJoinPool) { Assert.notNull(originDataArray, "originDataArray cannot be null"); Assert.isTrue(upperLimitIndex > lowerLimitIndex, "upperLimitIndex must great-than lowerLimitIndex, but curr upperLimitIndex is -> " + lowerLimitIndex + ", curr lowerLimitIndex is -> " + lowerLimitIndex); Assert.isTrue(triggerForkSize > 1, "triggerForkSize must great-than 1, but curr triggerForkSize is -> " + triggerForkSize); this.originDataArray = originDataArray; this.lowerLimitIndex = lowerLimitIndex; this.upperLimitIndex = upperLimitIndex; this.triggerForkSize = triggerForkSize; this.forkJoinPool = forkJoinPool; } @Override protected void compute() { // -> 如果不需要拆分, 那么直接计算 if (shouldComputeDirectly()) { this.computeDirectly(originDataArray, lowerLimitIndex, upperLimitIndex); return; } // -> 如果需要任务拆分 // map (任务-拆) List<ForkJoinTask<Void>> tasks = this.mapTask(); Collection<ForkJoinTask<Void>> forkJoinTasks; if (forkJoinPool == null) { forkJoinTasks = invokeAll(tasks); } else { forkJoinTasks = tasks.stream().peek(forkJoinPool::submit).collect(Collectors.toList()); } forkJoinTasks.forEach(ForkJoinTask::join); } /** * 是否应该直接计算 * * @return 是否应该直接计算 */ protected boolean shouldComputeDirectly() { return upperLimitIndex - lowerLimitIndex <= triggerForkSize; } /** * 直接计算结果 * * @param originDataArray * 源数据 * @param lowerLimitIndex * 当前RecursiveAction要分析的数据范围的下限位置,(含lowerLimitIndex对应的元素) * @param upperLimitIndex * 当前RecursiveAction要分析的数据范围的上限位置,(不含lowerLimitIndex对应的元素) */ protected abstract void computeDirectly(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex); /** * 将当前大任务拆分为一个一个小任务 * * @return 拆分出来的小任务 */ protected abstract List<ForkJoinTask<Void>> mapTask(); }
  • 简单实现抽象模板

    import com.aspire.demo.author.JustryDeng;
    
    import java.util.ArrayList;
    import java.util.List;
    import java.util.concurrent.ForkJoinTask;
    
    /**
     * 简单实现AbstractRecursiveAction
     *
     * @author {@link JustryDeng}
     * @since 2020/7/31 12:31:44
     */
    public class DemoRecursiveAction extends AbstractRecursiveAction<Character> {
        
        public DemoRecursiveAction(Character[] originDataArray, int lowerLimitIndex, int upperLimitIndex) {
            super(originDataArray, lowerLimitIndex, upperLimitIndex);
        }
        
        @Override
        protected void computeDirectly(Character[] originDataArray, int lowerLimitIndex, int upperLimitIndex) {
            for (int i = lowerLimitIndex; i < upperLimitIndex; i++) {
                System.err.println(Thread.currentThread().getName() + "\t" + originDataArray[i]);
            }
        }
        
        @Override
        protected List<ForkJoinTask<Void>> mapTask() {
            int middleIndex = (upperLimitIndex + lowerLimitIndex) / 2;
            DemoRecursiveAction taskOne = new DemoRecursiveAction(originDataArray, lowerLimitIndex, middleIndex);
            DemoRecursiveAction taskTwo = new DemoRecursiveAction(originDataArray, middleIndex, upperLimitIndex);
            List<ForkJoinTask<Void>> list = new ArrayList<>(2);
            list.add(taskOne);
            list.add(taskTwo);
            return list;
        }
    }
    
  • 测试一下
    Fork/Join浅探与使用_第7张图片


^_^ 如有不当之处,欢迎指正

^_^ 参考链接
         https://www.jianshu.com/p/42e9cd16f705

^_^ 参考资料
        《精通Java并发编程(第二版)》
[西]哈维尔·费尔南德斯·冈萨雷斯 著,唐富年译

^_^ 本文已经被收录进《程序员成长笔记(三)》,笔者JustryDeng

你可能感兴趣的:(多线程与高并发)