Java并行编程--并行归并排序

文章目录

    • 一、归并排序回顾
    • 二、Java并行编程框架
    • 三、`RecursiveAction`详解
    • 四、测试和效率分析

一、归并排序回顾

归并排序,想必大家都不陌生,它是我们学习排序算法和分治法的极好例子。它是稳定排序,且有稳定的 O ( n l o g n ) O(nlogn) O(nlogn)时间复杂度,不受数据混乱度影响。唯一的不足是需要 O ( n ) O(n) O(n)的辅助空间。因此,归并排序被认为是综合性能最优的排序算法。

我们先来回顾一下归并排序的基本思想:首先,把数组平分成两半。然后,对这两半分别递归地进行归并排序。最后,把这两半排好序的数组有序合并起来。所以,大致可以写成这样:(用Java)

    public void mergeSort(int[] arr, int l, int r) {
        final int len = r - l;
        if (len <= 0) {
            return;
        } else if (len == 1) {
            if (arr[l] > arr[r]) {
                int tmp = arr[l];
                arr[l] = arr[r];
                arr[r] = tmp;
            }
            return;
        } // 这两个是边界条件
        final int mid = l + (len >> 1);
        mergeSort(arr, l, mid);
        mergeSort(arr, mid + 1, r);
        merge(arr, l, mid, mid + 1, r);
    }

其中merge是合并函数,把两个有序数组(这里表示成一个数组中互不交叠的两段)合并成一个有序数组。代码如下:

    private void merge(int[] arr, int l1, int r1, int l2, int r2) {
        int i = l1, j = l2, k = left;
        while (i <= r1 && j <= r2) {
            if (arr[i] <= arr[j]) {
                tmp[k++] = arr[i++];
            } else {
                tmp[k++] = arr[j++];
            }
        }
        while (j <= r2) {
            tmp[k++] = arr[j++];
        }
        while (i <= r1) {
            tmp[k++] = arr[i++];
        }
        for (i = 0; i < k - left; i++) {
            arr[i + l1] = tmp[i + left];
        }
    }

我们看到,归并排序有两个递归调用,这两个递归调用没有“重叠子问题”,也就是说完全可以互不干扰地进行。所以,我们很自然地想到用并行框架,把两个子问题并行解决,这样就可以进一步提升速度。

二、Java并行编程框架

Java中并行编程有两种方法,一是借助线程和线程池,二是借助并行框架ForkJoinTask。我们分别思考一下。

如果借助线程,则要处理的是两个子数组排序的并行,以及两个子数组排序后再合并的问题。前者并不难,重点在于后者。如果某一个子数组没有排序完毕就开始merge操作,则得不到正确结果。因为merge是建立在两个子数组都有序的大前提下的。所以必须考虑同步问题。这就需要借助信号量等操作,比较繁琐。

而对于归并排序这类问题,Java提供了两个非常合适的框架,分别是ForkJoinTaskForkJoinPool。前者用于定义需要并行的任务,后者提供并行所需的线程池。ForkJoinTask是一个处理并行问题的接口,有两个抽象类实现了它,一类是RecursiveTask,一类是RecursiveAction,分别处理有返回结果(返回结果类型为T)和无返回结果的并行问题。我们的归并排序不需要返回结果,是原地排序,所以用RecursiveAction

三、RecursiveAction详解

要使用RecursiveAction,则必须用一个类来继承它,这个类就定义了一个具体的并行任务。我们的并行任务叫做ParallelMergeSort。注意RecursiveAction是一个抽象类,它有一个抽象方法compute,继承它的类必须实现这个抽象方法。这个compute方法,就是并行任务的执行代码。

我们先看调用并行任务的方法:调用时,不能直接用start或fork等方法,因为并行任务的执行并不在主线程中,而是在线程池提供的线程中,无法得知任务是否执行完毕。

不要担心,Java提供了另外一种阻塞式的启动方法,叫做invoke。它是ForkJoinPool的一个成员方法,表示启动一个并行任务,并阻塞主线程,在所有任务都执行完毕后再唤醒主线程。因此,我们可以这样写:

ParallelMergeSort task = new ParallelMergeSort(...); // 分配一个任务
ForkJoinPool pool = new ForkJoinPool(); // 分配一个并行线程池
pool.invoke(task); // 向线程池提交任务

然后再看并行任务ParallelMergeSort的定义。首先,有序数组的合并函数是完全一样的。重点在于如何将两个子数组的排序并行。所以,我们需要提供数组的排序区间。

import java.util.concurrent.RecursiveAction;

public class ParallelMergeSort extends RecursiveAction {

    private final int[] arr; // 待排序数组
    private final int left; // 排序区间左端点
    private final int right; // 排序区间右端点。注意这里是闭区间。
    private final int min; // min 和 max将在下文解释
    private final int max;
    private final int[] tmp; // tmp就是归并时的辅助数组,和普通归并排序中一样

    public ParallelMergeSort(int[] arr, int left, int right,
                             int min, int max, int[] tmp) {
        this.min = min;
        this.max = max;
        this.tmp = tmp;
        if (arr == null) {
            throw new NullPointerException();
        } else if (left > right) {
            throw new IndexOutOfBoundsException();
        }
        this.arr = arr;
        this.left = left;
        this.right = right;
    }

    private void merge(int[] arr, int l1, int r1, int l2, int r2) {
        int i = l1, j = l2, k = left; // 疑问1:为什么k初始化为left而不是0?
        while (i <= r1 && j <= r2) {
            if (arr[i] <= arr[j]) {
                tmp[k++] = arr[i++];
            } else {
                tmp[k++] = arr[j++];
            }
        }
        while (j <= r2) {
            tmp[k++] = arr[j++];
        }
        while (i <= r1) {
            tmp[k++] = arr[i++];
        }
        for (i = 0; i < k - left; i++) {
            arr[i + l1] = tmp[i + left];
        }
    }

    public void mergeSort(int[] arr, int l, int r) {
    // 疑问2:为什么要保留这个函数
        final int len = r - l;
        if (len <= 0) {
            return;
        } else if (len == 1) {
            if (arr[l] > arr[r]) {
                int tmp = arr[l];
                arr[l] = arr[r];
                arr[r] = tmp;
            }
            return;
        }
        final int mid = l + (len >> 1);
        mergeSort(arr, l, mid);
        mergeSort(arr, mid + 1, r);
        merge(arr, l, mid, mid + 1, r);
    }

    @Override
    protected void compute() { // 核心代码
        final int len = right - left;
        if (len < 50 || (len + 1) << 4 <= max - min + 1) {
            mergeSort(arr, left, right);
            return;
        } // 当数组规模很小或线程开得太多,就转为普通归并排序
        final int mid = left + (len >> 1);
        ParallelMergeSort leftTask =
                new ParallelMergeSort(arr, left, mid, min, max, tmp);
        ParallelMergeSort rightTask =
                new ParallelMergeSort(arr, mid + 1, right, min, max, tmp);
        invokeAll(leftTask, rightTask); // 和归并排序思路完全相同,只不过这里是并行的
        merge(arr, left, mid, mid + 1, right);
    }
}

这段代码并不难理解,但如代码中注释所示,有两个疑惑的地方,一个是为什么要保留普通归并排序函数,一个是merge函数中tmp数组为什么要从left开始存储。

其实归根结底,就是因为并行。有人说这不废话吗,诶,还真不废话。

我在compute函数的注释中已经写明白,如果数组规模太小,或者线程开得太多,就改用普通归并排序,因此需要保留普通归并排序函数。

然后就是merge函数中从left开始的问题。我们这里为了节省空间,始终用同一个tmp数组,它的真正内容由构造方法传入。任务中,每个子任务都需要使用tmp数组。如果归并操作统一从0开始,则由于子任务的并行,导致tmp同位置的元素可能被不同线程共享,造成归并操作混乱(线程安全问题)。因此,我们需要使得不同子任务使用的tmp数组区间互不重叠。用left可保证tmp和原数组位置一一对应,这样就避免了线程安全问题。

这两点是最难理解的,如果理解了这两点,剩下的工作就是水到渠成的。

最后考虑一下并行程度的问题。是不是开的线程越多,运行速度越快呢?并不是。因为线程分为用户线程和硬件线程两种。前者只是逻辑上的线程,而后者才是CPU中真正支持并行的东西。我们的CPU经常说“四核八线程”就是指支持四个独立计算单元以及8个真正并行的任务。其它逻辑线程仅仅是并发而非并行。而且,一般操作系统需要长期占用两个线程,其余6个线程才是真正供程序使用。无论你在程序中表面上开了多少个线程,都只有最多6个线程是真正并行的。

另外,开线程需要占用大量的操作系统资源,如果线程开得过多,则时间空间消耗是非常可观的,如果开线程本身的负面影响已经抵消甚至超过并行带来的好处,则使用并行就没有意义。

而根据我们上面的代码,并行任务数量大致和递归层数是指数关系(递归到第k层,则共有约 2 k 2^k 2k个并行任务)。因此我们应该在递归到第3-5层,就转为普通归并排序。所以,compute函数中有这样的特判:

        if (len < 50 || (len + 1) << 4 <= max - min + 1) {
            mergeSort(arr, left, right);
            return;
        } 

根据测试,保留8-32个并行任务时运行最快。

四、测试和效率分析

时间单位均为毫秒。CPU corei5-7200,四核八线程,操作系统win10 64bit,jdk1.8

排序规模 MergeSort时间 ParallelMergeSort时间 加速比
1000 2 2 1.000
5000 2 10 0.200
10000 4 15 0.267
50000 18 18 1.000
100000 32 16 2.000
500000 108 50 2.160
1000000 218 90 2.422
5000000 960 550 1.745
10000000 2100 950 2.211
30000000 6200 2520 2.460

可见,当规模较小时,由于并行与否的影响不大,而且开线程需要消耗时间空间,所以比普通归并排序慢一些。但规模较大时就体现出来差距了,并行版本的速度是普通版本的2-3倍。

你可能感兴趣的:(Java)