package test.forkjoin;
import java.awt.SystemColor;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;
public class CountTask extends RecursiveTask {
/**
*
*/
private static final long serialVersionUID = 1L;
private final static Integer THRESHOLD = 100_000;
private long start;
private long end;
public CountTask(long start, long end) {
// TODO 检查值
this.start = start;
this.end = end;
}
@Override
protected Long compute() {
long sum = 0;
if (end - start <= THRESHOLD) {
// directly compute
// System.out.println("can compute > start:" + start + ".end:" +
// end);
for (long i = start; i++ <= end;) {
sum += i;
}
} else {
long mid = (end + start) / 2;
CountTask leftTask = new CountTask(start, mid);
CountTask rightTask = new CountTask(mid + 1, end);
leftTask.fork();
rightTask.fork();
// 等待子任务结束
long leftRes = leftTask.join();
long rightRes = rightTask.join();
sum = leftRes + rightRes;
}
return sum;
}
public static void main(String[] args) throws InterruptedException, ExecutionException {
long start = 0;
long end = 40000000000l;
long timestamp = System.currentTimeMillis();
ForkJoinPool fjp = new ForkJoinPool();
CountTask ct = new CountTask(start, end);
Future result = fjp.submit(ct);
System.out.println(result.get());
long timestampend = System.currentTimeMillis();
System.out.println("forkjoin used " + (timestampend - timestamp));
System.out.println(compute(start, end));
System.out.println("single thread use " + (System.currentTimeMillis() - timestampend));
}
public static long compute(long start, long end) {
long sum = 0;
for (long i = start; i++ <= end;) {
sum += i;
}
return sum;
}
}
package test.concurrent;
import java.util.concurrent.RecursiveTask;
import java.util.function.DoublePredicate;
public class ForkJoinTest {
public static void main(String[] args) {
double[] v = new double[10000000];
for (int i = 0; i < 10000000; i++) {
v[i] = i;
}
DoublePredicate filter = (double a) -> {
return a % 7 == 0;
};
long start = System.currentTimeMillis();
Counter counter = new Counter(v, 0, 9999999, filter);
Integer res = counter.compute();
long end = System.currentTimeMillis();
System.out.println(res);
System.out.println("used time: " + (end - start));
start = System.currentTimeMillis();
int count = 0;
for (int i = 0; i < v.length; i++) {
if (filter.test(v[i]))
count++;
}
end = System.currentTimeMillis();
System.out.println(count);
System.out.println("used time: " + (end - start));
}
}
class Counter extends RecursiveTask {
/**
*
*/
private static final long serialVersionUID = 1L;
private static final int THREADHOLD = 100000;
private double[] values;
private int from;
private int to;
private DoublePredicate filter;
public Counter(double[] values, int from, int to, DoublePredicate filter) {
this.values = values;
this.from = from;
this.to = to;
this.filter = filter;
}
@Override
protected Integer compute() {
if (to - from < THREADHOLD) {
int count = 0;
for (int i = from; i < to; i++) {
if (filter.test(values[i]))
count++;
}
return count;
} else {
int mid = (from + to) / 2;
Counter first = new Counter(values, from, mid, filter);
Counter second = new Counter(values, mid, to, filter);
invokeAll(first, second);
return first.join() + second.join();
}
}
}