声明: 本文部分文字介绍,直接摘录自《精通Java并发编程(第二版)》
, 该书写得通俗易懂、且分析相对透彻,推荐阅读,具体信息见文末。
声明: 本文不会介绍具体的方法调用API,但是给出CountedCompleter
、RecursiveTask
、RecursiveAction
的简单使用示例。
序言: JDK8开始,提供/优化了很多非常好用的并发组件,如parallelStream
、CompletableFuture
、ForkJoin
等,本文初步学习ForkJoin。
简介
Java7并发API引入了Fork/Join框架。该框架基于(Executor的实现类)ForkJoinPool,除了具备基础的Executor功能外,ForkJoinPool主要由fork()方法、join()方法(以及它们的不同变体),以及一个被称作工作窃取算法的内部算法组成。
Fork/Join框架的基本特征
Fork/Join框架主要用于解决基于分治方法的问题。将原始问题拆分为较小的问题,直到问题很小,可以直接解决。即:拆分大问题为小问题,解决小问题并得到一系列结果,归并这些结果得到大问题的结果。
Fork/Join框架还有一个非常重要的特性——工作窃取算法。当一个任务使用join()方法等待某个子任务结束时,执行该任务的线程将会从任务池中选取另一个正等待被执行的任务去执行。Java8开始,为Fork/Join框架提供了一个默认的执行器ForkJoinPool,可通过ForkJoinPool.commonPool()来获得。
Fork/Join框架的局限性
ForkJoinPool类: 该类实现了Executor接口和ExecutorService接口,而执行Fork/Join任务时将用到Executor接口。Java8开始,提供了一个默认的ForkJoinPool对象作为公用池,但是如果需要,你还可以创建一些构造函数。你可以指定并行处理的任务的最大线程数目。默认情况下,它将使用可用处理器的数目作为最大并发线程数。
RecursiveTask类: 这是一个抽象类,其继承了ForkJoinTask. 这是一个有返回值的task类。RecursiceTask类提供有抽象的compute方法,实际的计算任务逻辑,应该在子类的compute实现方法中完成。
RecursiceAction类: 这是一个抽象类,其继承了ForkJoinTask. 这是一个无返回值的task类。RecursiceAction类提供有抽象的compute方法,实际的计算任务逻辑,应该在子类的compute实现方法中完成。
CountedCompleter类: 这是一个抽象类,其继承了ForkJoinTask. 这个类除了有与RecursiceAction类类似的功能外,还主要用于作为触发器,当当前任务的所有子任务全部都已经完成后,会触发当前任务的onComplete()方法,完成当前任务。
java.util.concurrent.CountedCompleter#tryComplete
。java.util.concurrent.CountedCompleter#tryComplete
。import com.aspire.demo.author.JustryDeng;
import java.util.concurrent.CountedCompleter;
/**
* Fork/Join之CountedCompleter实现 多线程归并排序
*
* P.S. 好吧,我写的归并算法的实现, 没有把归并算法的最佳性能发挥出来。。。。。。
* 简单测试发现: 当 数据量处于(0, 1万]时, Collections.sort性能优于MergeSortCompleter
* 当 数据量处于(1万, 100万]时, MergeSortCompleter性能优于Collections.sort
* 当 数据量处于(100万, 2000万]时, Collections.sort性能优于MergeSortCompleter
* 。。。
*
* @author {@link JustryDeng}
* @since 2020/7/9 16:22:33
*/
@SuppressWarnings("unused")
public class MergeSortCompleter<T extends Comparable<T>> extends CountedCompleter<Void> {
private final Comparable<T>[] data;
private int startIndex, middleIndex, endIndex;
private final boolean asc;
/**
* 进行fork的数组长度阈值
*/
private final int FORK_THRESHOLD;
/**
* 默认的进行fork的数组长度阈值
*/
private static final int DEFAULT_FORK_THRESHOLD = 200;
/**
* @see this#MergeSortCompleter(MergeSortCompleter, Comparable[], int, int, int, boolean)
*/
public MergeSortCompleter(MergeSortCompleter parent, Comparable<T>[] data, int startIndex, int endIndex) {
this(parent, data, startIndex, endIndex, DEFAULT_FORK_THRESHOLD, true);
}
/**
* @see this#MergeSortCompleter(MergeSortCompleter, Comparable[], int, int, int, boolean)
*/
public MergeSortCompleter(MergeSortCompleter parent, Comparable<T>[] data, int startIndex, int endIndex, boolean asc) {
this(parent, data, startIndex, endIndex, DEFAULT_FORK_THRESHOLD, asc);
}
/**
* 构造器
*
* @param parent
* 父任务
* @param data
* 数据容器
* @param startIndex
* 要被排序的数据的起始索引
* @param endIndex
* 要被排序的数据的结尾引
* @param forkThreshold
* 进行fork的数组长度阈值
* @param asc
* true-升序; false-降序
*/
public MergeSortCompleter(MergeSortCompleter parent, Comparable<T>[] data,
int startIndex, int endIndex, int forkThreshold, boolean asc) {
super(parent);
this.data = data;
this.startIndex = startIndex;
this.endIndex = endIndex;
this.asc = asc;
FORK_THRESHOLD = forkThreshold;
}
@Override
public void compute() {
// 如果长度>=指定的阈值, 那么fork
if (endIndex - startIndex >= FORK_THRESHOLD - 1) {
middleIndex = (endIndex + startIndex) >> 1;
MergeSortCompleter<T> task1 = new MergeSortCompleter<>(this, data, startIndex, middleIndex, asc);
MergeSortCompleter<T> task2 = new MergeSortCompleter<>(this, data, middleIndex + 1, endIndex, asc);
// 对pending进行add操作,必须在fork之前
this.addToPendingCount(1);
task1.fork();
task2.fork();
// 任务粒度已经足够笑了, 不再fork, 直接进行逻辑处理
} else {
// 执行排序
doSort(data, startIndex, endIndex, asc);
// 主要逻辑处理完后,调用tryComplete, 使执行onCompletion如果需要的话
tryComplete();
}
}
/**
* 触发onCompletion逻辑
*
* @param caller
* 触发调用onCompletion方法的对象
*/
@Override
public void onCompletion(CountedCompleter<?> caller) {
// middle == 0 说明没有fork过
if (middleIndex == 0) {
return;
}
merge(data, startIndex, middleIndex, endIndex, asc);
}
/// ********************************************** 下面的是归并排序实现
/**
* 归并排序
*
* @param data
* 数据容器
* @param start
* 要被排序的数据的起始索引
* @param end
* 要被排序的数据的结尾引
* @param asc
* true-升序; false-降序
*/
public void doSort(Comparable<T>[] data, int start, int end, boolean asc) {
if (end - start < 2) {
return;
}
int middle = (end + start) >> 1;
splitAndMerge(data, start, middle, asc);
splitAndMerge(data, middle + 1, end, asc);
merge(data, start, middle, end, asc);
}
/**
* (两路)拆分、归并 数组
*
* @param originArray
* 数组
* @param left
* 数组的起始元素索引
* @param right
* 数组的结尾元素索引
* @param asc
* 升序/降序。 true-升序; false-降序
*/
public void splitAndMerge(Comparable<T>[] originArray, int left, int right, boolean asc) {
// 中间那个数的索引
int middle = (left + right) >> 1;
/*
* 当目标区域要只有一个元素时,不再进行拆分
*
* 已知originArray长度大于0, 这里简单数学证明: 当middle = right时,originArray长度为1
* ∵ middle = (left + right) / 2 且 middle = right
* ∴ right = (left + right) / 2
* ∴ 2 * right = left + right
* ∴ right = left
* ∴ right = left
* ∴ originArray长度为1
*/
if (middle == right) {
return;
}
// 二叉树【前序遍历】, 再次进行拆分
splitAndMerge(originArray, left, middle, asc);
splitAndMerge(originArray, middle + 1, right, asc);
// 合并
merge(originArray, left, middle, right, asc);
}
/**
* 归并两个有序的数组
*
* @param originArray
* 数组。 注:该数组由两个紧邻的 有序数组组成
* @param left
* 要归并的第一个数组的起始元素索引
* @param middle
* 要归并的第一个数组的结尾元素索引
* @param right
* 要归并的第二个数组的结尾元素索引 注:要合并的第二个数组的结尾元素索引为middle + 1
* @param asc
* 升序/降序。 true-升序; false-降序
*/
@SuppressWarnings("unchecked")
private void merge(Comparable<T>[] originArray, int left, int middle, int right, boolean asc) {
Comparable<T>[] tmpArray = new Comparable[right - left + 1];
int i = left, j = middle + 1, tmpIndex = 0;
int result;
// 循环比较, 直至其中一个数组所有元素 拷贝至 tmpArray
while (i <= middle && j <= right) {
result = originArray[i].compareTo((T) originArray[j]);
// 控制升序降序
boolean ascFlag = asc ? result <= 0 : result >= 0;
if (ascFlag) {
tmpArray[tmpIndex] = originArray[i];
i++;
} else {
tmpArray[tmpIndex] = originArray[j];
j++;
}
tmpIndex++;
}
// 将剩余那个没拷贝完的数组中剩余的元素 拷贝至 tmpArray
while (i <= middle) {
tmpArray[tmpIndex] = originArray[i];
i++;
tmpIndex++;
}
while (j <= right) {
tmpArray[tmpIndex] = originArray[j];
j++;
tmpIndex++;
}
// 将临时数组中的元素按顺序拷贝至originArray
System.arraycopy(tmpArray, 0, originArray, left, tmpArray.length);
}
}
import com.aspire.demo.author.JustryDeng;
import org.springframework.util.Assert;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
import java.util.stream.Collectors;
/**
* 定义抽象模板,使用RecursiveTask>
*
*
* - P: 参数泛型
* - R: 结果泛型
*
*
* @author {@link JustryDeng}
* @since 2020/7/30 19:28:12
*/
@SuppressWarnings("unused")
public abstract class AbstractRecursiveTask<P, R> extends RecursiveTask<R> {
/** if non-null, to use it */
protected final ForkJoinPool forkJoinPool;
/**
* 源数据
*
* P.S. 本次分析的范围为 [lowerLimitIndex, upperLimitIndex)
*/
protected final P[] originDataArray;
/** 当前RecursiveTask要分析的数据范围的下限位置,(含lowerLimitIndex对应的元素) */
protected final int lowerLimitIndex;
/** 当前RecursiveTask要分析的数据范围的上限位置,(不含lowerLimitIndex对应的元素) */
protected final int upperLimitIndex;
/** 触发进行任务拆分的阈值 */
protected final int triggerForkSize;
/** 默认的触发进行任务拆分的阈值 */
private static final int DEFAULT_TRIG_FORK_SIZE = 2;
public AbstractRecursiveTask(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex) {
this(originDataArray, lowerLimitIndex, upperLimitIndex, DEFAULT_TRIG_FORK_SIZE, null);
}
public AbstractRecursiveTask(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex, int triggerForkSize,
ForkJoinPool forkJoinPool) {
Assert.notNull(originDataArray, "originDataArray cannot be null");
Assert.isTrue(upperLimitIndex > lowerLimitIndex, "upperLimitIndex must great-than lowerLimitIndex, but curr upperLimitIndex is -> "
+ lowerLimitIndex + ", curr lowerLimitIndex is -> " + lowerLimitIndex);
Assert.isTrue(triggerForkSize > 1, "triggerForkSize must great-than 1, but curr triggerForkSize is -> " + triggerForkSize);
this.originDataArray = originDataArray;
this.lowerLimitIndex = lowerLimitIndex;
this.upperLimitIndex = upperLimitIndex;
this.triggerForkSize = triggerForkSize;
this.forkJoinPool = forkJoinPool;
}
@Override
protected R compute() {
// -> 如果不需要拆分, 那么直接计算
if (shouldComputeDirectly()) {
return this.computeDirectly(originDataArray, lowerLimitIndex, upperLimitIndex);
}
// -> 如果需要任务拆分
// map (任务-拆)
List<ForkJoinTask<? extends R>> tasks = this.mapTask();
Collection<ForkJoinTask<? extends R>> forkJoinTasks;
if (forkJoinPool == null) {
forkJoinTasks = invokeAll(tasks);
} else {
forkJoinTasks = tasks.stream().peek(forkJoinPool::submit).collect(Collectors.toList());
}
List<R> resultList = forkJoinTasks.stream().map(ForkJoinTask::join).collect(Collectors.toList());
// reduce (结果-并)
return this.reduceResult(resultList);
}
/**
* 是否应该直接计算
*
* @return 是否应该直接计算
*/
protected boolean shouldComputeDirectly() {
return upperLimitIndex - lowerLimitIndex <= triggerForkSize;
}
/**
* 直接计算结果
*
* @param originDataArray
* 源数据
* @param lowerLimitIndex
* 当前RecursiveTask要分析的数据范围的下限位置,(含lowerLimitIndex对应的元素)
* @param upperLimitIndex
* 当前RecursiveTask要分析的数据范围的上限位置,(不含lowerLimitIndex对应的元素)
* @return 计算结果
*/
protected abstract R computeDirectly(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex);
/**
* 将当前大任务拆分为一个一个小任务
*
* @return 拆分出来的小任务
*/
protected abstract List<ForkJoinTask<? extends R>> mapTask();
/**
* 将所有结果进行合并并返回
*
* @param resultList
* 要进行合并处理的结果集
* @return 所有任务结果合并后的总结果
*/
protected abstract R reduceResult(List<R> resultList);
}
import com.aspire.demo.author.JustryDeng;
import java.util.*;
import java.util.concurrent.ForkJoinTask;
/**
* 简单实现AbstractRecursiveTask
*
* @author {@link JustryDeng}
* @since 2020/7/30 20:13:35
*/
public class DemoRecursiveTask extends AbstractRecursiveTask<Integer, Integer[]> {
public DemoRecursiveTask(Integer[] originDataArray, int lowerLimitIndex, int upperLimitIndex) {
super(originDataArray, lowerLimitIndex, upperLimitIndex);
}
@Override
protected Integer[] computeDirectly(Integer[] originDataArray, int lowerLimitIndex, int upperLimitIndex) {
Set<Integer> tmpSet = new HashSet<>();
Integer item;
for (int i = lowerLimitIndex; i < upperLimitIndex; i++) {
item = originDataArray[i];
if (item == null) {
continue;
}
// 算闰年
if (item % 4 == 0 && item % 100 != 0) {
tmpSet.add(item);
} else if (item % 400 == 0) {
tmpSet.add(item);
}
}
return tmpSet.toArray(new Integer[0]);
}
@Override
protected List<ForkJoinTask<? extends Integer[]>> mapTask() {
int middleIndex = (upperLimitIndex + lowerLimitIndex) / 2;
DemoRecursiveTask taskOne = new DemoRecursiveTask(originDataArray, lowerLimitIndex, middleIndex);
DemoRecursiveTask taskTwo = new DemoRecursiveTask(originDataArray, middleIndex, upperLimitIndex);
List<ForkJoinTask<? extends Integer[]>> list = new ArrayList<>(2);
list.add(taskOne);
list.add(taskTwo);
return list;
}
@Override
protected Integer[] reduceResult(List<Integer[]> resultList) {
Set<Integer> tmpSet = new HashSet<>();
resultList.forEach(x -> tmpSet.addAll(Arrays.asList(x)));
return tmpSet.toArray(new Integer[0]);
}
}
封装一个RecursiveTask抽象模板
import com.aspire.demo.author.JustryDeng;
import org.springframework.util.Assert;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveAction;
import java.util.stream.Collectors;
/**
* 定义抽象模板,使用RecursiveAction
*
*
* - P: 参数泛型
*
*
* @author {@link JustryDeng}
* @since 2020/7/30 19:28:12
*/
@SuppressWarnings("unused")
public abstract class AbstractRecursiveAction<P> extends RecursiveAction {
/** if non-null, to use it */
protected final ForkJoinPool forkJoinPool;
/**
* 源数据
*
* P.S. 本次分析的范围为 [lowerLimitIndex, upperLimitIndex)
*/
protected final P[] originDataArray;
/** 当前RecursiveAction要分析的数据范围的下限位置,(含lowerLimitIndex对应的元素) */
protected final int lowerLimitIndex;
/** 当前RecursiveAction要分析的数据范围的上限位置,(不含lowerLimitIndex对应的元素) */
protected final int upperLimitIndex;
/** 触发进行任务拆分的阈值 */
protected final int triggerForkSize;
/** 默认的触发进行任务拆分的阈值 */
private static final int DEFAULT_TRIG_FORK_SIZE = 2;
public AbstractRecursiveAction(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex) {
this(originDataArray, lowerLimitIndex, upperLimitIndex, DEFAULT_TRIG_FORK_SIZE, null);
}
public AbstractRecursiveAction(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex, int triggerForkSize,
ForkJoinPool forkJoinPool) {
Assert.notNull(originDataArray, "originDataArray cannot be null");
Assert.isTrue(upperLimitIndex > lowerLimitIndex, "upperLimitIndex must great-than lowerLimitIndex, but curr upperLimitIndex is -> "
+ lowerLimitIndex + ", curr lowerLimitIndex is -> " + lowerLimitIndex);
Assert.isTrue(triggerForkSize > 1, "triggerForkSize must great-than 1, but curr triggerForkSize is -> " + triggerForkSize);
this.originDataArray = originDataArray;
this.lowerLimitIndex = lowerLimitIndex;
this.upperLimitIndex = upperLimitIndex;
this.triggerForkSize = triggerForkSize;
this.forkJoinPool = forkJoinPool;
}
@Override
protected void compute() {
// -> 如果不需要拆分, 那么直接计算
if (shouldComputeDirectly()) {
this.computeDirectly(originDataArray, lowerLimitIndex, upperLimitIndex);
return;
}
// -> 如果需要任务拆分
// map (任务-拆)
List<ForkJoinTask<Void>> tasks = this.mapTask();
Collection<ForkJoinTask<Void>> forkJoinTasks;
if (forkJoinPool == null) {
forkJoinTasks = invokeAll(tasks);
} else {
forkJoinTasks = tasks.stream().peek(forkJoinPool::submit).collect(Collectors.toList());
}
forkJoinTasks.forEach(ForkJoinTask::join);
}
/**
* 是否应该直接计算
*
* @return 是否应该直接计算
*/
protected boolean shouldComputeDirectly() {
return upperLimitIndex - lowerLimitIndex <= triggerForkSize;
}
/**
* 直接计算结果
*
* @param originDataArray
* 源数据
* @param lowerLimitIndex
* 当前RecursiveAction要分析的数据范围的下限位置,(含lowerLimitIndex对应的元素)
* @param upperLimitIndex
* 当前RecursiveAction要分析的数据范围的上限位置,(不含lowerLimitIndex对应的元素)
*/
protected abstract void computeDirectly(P[] originDataArray, int lowerLimitIndex, int upperLimitIndex);
/**
* 将当前大任务拆分为一个一个小任务
*
* @return 拆分出来的小任务
*/
protected abstract List<ForkJoinTask<Void>> mapTask();
}
简单实现抽象模板
import com.aspire.demo.author.JustryDeng;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinTask;
/**
* 简单实现AbstractRecursiveAction
*
* @author {@link JustryDeng}
* @since 2020/7/31 12:31:44
*/
public class DemoRecursiveAction extends AbstractRecursiveAction<Character> {
public DemoRecursiveAction(Character[] originDataArray, int lowerLimitIndex, int upperLimitIndex) {
super(originDataArray, lowerLimitIndex, upperLimitIndex);
}
@Override
protected void computeDirectly(Character[] originDataArray, int lowerLimitIndex, int upperLimitIndex) {
for (int i = lowerLimitIndex; i < upperLimitIndex; i++) {
System.err.println(Thread.currentThread().getName() + "\t" + originDataArray[i]);
}
}
@Override
protected List<ForkJoinTask<Void>> mapTask() {
int middleIndex = (upperLimitIndex + lowerLimitIndex) / 2;
DemoRecursiveAction taskOne = new DemoRecursiveAction(originDataArray, lowerLimitIndex, middleIndex);
DemoRecursiveAction taskTwo = new DemoRecursiveAction(originDataArray, middleIndex, upperLimitIndex);
List<ForkJoinTask<Void>> list = new ArrayList<>(2);
list.add(taskOne);
list.add(taskTwo);
return list;
}
}
^_^ 如有不当之处,欢迎指正
^_^ 参考链接
https://www.jianshu.com/p/42e9cd16f705
^_^ 参考资料
《精通Java并发编程(第二版)》[西]哈维尔·费尔南德斯·冈萨雷斯 著,唐富年译
^_^ 本文已经被收录进《程序员成长笔记(三)》,笔者JustryDeng