【排序算法】02:归并排序、快速排序

本文为王争《数据结构与算法之美》笔记。

目录

  • 前言
  • 一、归并排序
    • 原理与代码
    • 性能分析
      • 时间复杂度
      • 空间复杂度
      • 稳定性
    • 改成非递归形式
  • 二、快速排序(Quicksort)
    • 原理
    • 初步代码
    • 最好和最坏时间复杂度
    • 优化分区算法
      • 随机法
      • 三数取中法
    • 性能分析
      • 时间复杂度
      • 空间复杂度
      • 稳定性
    • 改成非递归形式
    • 和其他排序算法的结合

前言

时间复杂度为O(nlogn)的算法主要是归并排序和快速排序。
一般来说,归并排序是从下到上的:先处理子问题,再合并;而快速排序是从上到下的,先分区,再处理子问题。
归并排序的时间复杂度稳定在O(n
logn),但是它是非原地排序算法,原因是合并操作无法原地执行。而快速排序的分区操作是可以原地执行的。
归并排序可以是稳定的算法,而快速排序是不稳定的算法。

一、归并排序

原理与代码

思想:先分解,一分为二,即“从上到下”;再合并,逐级地“合并两个有序数组”,即“从下到上”。
mergeSort(int[] nums, int[] result, int l, int r)方法的作用:对nums数组的[l, r)区间进行排序。
merge(int[] nums, int[] result, int l1, int r1, int l2, int r2)方法的作用:对nums数组中相邻的[l1, r1)和[l2, r2)进行合并排序,最后的效果就是[l1, r2)变为了有序区间。注意,因为两个区间是相邻的,且区间是左闭右开的,所以r1 = l2。
重点关注merge()方法,用到了result数组进行存储。正是因为

class Solution {
    public int[] sortArray(int[] nums) {
        int n = nums.length;
        if(n<=1) return nums;
        //左闭右开
        int[] result = new int[n];
        mergeSort(nums, result, 0, n);
        return nums;
    }
    void mergeSort(int[] nums, int[] result, int l, int r){
        if(r-l<=1) return;
        int mid = l + (r - 1 - l) / 2;
        mergeSort(nums, result, l, mid+1);
        mergeSort(nums, result, mid+1, r);
        merge(nums, result, l, mid+1, mid+1, r);
        return;
    }
    void merge(int[] nums, int[] result, int l1, int r1, int l2, int r2){
        int i = l1, j = l2, k = l1;
        while(i<r1 && j<r2)
        	//nums[i]<=nums[j]保证了归并排序的稳定性
        	//改成nums[i]>=nums[j],结果就能变成降序
            result[k++] = nums[i]<=nums[j] ? nums[i++] : nums[j++];
        while(i<r1)
            result[k++] = nums[i++];
        while(j<r2)
            result[k++] = nums[j++];
        for(i = l1; i<r2; i++)
            nums[i] = result[i];
    }
}

性能分析

时间复杂度

递推公式很好列:T(n),则T(n) = 2*T(n/2) + n,其中的时间复杂度为O(n),所以耗时为n。另外T(1) = C,即n = 1时,只需要常数时间就可以。
有了上面的递推公式,我们就可以开始推导:

T(n) = 2*T(n/2) + n 
= 2*(2*T(n/4) + n/2) + n 
= 4*T(n/4) + 2*n 
= 4*(2*T(n/8) + n/4) + 2*n 
= 8*T(n/8) + 3*n 
...... 
= 2^k * T(n/2^k) + k * n

当n/2^k = 1时,递推结束,k = log2n。T(n) = 2Clog2n + n*log2n,如果用大O表示法的话,T(n)就等于O(nlogn)。

还可以用递归树的方法,归并排序的递归树是一棵满二叉树,每层合并总耗时都是n,共logn层,故时间复杂度为O(n*logn)。

和快排不同,归并排序无论何时时间复杂度都是O(n*logn),不会退化。

空间复杂度

辅助数组result带来O(n)的空间复杂度。递归带来的空间复杂度是O(logn),数量级较小就忽略了。
所以归并排序虽然时间复杂度稳定在O(nlogn),但其空间复杂度是O(n),不适用于对数据规模大的数组进行排序,因为那样额外的内存消耗就太高了。
但是归并排序可以对较小规模的数组进行排序。事实上,C标准库中的qsort()函数就优先使用归并排序,对小规模数据排序。1KB、2KB都没关系,换取稳定的O(n
logn)复杂度,“以空间换时间”。而像100MB这种规模的数据,就不适合用归并排序了。

稳定性

合并过程中的“<=”:nums[i]<=nums[j]保证了归并排序的稳定性。

改成非递归形式

改成递归形式的话,就没有“从上到下”的分解过程。而是直接开始“并”的过程:将一个元素与相邻元素合并成有序数组,再与旁边数组合并成有序数组,逐级进行,直至整个数组有序。这个过程中仍然需要不停合并,所以merge()函数跟递归代码中一样。

class Solution {
    public int[] sortArray(int[] nums) {
        int n = nums.length;
        if(n<=1) return nums;
        //左闭右开
        int[] result = new int[n];
        mergeSort(nums, result, n);
        return nums;
    }
    //mergeSort()方法改变了输入参数
    void mergeSort(int[] nums, int[] result, int n){
        int k = 1;
        while(k<n){
        	//以下操作为:k个一组合并相邻数组
            int i = 0;
            //当左区间和右区间都在nums内时
            while(i+2*k <= n){
                merge(nums, result, i, i+k, i+k, i+2*k);
                i += 2*k;
            }
            //当左区间都在nums内,右区间部分在nums内时
            //如果只有左区间存在,则本次不合并,在最后一次k个一组合并时再处理
            if(i+k<n) merge(nums, result, i, i+k, i+k, n);
            k *= 2;
        }
    }
	//merge()函数和递归方法一样,复制就行了
}

本方法理解难度较高,更多分析可见这篇文章:常见排序算法(6)–归并排序(非递归版)。

二、快速排序(Quicksort)

原理

假如要对数组左闭右开区间 [l, r) 的部分进行排序,我们就选择[l, r)中任意一个数作为pivot(分区点)。
然后进行分区操作:遍历[l, r)中的数据,将小于pivot的数放到左边,将大于pivot的放到右边,这样中间的pivot便成了“分区点”。假设分区结束后pivot的索引是p,那么区间[l, p)内部的数都是小于pivot的,区间[p+1, r)内部的数都是大于pivot的。(注:其他等于pivot的数在哪个区间要看代码实现,但对排序过程没有什么影响,会对所有等于pivot的数据的先后顺序有影响)
根据分治的思想,我们递归地对区间[l, p)和[p+1, r)也进行上述分区操作,直到区间长度缩小为1,就说明排序完成了。

初步代码

刚才整个过程写成代码如下:

void quickSort(int[] nums, int l, int r){
    if(l>=r) return;
    int k = partition(nums, l, r);
    //递归的前提是先得有左区间
    if(k > l) quickSort(nums, l, k);
    //递归的前提是先得有右区间
    if(k < r-1) quickSort(nums, k+1, r);
}

那么,该怎么实现分区函数partition()呢?我们可能首先会想到,直接顺序遍历原数组,用临时数组A接收所有小于pivot的值,临时数组B接收所有大于pivot的值,最后再按顺序把数组A和B拷贝到原数组。
这样做可以,但是空间复杂度太高了。
**如何不用额外空间就进行分区呢?**这里给出一种思路:定义两个指针i, j,先初始化为左边界l,然后j遍历数组进行探路,找到小于pivot的数就和nums[i]交换,然后i前进一格,如此循环。

int partition(int[] nums, int l, int r){
    int pivot = nums[r-1];
    int i = l;
    for(int j=l; j<r-1; j++){
        if(nums[j] < pivot){
            int temp = nums[j];
            nums[j] = nums[i];
            nums[i] = temp;
            i++;
        }
    }
    int temp = nums[i];
    nums[i] = nums[r-1];
    nums[r-1] = temp;
    return i;
}

j前进到pivot的位置就跳出循环,然后交换nums[i]和pivot,返回i。
为什么能放心地将nums[i]和pivot交换?因为j已经探过路了,[i, j)之间的数肯定都>=pivot。

于是我们得到了代码:

class Solution {
    public int[] sortArray(int[] nums) {
        int n = nums.length;
        //如果数组长度<=1,则没必要排序,直接返回
        if(n<=1) return nums;
        quickSort(nums, 0, n);
        return nums;
    }
    void quickSort(int[] nums, int l, int r){
        if(l>=r) return;
        int k = partition(nums, l, r);
        //递归的前提是先得有左区间
        if(k > l) quickSort(nums, l, k);
        //递归的前提是先得有右区间
        if(k < r-1) quickSort(nums, k+1, r);
    }
    int partition(int[] nums, int l, int r){
        int pivot = nums[r-1];
        int i = l;
        for(int j=l; j<r-1; j++){
            if(nums[j] < pivot){
                int temp = nums[j];
                nums[j] = nums[i];
                nums[i] = temp;
                i++;
            }
        }
        int temp = nums[i];
        nums[i] = nums[r-1];
        nums[r-1] = temp;
        return i;
    }
}

如果你将代码复制到LeetCode 912.排序数组并提交,就会发现超时了,无法通过。
为什么呢?

最好和最坏时间复杂度

我们先来初步分析一下快排的时间复杂度。
先假设每次选分区点都很巧,恰好最后分区点都在区间的中间。那么用时的递推公式就和归并排序一样,所以此时快排的时间复杂度是O(n*logn)。

但是,想要每次选分区点都那么巧,落到最中间,那是几乎不可能的。我们每次都选右边界作为pivot的取值,基本不会每次都是中位数。
所以,O(nlogn)只是快排的最好时间复杂度,代表分界最均匀的情况下快排的耗时。
如果数组本来就已经升序,那么按照这种取pivot的方法,每次分区之后分区点都还在右边界,这就是分区极不均匀的情况,最后需要分n次区,每次分区平均要扫描n/2个元素。这样,快排的时间复杂度就退化为了O(n^2),即为快排的最坏时间复杂度。

可以看到,现在选择分区点的方法还可以优化,避免陷入分区极不均匀、以至于时间复杂度变为O(n^2)的情况。
言归正传,优化了分区点选择方法,就有可能不再超时,成功通过了。

优化分区算法

随机法

利用Random类在区间内随机选一个值,虽然不能保证每次都是最佳,但也很难出现时间复杂度退化到O(n^2)的情况。

class Solution {
    public int[] sortArray(int[] nums) {
        int n = nums.length;
        if(n<=1) return nums;
        //左闭右开
        quickSort(nums, 0, n);
        return nums;
    }
    void quickSort(int[] nums, int l, int r){
        if(l>=r-1) return;
        int k = partition(nums, l, r);
        //递归的前提是先得有左区间
        if(k > l) quickSort(nums, l, k);
        //递归的前提是先得有右区间
        if(k < r-1) quickSort(nums, k+1, r);
    }
    int partition(int[] nums, int l, int r){
        //随机法选择分区点
        //获取l到r的一个随机整数,以随机选择分区点
        int p = new Random().nextInt(r - l) + l;
        //把随机到的轴值交换到数组尾部
        swap(nums, p, r-1);
        int pivot = nums[r-1];
        int i = l;
        for(int j=l; j<r-1; j++){
        	//改成nums[j] > pivot,结果就能变成降序
            if(nums[j] < pivot){
                swap(nums, i, j);
                i++;
            }
        }
        swap(nums, i, r-1);
        return i;
    }
    //交换函数
    void swap(int[] nums, int a, int b){
        int temp = nums[b];
        nums[b] = nums[a];
        nums[a] = temp;
    }
}

只是把分区的代码多加了随机取数,性能方面就得到了很大优化,成功AC了。

PS:
这里还将交换的操作抽象成了函数。注意交换函数的形参,数组在Java中属于引用类型。有了引用类型,实参才能被改变。
下面是错误的交换函数实现:

void swap(int a, int b){
    int temp = a;
    a = b;
    b = temp;
}
//等到用的时候就这么写:
swap(nums[i], nums[j]);

因为Java中基本类型都是值传递的,所以swap(nums[i], nums[j]);这样的语句根本不会改变数组内的数据。

三数取中法

跟上面随机法的代码只有partition()函数不同:

int partition(int[] nums, int l, int r){
	//注意,移位运算优先级比加减低,因此a>>b这类运算要加括号,像l + (r-1-l)>>1就是错的
    int mid = l + ((r - 1 - l)>>1);
    //对左中右进行排序
    if(nums[l] > nums[r-1]) swap(nums, l, r-1);
    if(nums[l] > nums[mid]) swap(nums, l, mid);
    if(nums[mid] > nums[r-1]) swap(nums, mid, r-1);
    swap(nums, mid, r-2);
    int pivot = nums[r-2];
    int i = l;
    for(int j=l; j<r-2; j++){
        if(nums[j] < pivot){
            swap(nums, i, j);
            i++;
        }
    }
    swap(nums, i, r-2);
    return i;
}

三数取中法的分区算法改动较大,但最基本的扫描数据方式没变,还是用j指针探路,i与j交换。
有一张图片直观地展示了这种分区算法的过程,来自文章:图解排序算法(五)之快速排序——三数取中法,作者: dreamcatcher-cx。

【排序算法】02:归并排序、快速排序_第1张图片
由图可知,经过三数取一之后的第一个元素一定<=6,理论上来说可以跳过它直接从第二个开始扫描。
但注意,如果用之前的j指针探路的扫描方式,就不能这样做。
因为第二个元素是l+1,而分区点的索引是r-2,如果区间长度是2,就会出现第二个元素直接大于分区点索引的情况。而这是扫描数组时所不能出现的情况,j指针探路的扫描方式要求:指针到达分区点索引后就立刻跳出。
更直观一点:如果真的跳过第一个元素,这种算法在处理[3, 5]时反而会错误地输出[5, 3]。

另外,和随机法不同,如果要改成降序,不只要把nums[j] < pivot修改成nums[j] > pivot。还需要在“对左中右进行排序”部分,把">“全改成”<":

if(nums[l] < nums[r-1]) swap(nums, l, r);
if(nums[l] < nums[mid]) swap(nums, l, mid);
if(nums[mid] < nums[r]) swap(nums, mid, r);

性能分析

时间复杂度

最好和最坏的上面已经推导了,分别是O(nlogn)和O(n)。
平均时间复杂度
利用递推公式的完整的数学推导点这里,只要你有高中知识就不难理解。
《数据结构与算法之美》第27节那个递归树解法,我认为不太行……首先它得假设每次分区的比例不变,然后“根据概率论”可得平均时间复杂度为O(n
logn),但我实在看不出来是怎么“根据概率论”的。

空间复杂度

快排是原地排序算法,空间复杂度来自递归调用。最好空间复杂度为O(logn),最坏空间复杂度为O(n)。

稳定性

快排是不稳定的排序算法。比如输入[5,6,5,1,2,4],其中pivot为最后一个元素4,那么扫描到1的时候会跟第一个5交换,这样同为5的数据的先后顺序就变了。

改成非递归形式

递归要警惕堆栈溢出,而上面快排的递归实现就有这个风险,改成非递归形式就可以解决这个问题。
任何递归代码都可以改成非递归形式。递归代码在执行过程中有系统或虚拟机隐式地维护栈,而如果我们在内存堆上自己实现并维护一个栈,就可以做到手动模拟递归。
我们只需要自己创建一个栈,存储待排序区间的左右边界(先存储右边界),然后开始迭代过程就好。
还是只修改quickSort()方法,改成非递归形式:

void quickSort(int[] nums, int l, int r){
    Stack<Integer> stk = new Stack<>();
    int i, j;
    //先存右指针再存左指针
    stk.push(r);
    stk.push(l);
    while(!stk.empty()){
        i = stk.pop();
        j = stk.pop();
        if(i < j-1){
            int k = partition(nums, i, j);
            if(k > i){
                stk.push(k);
                stk.push(i);
            }
            if(k < j-1){
                stk.push(j);
                stk.push(k+1);
            }
        }
    }
}

和其他排序算法的结合

在数据规模较小的时候,O(n^2)算法未必比O(nlogn)算法执行时间长。复杂度分析都是偏理论的,并不能和实际运行时间划等号。
大O表示法是忽略了系数、常数和低阶的,即O(n
logn)在没有忽略系数、常数和低阶前可能是O(knlogn+c)。而如果k、c比较大的话,执行时间可能会反超O(n^2)算法。

比如k = 500, c = 200,而n = 100(数据规模较小)时:
k*n*logn+c = 500*100*log100 + 200,比100^2大得多

因此在区间长度较小时,可以直接使用O(n^2)的插入排序来替代快速排序,还是直接修改quickSort()函数就可以:

void quickSort(int[] nums, int l, int r){
    //如果区间长度<=1,直接返回
    if(r-l<=1){
        return;
    }else if(r-l<=16){
        insertionSort(nums, l, r);
        return;
    }else {
        int k = partition(nums, l, r);
        if (k > l) quickSort(nums, l, k);
        if (k < r - 1) quickSort(nums, k + 1, r);
    }
}

插入排序代码:

void insertionSort(int[] nums, int l, int r) {
    if(r-l<=1) return;
    for(int i=l+1; i<r; i++){
        int value = nums[i];
        int j = i-1;
        for(; j>=l; j--){
            if(nums[j] > value){
                nums[j+1] = nums[j];
            }else{
                break;
            }
        }
        nums[j+1] = value;
    }
    return;
}

你可能感兴趣的:(数据结构与算法,java,算法,排序算法)