java并发编程学习4--forkJoin

【概念

    分支和并框架的目的是以递归的方式将可以并行的任务拆分成更小的任务,然后将每个子任务的结果合并起来生成整体的结果,它是ExecutorService的一个实现,
    它把子任务分配给线程池(ForkJoinPool)中的工作线程。
    某些应用可能对每个处理器内核饭别试用一个线程,来完成计算密集任务,例如图像处理。java7引入forkjoin框架,专门用来支持这一类应用。
    假设有一个处理任务,它可以很自然的分解成子任务。

【使用方式

    要把任务提交到线程池,必须创建RecursiveTask的一个子类,其中R是并行化任务产生的结果(如果没有结果使用RecursiveAction类型)。
    然后在子类中实现product abstract R compute()方法即可。这个方法同时实现了“拆分子任务”,“任务不可拆时”的处理逻辑。如下所示:

    if(任务足够小){
        顺序计算该任务的值;
    }else{
        将任务分成两个子任务;
        递归调用本方法;
        合并每个子任务的结果;
    }

【最佳实践

    1.对一个任务调用join()方法会阻塞调用方,直到该任务结束。因此,有必要在两个子任务的计算都开始之后再调用join()。否则,你得到的版本会比原始
    的顺序执行更加缓慢,因为每个子任务都需要等到另一个子任务完成后才能开始计算,中途还要加上开启线程的开销。

    2.不应该在RecursiveTask的内部使用ForkJoinPool.invoke(),相反你应该始终调用compute()或者fork(),只有顺序代码才使用ForkJoinPool.invoke()
    来启动并行运算。

    3.对子任务调用fork()可以把他排进ForkJoinPool。同时对左边和右边的子任务调用它似乎很自然,但是这样做的效率比直接对其中一个调用compute()低。
    对一个子任务调用compute()的话,可以使一个子任务重用当前线程,避免线程池中多分配一个任务带来的开销。

    4.不应该认为多核系统中,分支合并就比顺序计算要快。一个任务可以分解成多个独立的子任务,才能让性能在并行化时有所提升。所有的子任务运行时间
    应该比分出新任务花费的时间要长。通常我们把输入输出放在一个方法中,计算在另一个方法中。

【工作窃取

        我们很难确定要满足什么条件才可以不再拆分任务。但是分出大量的小任务是一个好的选择,因为在理想情况下,划分并行任务时因该让每个任务都花费相同的时间。
        让cpu的所有内核都一样的繁忙,但是现实中我们的子任务花费的时间大不相同,这是因为有许多我们无法确认的情况出现:io,rpc,分配效率等等。

        分支合并框架使用:工作窃取来解决内核之间任务不匹配的问题。让所有任务大致相同的被平均分配到forkjoinpool的每个线程上。

        每个线程都拥有一个双向链式队列来保存它的任务,每完成一个任务就从队列头部取出下一个任务执行。当一个线程的任务队列已空,而其他线程还在繁忙,
        这个空闲线程就随机选择一个繁忙线程并从其队列尾部拿走一个任务开始执行。直到所有的任务执行完毕。
 
  
【例子
 
  
1.输出数组中有多少个数字小于0.5
 
  
 
  
public class ExerciseFilter {

    //数据源
    static double numbers[] = new double[100];
    static {
        for(int i = 0 ; i < 100 ; i++){
            numbers[i] = i + 1;
        }
        numbers[0] = 0.08;
        numbers[1] = 0.18;
        numbers[11] = 0.18;
    }

    public static void main(String[] args) {
        Counter counter = new Counter(numbers,x -> x < 0.5);
        //使用单例
        ForkJoinPool pool = ForkJoinPool.commonPool();
        long st = System.currentTimeMillis();
        //启动并行任务
        pool.invoke(counter);
        System.out.println((System.currentTimeMillis() - st) + " : " + counter.join());
    }
}

class Counter extends RecursiveTask{
    //分界线,当一个数组的长度 < 1000 就不再继续拆分
    public static final int THRESHOLD = 1000;
    //数组
    private double[] values;
    //判断条件
    private DoublePredicate filter;

    public Counter(double [] values,DoublePredicate filter){
        this.values = values;
        this.filter = filter;
    }

    @Override
    protected Integer compute() {
        //任务足够小就不再拆分
        if(values.length < THRESHOLD ){
            //返回该数组中有多少数字满足判断逻辑
            int count = 0;
            for(int i = 0; i < values.length ; i++){
                if(filter.test(values[i])){
                    count++;
                }
            }
            return count;
        }else {
            //将打数组拆分成两个
            int mid = values.length / 2;
            Counter first = new Counter(Arrays.copyOfRange(values,0,mid),filter);
            //第一个子任务提交到线程池
            first.fork();
            Counter second = new Counter(Arrays.copyOfRange(values,mid,values.length),filter);
            //当前线程执行第二个子任务,节约一个线程的开销
            int secondResult = second.compute();
            //等待第一个子任务执行完毕
            int firstResult = first.join();
            return firstResult + secondResult;
        }
    }
}

2.列表中求和
 
  
 
  
public class ExerciseSum {
    //数据源
    static int sum[] = new int[100];
    static {
        for(int i = 0 ; i < 100 ; i++){
            sum[i] = i + 1;
        }
    }

    public static void main(String[] args) {
        CounterSum counter = new CounterSum(sum);
        ForkJoinPool pool = ForkJoinPool.commonPool();
        long st = System.currentTimeMillis();
        pool.invoke(counter);
        System.out.println((System.currentTimeMillis() - st) + " : " + counter.join());
    }
}

class CounterSum extends RecursiveTask {

    //最小拆分单位:每个小数组length = 10
    public static final int THRESHOLD = 10;
    private int[] values;

    public CounterSum(int [] values){
        this.values = values;
    }

    @Override
    protected Integer compute() {
        if(values.length < THRESHOLD ){
            int count = 0;
            for(int i = 0; i < values.length ; i++){
                count += values[i];
            }
            return count;
        }else {
            int mid = values.length / 2;
            CounterSum first = new CounterSum(Arrays.copyOfRange(values,0,mid));
            first.fork();
            CounterSum second = new CounterSum(Arrays.copyOfRange(values,mid,values.length));
            int secondResult = second.compute();
            int firstResult = first.join();
            return firstResult + secondResult;
        }
    }
}

3.排序
 
  
 
  
public class ExerciseSort {

    //数据源
    static int num[] = new int[100];
    static {
        Random r = new Random();
        for(int i = 0 ; i < 100 ; i++){
            num[i] = r.nextInt(100);
        }
    }

    public static void main(String[] args) {
        CounterSort counter = new CounterSort(num);
        //使用单例
        ForkJoinPool pool = ForkJoinPool.commonPool();
        long st = System.currentTimeMillis();
        //启动并行任务
        pool.invoke(counter);
        System.out.println((System.currentTimeMillis() - st));
        Arrays.stream(counter.join()).forEach(System.out::println);
    }
}

class CounterSort extends RecursiveTask<int[]> {

    //最小拆分单位:每个小数组length = 10
    public static final int THRESHOLD = 10;
    //待排序数组
    private int[] values;

    public CounterSort(int [] values){
        this.values = values;
    }

    @Override
    protected int[] compute() {
        if(values.length < THRESHOLD ){
            int[] result = new int[values.length];
            //1.单数组排序
            result = Arrays.stream(values).sorted().toArray();
            return result;
        }else {
            int mid = values.length / 2;
            CounterSort first = new CounterSort(Arrays.copyOfRange(values,0,mid));
            first.fork();
            CounterSort second = new CounterSort(Arrays.copyOfRange(values,mid,values.length));
            int[] secondResult = second.compute();
            int[] firstResult = first.join();
            //两个数组混合排序
            int[] combineRsult = combineIntArray(firstResult,secondResult);
            return combineRsult;
        }
    }
    private int[] combineIntArray(int[] a1,int[] a2){
        int result[] = new int[a1.length + a2.length];
        int a1Len = a1.length;
        int a2Len = a2.length;
        int destLen = result.length;

        for(int index = 0 , a1Index = 0 , a2Index = 0 ; index < destLen ; index++) {
            int value1 = a1Index >= a1Len?Integer.MAX_VALUE:a1[a1Index];
            int value2 = a2Index >= a2Len?Integer.MAX_VALUE:a2[a2Index];
            if(value1 < value2) {
                a1Index++;
                result[index] = value1;
            }
            else {
                a2Index++;
                result[index] = value2;
            }
        }

        return result;
    }
}

你可能感兴趣的:(学习)