BFPRT 算法

本节以今天leetcode打卡题为例来介绍下BFPRT算法。



最小的k个数

输入整数数组 arr ,找出其中最小的 k 个数。例如,输入4、5、1、6、2、7、3、8这8个数字,则最小的4个数字是1、2、3、4。

示例 1:

输入:arr = [3,2,1], k = 2
输出:[1,2] 或者 [2,1]

示例 2:

输入:arr = [0,1,2,1], k = 1
输出:[0]

限制:
0 <= k <= arr.length <= 10000
0 <= arr[i] <= 10000

这个题很简单,常规做法,排序后选择k个数。

方法一

使用C++自带的sort函数。

class Solution {
public:
    vector getLeastNumbers(vector& arr, int k) {
        sort(arr.begin(),arr.end());
        if(k==0){
            return {};
        }
        vector res(arr.begin(),arr.begin()+k);
        return res;
    }
};

方法二

使用堆排序方法。

class Solution {
public:
    vector getLeastNumbers(vector& arr, int k) {
        vector res;
        if(k == 0){
            return res;
        }
        priority_queue h;
        for(int num : arr){
            if(h.size() < k){
                h.push(num);
            }else{
                if(h.top() <= num){
                    continue;
                }else{
                    h.pop();
                    h.push(num);
                }
            }
        }
        while(!h.empty()){
            res.push_back(h.top());
            h.pop();
        }
        return res;
    }
};

有关priority_queue可以参考往期博客。

这些方法都比较简单,接下来我们将介绍本节重点,BFPRT算法。

BFPRT 算法

分析这个题目,它是一个TOP-K问题,先排序,后选择k个数当然是一种不错的想法。

对于排序,如果使用性能比较好的快速排序,其平均时间复杂度为,最坏时间复杂度为,如果使用堆排序,需要维护一个大小为k的堆(大顶堆,小顶堆),时间复杂度为,但是无论哪种排序方法,对于本题而言,其实会有些多余,因为我们只需要前k个数或者说后n-k个数,那些不需要的数也排序了。

针对这个问题,有一个比较好的算法-BFPTR算法,又称为中位数的中位数算法,它的最坏时间复杂度为。该算法的思想是修改快速选择算法的主元选取方法,提高算法在最坏情况下的时间复杂度。

在BFPTR算法中,改变了快速排序主元素pivot值的选取,在快速排序中,我们始终选择第一个元素或者最后一个元素作为pivot,而在BFPTR算法中,每次选择五分中位数的中位数作为pivot,这样做的目的就是使得划分比较合理,从而避免了最坏情况的发生。
其算法步骤为:

  1. 将个元素划为组,每组5个,至多只有一组有个元素。
  2. 寻找这个组中每一个组的中位数,这个过程可以用插入排序。
  3. 对步骤2中的个中位数,重复步骤1和步骤2,递归下去,直到剩下一个数字。
  4. 最终剩下的数字即为pivot,把大于它的数全放左边,小于等于它的数全放右边。
  5. 判断pivot的位置与k的大小,有选择的对左边或右边递归。

求第大与求第等价。

来看BFPRT算法的程序:

#include

using namespace std;

//插入排序
void insert_sort(vector& arr,int l,int r){
    for(int i=l+1;i<=r;i++){
        if(arr[i-1]>arr[i]){
            int t=arr[i];
            int j=i;
            while(j>l && arr[j-1]>t){
                arr[j]=arr[j-1];
                j--;
            }
            arr[j]=t;
        }
    }
}

//寻找中位数的中位数
int find_mid(vector& arr,int l,int r){
    if(l==r){
        return l;
    }
    int i=0,n=0;
    for(i=l;i0){
        insert_sort(arr,i,i+num-1);
        n=i-l;
        swap(arr[l+n/5],arr[i+num/2]);
    }
    n=n/5;
    if(n==l){
        return l;
    }
    return find_mid(arr,l,l+n);
}

//进行划分过程
int partion(vector& arr,int l,int r,int p){
    swap(arr[p],arr[l]);
    int i=l;
    int j=r;
    int pivot=arr[l];
    while(i=pivot && i& arr,int l,int r,int k){
    int p=find_mid(arr,l,r);
    int i=partion(arr,l,r,p);

    int m=i-l+1;
    if(m==k){
        return arr[i];
    }
    if(m>k){
        return BFPRT(arr,l,i-1,k);
    }
    return BFPRT(arr,i+1,r,k-m);
}

int main(){
    int n;
    cin>>n;
    vector arr(n,0);
    for(int i=0;i>arr[i];
    }
    int k;
    cin>>k;
    cout<<"The "<

测试一下:

10
9 8 2 6 4 7 2 1 3 0
3
The 3th number is 2
0 1 2 2 3 4 6 9 8 7

那么这题的BFPRT解法便是:

class Solution {
public:
    vector getLeastNumbers(vector& arr, int k) {
        if(k==0){
            return {};
        }
        BFPRT(arr,0,arr.size()-1,k);
        vector res(arr.begin(),arr.begin()+k);
        return res;
    }
    //插入排序
    void insert_sort(vector& arr,int l,int r){
        for(int i=l+1;i<=r;i++){
            if(arr[i-1]>arr[i]){
                int t=arr[i];
                int j=i;
                while(j>l && arr[j-1]>t){
                    arr[j]=arr[j-1];
                    j--;
                }
                arr[j]=t;
            }
        }
    }

    //寻找中位数的中位数
    int find_mid(vector& arr,int l,int r){
        if(l==r){
            return l;
        }
        int i=0,n=0;
        for(i=l;i0){
            insert_sort(arr,i,i+num-1);
            n=i-l;
            swap(arr[l+n/5],arr[i+num/2]);
        }
        n=n/5;
        if(n==l){
            return l;
        }
        return find_mid(arr,l,l+n);
    }

    //进行划分过程
    int partion(vector& arr,int l,int r,int p){
        swap(arr[p],arr[l]);
        int i=l;
        int j=r;
        int pivot=arr[l];
        while(i=pivot && i& arr,int l,int r,int k){
        int p=find_mid(arr,l,r);
        int i=partion(arr,l,r,p);

        int m=i-l+1;
        if(m==k){
            return arr[i];
        }
        if(m>k){
            return BFPRT(arr,l,i-1,k);
        }
        return BFPRT(arr,i+1,r,k-m);
    }
};

结果:


图片.png

本文参考了:

  • BFPRT算法原理

你可能感兴趣的:(BFPRT 算法)