归并排序,想必大家都不陌生,它是我们学习排序算法和分治法的极好例子。它是稳定排序,且有稳定的 O ( n l o g n ) O(nlogn) O(nlogn)时间复杂度,不受数据混乱度影响。唯一的不足是需要 O ( n ) O(n) O(n)的辅助空间。因此,归并排序被认为是综合性能最优的排序算法。
我们先来回顾一下归并排序的基本思想:首先,把数组平分成两半。然后,对这两半分别递归地进行归并排序。最后,把这两半排好序的数组有序合并起来。所以,大致可以写成这样:(用Java)
public void mergeSort(int[] arr, int l, int r) {
final int len = r - l;
if (len <= 0) {
return;
} else if (len == 1) {
if (arr[l] > arr[r]) {
int tmp = arr[l];
arr[l] = arr[r];
arr[r] = tmp;
}
return;
} // 这两个是边界条件
final int mid = l + (len >> 1);
mergeSort(arr, l, mid);
mergeSort(arr, mid + 1, r);
merge(arr, l, mid, mid + 1, r);
}
其中merge是合并函数,把两个有序数组(这里表示成一个数组中互不交叠的两段)合并成一个有序数组。代码如下:
private void merge(int[] arr, int l1, int r1, int l2, int r2) {
int i = l1, j = l2, k = left;
while (i <= r1 && j <= r2) {
if (arr[i] <= arr[j]) {
tmp[k++] = arr[i++];
} else {
tmp[k++] = arr[j++];
}
}
while (j <= r2) {
tmp[k++] = arr[j++];
}
while (i <= r1) {
tmp[k++] = arr[i++];
}
for (i = 0; i < k - left; i++) {
arr[i + l1] = tmp[i + left];
}
}
我们看到,归并排序有两个递归调用,这两个递归调用没有“重叠子问题”,也就是说完全可以互不干扰地进行。所以,我们很自然地想到用并行框架,把两个子问题并行解决,这样就可以进一步提升速度。
Java中并行编程有两种方法,一是借助线程和线程池,二是借助并行框架ForkJoinTask
。我们分别思考一下。
如果借助线程,则要处理的是两个子数组排序的并行,以及两个子数组排序后再合并的问题。前者并不难,重点在于后者。如果某一个子数组没有排序完毕就开始merge操作,则得不到正确结果。因为merge是建立在两个子数组都有序的大前提下的。所以必须考虑同步问题。这就需要借助信号量等操作,比较繁琐。
而对于归并排序这类问题,Java提供了两个非常合适的框架,分别是ForkJoinTask
和ForkJoinPool
。前者用于定义需要并行的任务,后者提供并行所需的线程池。ForkJoinTask
是一个处理并行问题的接口,有两个抽象类实现了它,一类是RecursiveTask
,一类是RecursiveAction
,分别处理有返回结果(返回结果类型为T)和无返回结果的并行问题。我们的归并排序不需要返回结果,是原地排序,所以用RecursiveAction
。
RecursiveAction
详解要使用RecursiveAction
,则必须用一个类来继承它,这个类就定义了一个具体的并行任务。我们的并行任务叫做ParallelMergeSort
。注意RecursiveAction
是一个抽象类,它有一个抽象方法compute
,继承它的类必须实现这个抽象方法。这个compute
方法,就是并行任务的执行代码。
我们先看调用并行任务的方法:调用时,不能直接用start或fork等方法,因为并行任务的执行并不在主线程中,而是在线程池提供的线程中,无法得知任务是否执行完毕。
不要担心,Java提供了另外一种阻塞式的启动方法,叫做invoke
。它是ForkJoinPool
的一个成员方法,表示启动一个并行任务,并阻塞主线程,在所有任务都执行完毕后再唤醒主线程。因此,我们可以这样写:
ParallelMergeSort task = new ParallelMergeSort(...); // 分配一个任务
ForkJoinPool pool = new ForkJoinPool(); // 分配一个并行线程池
pool.invoke(task); // 向线程池提交任务
然后再看并行任务ParallelMergeSort
的定义。首先,有序数组的合并函数是完全一样的。重点在于如何将两个子数组的排序并行。所以,我们需要提供数组的排序区间。
import java.util.concurrent.RecursiveAction;
public class ParallelMergeSort extends RecursiveAction {
private final int[] arr; // 待排序数组
private final int left; // 排序区间左端点
private final int right; // 排序区间右端点。注意这里是闭区间。
private final int min; // min 和 max将在下文解释
private final int max;
private final int[] tmp; // tmp就是归并时的辅助数组,和普通归并排序中一样
public ParallelMergeSort(int[] arr, int left, int right,
int min, int max, int[] tmp) {
this.min = min;
this.max = max;
this.tmp = tmp;
if (arr == null) {
throw new NullPointerException();
} else if (left > right) {
throw new IndexOutOfBoundsException();
}
this.arr = arr;
this.left = left;
this.right = right;
}
private void merge(int[] arr, int l1, int r1, int l2, int r2) {
int i = l1, j = l2, k = left; // 疑问1:为什么k初始化为left而不是0?
while (i <= r1 && j <= r2) {
if (arr[i] <= arr[j]) {
tmp[k++] = arr[i++];
} else {
tmp[k++] = arr[j++];
}
}
while (j <= r2) {
tmp[k++] = arr[j++];
}
while (i <= r1) {
tmp[k++] = arr[i++];
}
for (i = 0; i < k - left; i++) {
arr[i + l1] = tmp[i + left];
}
}
public void mergeSort(int[] arr, int l, int r) {
// 疑问2:为什么要保留这个函数
final int len = r - l;
if (len <= 0) {
return;
} else if (len == 1) {
if (arr[l] > arr[r]) {
int tmp = arr[l];
arr[l] = arr[r];
arr[r] = tmp;
}
return;
}
final int mid = l + (len >> 1);
mergeSort(arr, l, mid);
mergeSort(arr, mid + 1, r);
merge(arr, l, mid, mid + 1, r);
}
@Override
protected void compute() { // 核心代码
final int len = right - left;
if (len < 50 || (len + 1) << 4 <= max - min + 1) {
mergeSort(arr, left, right);
return;
} // 当数组规模很小或线程开得太多,就转为普通归并排序
final int mid = left + (len >> 1);
ParallelMergeSort leftTask =
new ParallelMergeSort(arr, left, mid, min, max, tmp);
ParallelMergeSort rightTask =
new ParallelMergeSort(arr, mid + 1, right, min, max, tmp);
invokeAll(leftTask, rightTask); // 和归并排序思路完全相同,只不过这里是并行的
merge(arr, left, mid, mid + 1, right);
}
}
这段代码并不难理解,但如代码中注释所示,有两个疑惑的地方,一个是为什么要保留普通归并排序函数,一个是merge函数中tmp数组为什么要从left开始存储。
其实归根结底,就是因为并行。有人说这不废话吗,诶,还真不废话。
我在compute函数的注释中已经写明白,如果数组规模太小,或者线程开得太多,就改用普通归并排序,因此需要保留普通归并排序函数。
然后就是merge函数中从left开始的问题。我们这里为了节省空间,始终用同一个tmp数组,它的真正内容由构造方法传入。任务中,每个子任务都需要使用tmp数组。如果归并操作统一从0开始,则由于子任务的并行,导致tmp同位置的元素可能被不同线程共享,造成归并操作混乱(线程安全问题)。因此,我们需要使得不同子任务使用的tmp数组区间互不重叠。用left可保证tmp和原数组位置一一对应,这样就避免了线程安全问题。
这两点是最难理解的,如果理解了这两点,剩下的工作就是水到渠成的。
最后考虑一下并行程度的问题。是不是开的线程越多,运行速度越快呢?并不是。因为线程分为用户线程和硬件线程两种。前者只是逻辑上的线程,而后者才是CPU中真正支持并行的东西。我们的CPU经常说“四核八线程”就是指支持四个独立计算单元以及8个真正并行的任务。其它逻辑线程仅仅是并发而非并行。而且,一般操作系统需要长期占用两个线程,其余6个线程才是真正供程序使用。无论你在程序中表面上开了多少个线程,都只有最多6个线程是真正并行的。
另外,开线程需要占用大量的操作系统资源,如果线程开得过多,则时间空间消耗是非常可观的,如果开线程本身的负面影响已经抵消甚至超过并行带来的好处,则使用并行就没有意义。
而根据我们上面的代码,并行任务数量大致和递归层数是指数关系(递归到第k层,则共有约 2 k 2^k 2k个并行任务)。因此我们应该在递归到第3-5层,就转为普通归并排序。所以,compute函数中有这样的特判:
if (len < 50 || (len + 1) << 4 <= max - min + 1) {
mergeSort(arr, left, right);
return;
}
根据测试,保留8-32个并行任务时运行最快。
时间单位均为毫秒。CPU corei5-7200,四核八线程,操作系统win10 64bit,jdk1.8
排序规模 | MergeSort时间 | ParallelMergeSort时间 | 加速比 |
---|---|---|---|
1000 | 2 | 2 | 1.000 |
5000 | 2 | 10 | 0.200 |
10000 | 4 | 15 | 0.267 |
50000 | 18 | 18 | 1.000 |
100000 | 32 | 16 | 2.000 |
500000 | 108 | 50 | 2.160 |
1000000 | 218 | 90 | 2.422 |
5000000 | 960 | 550 | 1.745 |
10000000 | 2100 | 950 | 2.211 |
30000000 | 6200 | 2520 | 2.460 |
可见,当规模较小时,由于并行与否的影响不大,而且开线程需要消耗时间空间,所以比普通归并排序慢一些。但规模较大时就体现出来差距了,并行版本的速度是普通版本的2-3倍。