说明:
问题要求得到最大的k个元素,就可以构建含有k个元素的小根堆(相应地,若是求最小的k个元素,就构建大根堆)。
LeetCode相关题目:215. 数组中的第K个最大元素
//使用数组的前k个元素构造含有k个元素的小根堆
//从k+1开始遍历,每次和堆顶元素比较,若被遍历到的元素大于堆顶元素,则替换堆顶元素并调整堆,保证堆内的k个元素总是当前最大的k个元素。
int findKthLargest(vector<int>& nums, int k) {
vector<int> heap_k(nums.begin(), nums.begin()+k); //选取nums中的前k个元素
BuildMinHeap(heap_k); //将这k个元素建成小根堆
for(int i=k; i<nums.size(); i++){//从第k+1个元素(下标为k)开始依次和堆顶元素比较
if(nums[i] > heap_k[0]){
heap_k[0] = nums[i];
MinHeapAdjust(heap_k, 0, k);//若被遍历到的元素比堆顶元素大,则替换堆顶元素并调整堆
}
}
return heap_k[0];//heap_k是小根堆,heap_k[0]中就是原数组第k大元素
}
//构建小根堆
void BuildMinHeap(vector<int>& nums){
int n = nums.size();
for(int i=n/2; i>=0; i--){//从第一个非叶结点开始调整
MinHeapAdjust(nums, i, n);
}
}
//调整小根堆
void MinHeapAdjust(vector<int>& nums, int i, int n){
int temp = nums[i]; //暂存被筛选的结点
for(int j=i*2+1; j<n; j=i*2+1){//j初始时指向i结点的左孩子
if(j+1<n && nums[j+1]<nums[j]) j++;//调整j,使其指向i的左右孩子中的较小值
if(temp <= nums[j]) break;//若当前被筛选结点temp更小,说明自此结点向下度符合小根堆的要求,可以提前终止筛选
else{
nums[i] = nums[j];//否则将孩子结点中的更小者调整到双亲位置上
i = j; //更新i指针以便继续向下筛选
}
}
nums[i] = temp; //被筛选结点放在其最终位置
}
在快速排序中,最主要的步骤是pivotPos = Partition(nums, left, right)
,它利用数组中的一个元素作为pivot,将下标从left到right的元素分为两部分,并以pivot为枢轴将比pivot小的元素放在左边,比pivot大的元素放在右边,通过不断划分数组,最终获得整体的排序。
在TopK问题中也利用了这种不断划分的分治策略,但是在快排中,每次要处理左右两部分,在TopK问题中对这一步骤做了简化,即每次只处理一边。因为要找的是最大/最小的k个元素,因此可以通过比较pivotPos
和k
的大小来判断下一次要处理的是左边还是右边。
P.S. 关于RANDOMIZED-SELECT的算法对应《算法导论(第3版)》9.2期望为线性时间的选择算法,关于SELECT的算法对应9.3最坏情况为线性时间的选择算法
//方法4:分治法,随机选择
int findKthLargest(vector<int>& nums, int k) {
randomizedSelect(nums, 0, nums.size()-1, k);
return nums[k-1];
}
//划分:随机选择RANDOMIZED-SELECT
int Partition(vector<int>& nums, int left, int right){
int pivotPos = rand()%(right-left+1) + left;//生成[left,right]范围内的随机数
int pivot = nums[pivotPos];//随机选择元素作为枢轴
swap(nums[left], nums[pivotPos]);//将枢轴元素和最左元素交换,之后将最左元素作为枢轴(算法书中称为主元)
//获得降序序列
while(left<right){
while(left<right && nums[right]<=pivot) right--;//因为将最左元素作为枢轴,因此也要先移动右侧指针
nums[left] = nums[right];
while(left<right && nums[left]>=pivot) left++;
nums[right] = nums[left];
}
pivotPos = left;//最终左右指针相遇,该位置即为pivot的最终位置
nums[pivotPos] = pivot;
return pivotPos;
}
//随机选择递归函数
void randomizedSelect(vector<int>& nums, int left, int right, int k){
if(left >= right) return;//递归返回条件
int pivotPos = Partition(nums, left, right);
if(pivotPos == k) return; //找到kth
else if(pivotPos > k) randomizedSelect(nums, left, pivotPos-1,k);//按降序排列,因此当pivotPos比k大时,说明要找的kth在序列的左侧
else randomizedSelect(nums, pivotPos+1, right, k);
}
参考学习:
线性时间选择问题
BFPRT——Top k问题的终极解法
选出第k大的元素:
//找到第k大的元素
//划分函数
int Partition(vector<int>& nums, int left, int right, int pivot){//pivot是传入的中位数的中位数
for(int index=left; index<=right; index++){//在left,right范围内寻找pivot的下标
if(nums[index] == pivot){
swap(nums[left], nums[index]);//和最左元素交换作为主元
break;
}
}
//降序序列
int i=left, j=right;
while(i<j){
while(i<j && nums[j]<=pivot) j--;
while(i<j && nums[i]>=pivot) i++;
swap(nums[i], nums[j]);
}
swap(nums[left], nums[i]);
return i;
}
//对[begin,end]范围内数据进行排序,并返回中位数下标
int indexOfMedian(vector<int>& nums, int begin, int end){
sort(nums.begin()+begin, nums.begin()+end+1, greater<int>());
int index = begin + (end - begin)/2;
return index;
}
int select(vector<int>& nums, int left, int right, int kth){
if(right-left+1 <= 5){
//元素个数在5个以内则直接排序并返回此次的kth
sort(nums.begin()+left, nums.begin()+right+1, greater<int>());
return nums[left + kth -1];//注意下标要减一
}
int count = right - left + 1;
int groups = count/5 + (count%5 > 0 ? 1 : 0);//总共有多少组
for(int i=0; i<groups; i++){//i是组号,从0开始,按组遍历
int index = indexOfMedian(nums, left+i*5, min(left+i*5+4, right));//这里要对最后一组进行处理,最后一组元素个数可能不足5个
swap(nums[left+i], nums[index]);//将中位数换到数组的前面,方便下一次取数
}
int pivot = select(nums, left, left+groups-1, groups/2);//中位数的中位数,若个数为偶数,选择较小中位数
int pivotPos = Partition(nums, left, right, pivot);//按照pivot进行划分
int num_left = pivotPos - left + 1;//[left, pivotPos]之间一共多少个元素
if(num_left == kth) return nums[pivotPos];//若下标为pivotPos的元素恰好为kth个元素,返回
else if(num_left > kth) return select(nums, left, pivotPos-1, kth);//kth在左半区
else return select(nums, pivotPos+1, right, kth-num_left);//kth在右半区,kth-num_left是kth在右半区的相对位置
}
P.S.对SELECT算法自己理解得还不够透彻,以后继续把这一块儿补充一下。