【TopK问题】基于堆的方法&基于分治策略的方法

说明

  1. TopK问题:对于给定的数组,选出其中最大/最小的k个元素,或是选出第k大/第k小元素;
  2. 本文整理了两种实现方法,分别是
    • 基于堆的实现方法:和堆排序有所不同的是,仅仅通过构建含有k个元素的堆,最终得到最大/最小的k个元素
    • 基于分治策略的方法:采用了快速排序的思想,对原数组进行划分,但和快排不同的是,每次仅处理划分后的一边
  3. 文章内容为个人学习整理,如有错误,欢迎指正。

文章目录

  • 1. 基于堆的方法
    • 1.1 算法步骤
    • 1.2 算法实现
  • 2. 基于分治策略的方法
    • 2.1 算法步骤
    • 2.2 算法实现

1. 基于堆的方法

1.1 算法步骤

问题要求得到最大的k个元素,就可以构建含有k个元素的小根堆(相应地,若是求最小的k个元素,就构建大根堆)。

  1. 首先利用原数组的前k个元素构建小根堆;
  2. 从原数组的第k+1个元素开始向后遍历,并依次比较元素与堆顶元素大小,若大于堆顶元素则替换堆顶元素并及时调整堆;否则继续向后遍历;
  3. 当数组遍历完毕后,小根堆中存储的k个元素就是原数组中最大的k个元素。

1.2 算法实现

LeetCode相关题目:215. 数组中的第K个最大元素

//使用数组的前k个元素构造含有k个元素的小根堆
    //从k+1开始遍历,每次和堆顶元素比较,若被遍历到的元素大于堆顶元素,则替换堆顶元素并调整堆,保证堆内的k个元素总是当前最大的k个元素。
    int findKthLargest(vector<int>& nums, int k) {
        vector<int> heap_k(nums.begin(), nums.begin()+k); //选取nums中的前k个元素     
        BuildMinHeap(heap_k); //将这k个元素建成小根堆
        
        for(int i=k; i<nums.size(); i++){//从第k+1个元素(下标为k)开始依次和堆顶元素比较
            if(nums[i] > heap_k[0]){
                heap_k[0] = nums[i]; 
                MinHeapAdjust(heap_k, 0, k);//若被遍历到的元素比堆顶元素大,则替换堆顶元素并调整堆
            }
        }               
        return heap_k[0];//heap_k是小根堆,heap_k[0]中就是原数组第k大元素
    }
	//构建小根堆
    void BuildMinHeap(vector<int>& nums){
        int n = nums.size();
        for(int i=n/2; i>=0; i--){//从第一个非叶结点开始调整
            MinHeapAdjust(nums, i, n);
        }
    }
    //调整小根堆
    void MinHeapAdjust(vector<int>& nums, int i, int n){
        int temp = nums[i]; //暂存被筛选的结点
        for(int j=i*2+1; j<n; j=i*2+1){//j初始时指向i结点的左孩子
            if(j+1<n && nums[j+1]<nums[j]) j++;//调整j,使其指向i的左右孩子中的较小值

            if(temp <= nums[j]) break;//若当前被筛选结点temp更小,说明自此结点向下度符合小根堆的要求,可以提前终止筛选
            else{
                nums[i] = nums[j];//否则将孩子结点中的更小者调整到双亲位置上
                i = j; //更新i指针以便继续向下筛选
            }
        }
        nums[i] = temp; //被筛选结点放在其最终位置
    }

2. 基于分治策略的方法

2.1 算法步骤

在快速排序中,最主要的步骤是pivotPos = Partition(nums, left, right),它利用数组中的一个元素作为pivot,将下标从left到right的元素分为两部分,并以pivot为枢轴将比pivot小的元素放在左边,比pivot大的元素放在右边,通过不断划分数组,最终获得整体的排序。

在TopK问题中也利用了这种不断划分的分治策略,但是在快排中,每次要处理左右两部分,在TopK问题中对这一步骤做了简化,即每次只处理一边。因为要找的是最大/最小的k个元素,因此可以通过比较pivotPosk的大小来判断下一次要处理的是左边还是右边。

2.2 算法实现

P.S. 关于RANDOMIZED-SELECT的算法对应《算法导论(第3版)》9.2期望为线性时间的选择算法,关于SELECT的算法对应9.3最坏情况为线性时间的选择算法

  • 基于随机选择的算法实现
//方法4:分治法,随机选择 
    int findKthLargest(vector<int>& nums, int k) {
        randomizedSelect(nums, 0, nums.size()-1, k);
        return nums[k-1];
    }

    //划分:随机选择RANDOMIZED-SELECT
    int Partition(vector<int>& nums, int left, int right){
        int pivotPos = rand()%(right-left+1) + left;//生成[left,right]范围内的随机数
        int pivot = nums[pivotPos];//随机选择元素作为枢轴

        swap(nums[left], nums[pivotPos]);//将枢轴元素和最左元素交换,之后将最左元素作为枢轴(算法书中称为主元)

        //获得降序序列
        while(left<right){
            while(left<right && nums[right]<=pivot) right--;//因为将最左元素作为枢轴,因此也要先移动右侧指针
            nums[left] = nums[right];
            while(left<right && nums[left]>=pivot) left++;
            nums[right] = nums[left];
        }
        pivotPos = left;//最终左右指针相遇,该位置即为pivot的最终位置
        nums[pivotPos] = pivot;
        return pivotPos;
    }
    //随机选择递归函数
    void randomizedSelect(vector<int>& nums, int left, int right, int k){
        if(left >= right) return;//递归返回条件

        int pivotPos = Partition(nums, left, right);
        if(pivotPos == k) return; //找到kth
        else if(pivotPos > k) randomizedSelect(nums, left, pivotPos-1,k);//按降序排列,因此当pivotPos比k大时,说明要找的kth在序列的左侧
        else randomizedSelect(nums, pivotPos+1, right, k);
    }
  • 最坏情况为线性时间的选择算法 SELECT,以下内容来自《算法导论》
  1. 将输入数组的n个元素划分为n/5组,每组5个元素,且至多只有一组由剩下的n mod 5个元素组成;
  2. 寻找n/5组中每一组的中位数:首先对每组元素进行插入排序,然后确定每组有序元素的中位数;
  3. 对第二步找到的n/5个中位数,递归调用SELECT函数找出其中位数x(若有偶数个中位数,约定x是较小的中位数);
  4. 按中位数的中位数x对数组进行划分,让k比划分的低区中的元素数目多1,因此x是第k小元素,且有n-k个元素在划分的高区;
  5. 若i=k,则返回x;若i.k,则在高区递归查找第i-k小的元素。

参考学习:
线性时间选择问题
BFPRT——Top k问题的终极解法

选出第k大的元素:

//找到第k大的元素

//划分函数
int Partition(vector<int>& nums, int left, int right, int pivot){//pivot是传入的中位数的中位数
    for(int index=left; index<=right; index++){//在left,right范围内寻找pivot的下标
        if(nums[index] == pivot){
            swap(nums[left], nums[index]);//和最左元素交换作为主元
            break;
        }
    }
    //降序序列
    int i=left, j=right;
    while(i<j){
        while(i<j && nums[j]<=pivot) j--;
        while(i<j && nums[i]>=pivot) i++;
        swap(nums[i], nums[j]);
    }
    swap(nums[left], nums[i]);
    return i;
}

//对[begin,end]范围内数据进行排序,并返回中位数下标
int indexOfMedian(vector<int>& nums, int begin, int end){
    sort(nums.begin()+begin, nums.begin()+end+1, greater<int>());
    int index = begin + (end - begin)/2;
    return index;
}

int select(vector<int>& nums, int left, int right, int kth){
    if(right-left+1 <= 5){
        //元素个数在5个以内则直接排序并返回此次的kth
        sort(nums.begin()+left, nums.begin()+right+1, greater<int>());
        return nums[left + kth -1];//注意下标要减一
    }

    int count = right - left + 1;
    int groups = count/5 + (count%5 > 0 ? 1 : 0);//总共有多少组
    for(int i=0; i<groups; i++){//i是组号,从0开始,按组遍历
        int index = indexOfMedian(nums, left+i*5, min(left+i*5+4, right));//这里要对最后一组进行处理,最后一组元素个数可能不足5个
        swap(nums[left+i], nums[index]);//将中位数换到数组的前面,方便下一次取数
    }

    int pivot = select(nums, left, left+groups-1, groups/2);//中位数的中位数,若个数为偶数,选择较小中位数
    int pivotPos = Partition(nums, left, right, pivot);//按照pivot进行划分

    int num_left = pivotPos - left + 1;//[left, pivotPos]之间一共多少个元素
    if(num_left == kth) return nums[pivotPos];//若下标为pivotPos的元素恰好为kth个元素,返回
    else if(num_left > kth) return select(nums, left, pivotPos-1, kth);//kth在左半区
    else return select(nums, pivotPos+1, right, kth-num_left);//kth在右半区,kth-num_left是kth在右半区的相对位置    
}

P.S.对SELECT算法自己理解得还不够透彻,以后继续把这一块儿补充一下。

你可能感兴趣的:(算法导论,数据结构,算法)