top k 问题的几种解决方法

top k问题是指给定一组数量为n的数,从中找出前k大的数或第k大的数(k <= n)。由于只要能找出前k大的数,即可以得到第k大的数。所以下面先介绍解决前k大数问题的几种思路:

1.部分排序

由于我们只需要找到数组nums的前k大的数,所以不需要对整个数据进行排序,只需要保持前k大的数有序即可。所以我们可以维护一个大小为k的数组tk:

  • 首先将数组nums的前k个元素放入数组tk中,然后从大到小对tk排序
  • 之后从数组nums的第k+1个元素开始遍历,将遇到的每个元素nums[i]和tk中的最小数tk[k-1]比较,如果比tk[k-1]大,就将tk[k-1]删去,然后将nums[i]插入到数组tk中,保证tk依旧有序
  • 遍历结束后,tk中的元素就是前k大的数

复杂度:对tk排序的复杂度为O(k*logk),向tk插入数据的复杂度为O(k), 所以遍历数组nums并向tk插入数据的复杂度为O(n*k),总的复杂度为O(k*logk + n*k) 近似为 O(k*n)
如果k的值过大,算法的复杂度会相应增大

代码

    vector<int> solve1(vector<int> &nums, int k){
        int n = nums.size();
        vector<int> tk(nums.begin(), nums.begin()+k);
        sort(tk.rbegin(), tk.rend());

        for(int i = k; i < n; i++){
            if(nums[i]  > tk[k-1]){
                int j = k-1;
                int t = nums[i];
                while(j > 0){
                    if(tk[j-1] >= t){
                        tk[j] = t;
                        break;
                    }
                    tk[j] = tk[j-1];
                    j--;
                }
                if(j == 0)
                    tk[j] = t;
            }
        }
        return tk;
    }

2.基于大根堆

我们可以将待找数组nums建立为一个大根堆,然后从建好的堆中一次找出最大的k个数即可。
复杂度:使用筛选法建堆的复杂度为O(n), 然后从大根堆中找出前k大数的复杂度为O(k*logn),所以总的复杂度为:O(n + k*logn)
显然这个算法的复杂度要低于部分排序。

代码

    vector<int> solve2(vector<int> &nums, int k){
        int n = nums.size();
        vector<int> result;
        //建堆
        for(int i = (n-2)/2; i >= 0; i--)
            adjust(nums, i, n);
        //找出前k大的数
        for(int i = n-1; i >= n-k; i--){
            int t = nums[0];
            nums[0] = nums[i];
            nums[i] = t;
            result.push_back(t);
            adjust(nums, 0, i);
        }
        return result;
    }

    void adjust(vector<int> &nums, int i, int n){
        int parent = i;
        int t = nums[i];
        while(parent*2+1 <= n-1){
            int child = parent*2+1;
            if(child != n-1 && nums[child] < nums[child+1])
                child++;
            if(t >= nums[child])
                break;
            nums[parent] = nums[child];
            parent = child;
        }
        nums[parent] = t;
    }

基于小根堆的部分排序

分析前面的部分排序算法,我们可以发现有太多的时间浪费在了对数组tk的插入操作中,为了提高插入的效率,我们可以将数组tk组织为一个小根堆,对于小根堆的插入操作复杂度为O(logk),这显然要优于直接插入的复杂度O(k)。
复杂度:总的复杂度为 O(n*logk)

代码

    vector<int> solve3(vector<int> &nums, int k){
        int n = nums.size();
        vector<int> tk(nums.begin(), nums.begin()+k);
        //建堆
        for(int i = (k-2)/2; i >= 0; i--)
            adjust(tk, i, k);
        //遍历
        for(int i = k; i < n; i++)
            if(nums[i] > tk[0]){
                tk[0] = nums[i];
                adjust(tk, 0, k);
            }
        //对tk排序
        for(int i = k-1; i >= 0; i--){
            int t = tk[0];
            tk[0] = tk[i];
            tk[i] = t;
            adjust(tk, 0, i);
        }
        return tk;
    }

    void adjust(vector<int> &nums, int i, int n){
        int t = nums[i];
        int parent = i;

        while(parent*2+1 <= n-1){
            int child = parent*2+1;
            if(child != n-1 && nums[child+1]if(t <= nums[child])
                break;
            nums[parent] = nums[child];
            parent = child;
        }
        nums[parent] = t;
    }

基于快速排序

还有一种算法是基于快速排序的,我们知道每趟快排都会选定一个基准值,一趟快排后,基准值右边的所有数都大于这个基准值,所以我们可以通过选取合适的部分递归地对这些部分进行一趟快排,直到基准值右边的数为k个,那么我们就得到了数组的前k大的数:

1. 首先对数组nums进行一趟快排
2. 然后根据关键值key的位置进行判断
3. 如果key的下标 i < n-k : 对i右边的部分进行一趟快排,然后重复步骤2
4. 如果key的下标 i > n-k : 对i左边的部分进行一趟快排,然后重复步骤2
5. 如果key的下标 i == n-k ,那么就返回key(或 i )

上述算中,如果返回key就是数组中第k大的数,如果返回i就是前k大数的位置,下面的算法给出的是一个寻找第k大数的算法,稍作修改就可以得到前k大的数。

复杂度:O(n)

代码

    int qselect(vector<int> &nums, int left, int right, int k){
        if(left <= right){
            int low = left;
            int high = right;
            int key = nums[left];
            while(low < high){
                while(low < high && nums[high] >= key)
                    high--;
                nums[low] = nums[high];
                while(low < high && nums[low] <= key)
                    low++;
                nums[high] = nums[low];
            }
            nums[low] = key;
            if(low == nums.size()-k)
                return key;
            else if(low < nums.size()-k)
                return qselect(nums, low+1, right, k);
            else
                return qselect(nums, left, low-1, k);
        }
    }

你可能感兴趣的:(数据结构与算法)