package forkjoin;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
public class ForkJoinRecursiveAction {
private final static int MAX_THRESHOLD = 1;//任务的粒度,越大表示任务越粗糙
private final static AtomicInteger SUM = new AtomicInteger(0);
public static void main(String[] args) throws InterruptedException, ExecutionException {
final ForkJoinPool pool = new ForkJoinPool();
pool.submit(new CalculateRecursiveAction(0, 10));
pool.awaitTermination(100, TimeUnit.MILLISECONDS);
System.out.println(SUM.get());
}
@SuppressWarnings({ "serial" })
private static class CalculateRecursiveAction extends RecursiveAction {
private final int start;
private final int end;
public CalculateRecursiveAction(int start, int end) {
this.start = start;
this.end = end;
}
@Override
protected void compute() {
if ((end-start) <= MAX_THRESHOLD) {
SUM.addAndGet(IntStream.rangeClosed(start, end).sum());
} else {
int middle = (start + end) / 2;
CalculateRecursiveAction left = new CalculateRecursiveAction(start, middle);
CalculateRecursiveAction right = new CalculateRecursiveAction(middle+1, end);
left.fork();
right.fork();
}
}
}
}
package forkjoin;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
public class ForkJoinRecursiveTask {
public static void main(String[] args) throws InterruptedException, ExecutionException {
int arr[] = new int[1000];
Random random = new Random();
int total = 0;
// 初始化100个数字元素
for (int i = 0; i < arr.length; i++) {
int temp = random.nextInt(100);
// 对数组元素赋值,并将数组元素的值添加到total总和中
total += (arr[i] = temp);
}
System.out.println("初始化时的总和=" + total);
// 创建包含Runtime.getRuntime().availableProcessors()返回值作为个数的并行线程的ForkJoinPool
ForkJoinPool forkJoinPool = new ForkJoinPool();
// 提交可分解的PrintTask任务
Future<Integer> future = forkJoinPool.submit(new RecursiveTaskDemo(arr, 0, arr.length));
System.out.println("计算出来的总和="+future.get());
// Integer integer = forkJoinPool.invoke(new RecursiveTaskDemo(arr, 0, arr.length));
// System.out.println("计算出来的总和=" + integer);
// 关闭线程池
forkJoinPool.shutdown();
}
private static class RecursiveTaskDemo extends RecursiveTask<Integer> {
/**
* 每个"小任务"最多只打印70个数
*/
private static final int MAX = 70;//任务的粒度,越大表示任务越粗糙
private int arr[];
private int start;
private int end;
public RecursiveTaskDemo(int[] arr, int start, int end) {
this.arr = arr;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
int sum = 0;
// 当end-start的值小于MAX时候,开始打印
if((end - start) < MAX) {
System.err.println("=====执行累计加法======");
for (int i = start; i < end; i++) {
sum += arr[i];
}
return sum;
}else {
System.err.println("=====任务分解======");
// 将大任务分解成两个小任务
int middle = (start + end) / 2;
RecursiveTaskDemo left = new RecursiveTaskDemo(arr, start, middle);
RecursiveTaskDemo right = new RecursiveTaskDemo(arr, middle, end);
// 并行执行两个小任务
left.fork();
right.fork();
// 把两个小任务累加的结果合并起来
return left.join() + right.join();
}
}
}
}
第三个子类:CountedCompleter
请参考:https://blog.csdn.net/hudmhacker/article/details/106544897
如果任务需要阻塞(比如子问题重复计算的问题),则需要使用ManagedBlocker
package forkjoin;
import java.math.BigInteger;
import java.util.Map;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
/**
* 斐波那契数列 一个数是它前面两个数之和 1,1,2,3,5,8,13,21
*/
public class Fibonacci {
public static void main(String[] args) {
long time = System.currentTimeMillis();
Fibonacci fib = new Fibonacci();
int result = fib.f(1_000).bitCount();
time = System.currentTimeMillis() - time;
System.out.println("result = " + result);
System.out.println("test1_000() time = " + time);
}
public BigInteger f(int n) {
Map<Integer, BigInteger> cache = new ConcurrentHashMap<>();
cache.put(0, BigInteger.ZERO);
cache.put(1, BigInteger.ONE);
return f(n, cache);
}
private final BigInteger RESERVED = BigInteger.valueOf(-1000);
public BigInteger f(int n, Map<Integer, BigInteger> cache) {
BigInteger result = cache.putIfAbsent(n, RESERVED);
if (result == null) {
int half = (n + 1) / 2;
RecursiveTask<BigInteger> f0_task = new RecursiveTask<BigInteger>() {
@Override
protected BigInteger compute() {
return f(half - 1, cache);
}
};
f0_task.fork();
BigInteger f1 = f(half, cache);
BigInteger f0 = f0_task.join();
long time = n > 10_000 ? System.currentTimeMillis() : 0;
try {
if (n % 2 == 1) {
result = f0.multiply(f0).add(f1.multiply(f1));
} else {
result = f0.shiftLeft(1).add(f1).multiply(f1);
}
synchronized (RESERVED) {
cache.put(n, result);
RESERVED.notifyAll();
}
} finally {
time = n > 10_000 ? System.currentTimeMillis() - time : 0;
if (time > 50)
System.out.printf("f(%d) took %d%n", n, time);
}
} else if (result == RESERVED) {
try {
ReservedFibonacciBlocker blocker = new ReservedFibonacciBlocker(n, cache);
ForkJoinPool.managedBlock(blocker);
result = blocker.result;
} catch (InterruptedException e) {
throw new CancellationException("interrupted");
}
}
return result;
// return f(n - 1).add(f(n - 2));
}
private class ReservedFibonacciBlocker implements ForkJoinPool.ManagedBlocker {
private BigInteger result;
private final int n;
private final Map<Integer, BigInteger> cache;
public ReservedFibonacciBlocker(int n, Map<Integer, BigInteger> cache) {
this.n = n;
this.cache = cache;
}
@Override
public boolean block() throws InterruptedException {
synchronized (RESERVED) {
while (!isReleasable()) {
RESERVED.wait();
}
}
return true;
}
@Override
public boolean isReleasable() {
return (result = cache.get(n)) != RESERVED;
}
}
}