top-K问题指的是从一个数组中找出里面前k大或是前k小的问题。解决这类问题可以有以下的集中方法。
1、排序。排序之后取前k大或是前k小,时间复杂度是O(nlog(n));
2、使用堆。与之对应的是最大堆和最小堆,时间复杂度是O(nlog(k));
3、使用快排中的partition,将数组分成小于等于大于三部分,根据k除去一部分数据,在对剩下的数据进行partition,直至找到前k大或是前k小的数,时间复杂度是O(n),不过这个时间复杂度是概率统计下的结果,并不是严格的O(n)的。
下面介绍的BFPRT算法求解top-k问题的时间复杂度是严格的O(n)的。
一、BFPRT流程
我们将BFPRT算法看成一个函数bfprt(int[] arr,int k),返回值是第k大或是第k小的值。
将数组按照5个为一组进行分组,最后剩下的不足5个的为一组,此项操作时间复杂度是O(n);
将每组的5个数进行排序,因为每组只有5个数进行排序,一个组排序的时间复杂度是O(1),这个操作的时间复杂度是O(n);
将每组的中位数取出来构成一个新的数组newArr,这个数组的长度大约是n/5,所以这个操作的时间复杂度是O(n);
求出新数组的中位数,即递归调用bfprt(newArr,newArr.length/2),假设原来的问题时间复杂度是T(n),则这个操作的时间是T(n/5);
使用步骤4求出的中位数进行partition,这一步骤最少可以排除掉arr的3/10,在对剩下的进行bfprt,这个操作的时间是T(7n/10);
整个过程的时间复杂度是T(n) = T(n/5) + T(7n/10) + O(n) = O(n),这个的证明过程大家可以看算法导论。
二、BFPRT算法实现
public static void main(String[] args) {
int[] arr = { 6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9 };
int[] res = minKthNumsByBFPRT(arr,11);
System.out.println(Arrays.toString(res));
Arrays.sort(arr);
System.out.println(Arrays.toString(arr));
}
public static int[] minKthNumsByBFPRT(int[] arr,int k){
if(arr == null || k < 1 || arr.length < 1 ||arr.length < k){
return null;
}
// 使用BFPRT算法得到第k小的数kthNum
int kthNum = getKthMInNumByBFPRT(arr,0,arr.length - 1,k);
int[] res = new int[k];
int index = 0;
// 将小于kthNum的数放到结果数组中
for(int i = 0; i < arr.length; i++){
if(arr[i] < kthNum){
res[index++] = arr[i];
}
}
//小于kthNum的数个数小于k个时,剩下的数全部用kthNum填充
while(index < k){
res[index++] = kthNum;
}
// 返回最小的前k的数的数组
return res;
}
private static int getKthMInNumByBFPRT(int[] arr, int left, int right, int k) {
if(left == right){
return arr[left];
}
// 将原数组进行复制
int[] copy = copyArr(arr);
// 得到中位数的中位数median
int median = getMedianOfMedian(copy,left,right);
// 按median进行划分小于等于大于三部分,中间等于区域的索引范围
int[] range = partition(copy,left,right,median);
// 在等于区域的范围内,直接返回
if(k >= range[0] && k <= range[1]){
return copy[k];
// k在等于区域的左边,即在小于区域,再对左边的区域进行BFPRT算法即可
}else if(k < range[0]){
return getKthMInNumByBFPRT(copy,left,range[0]-1,k);
// k在等于区域的右边,即在大于区域,再对右边的区域进行BFPRT算法即可
}else{
return getKthMInNumByBFPRT(copy,range[1]+1,right,k);
}
}
// 对copy数组使用median进行partition操作,大于median放左边,等于median放中间,大于median放右边
private static int[] partition(int[] copy, int left, int right, int median) {
int less = left - 1;
int more = right + 1;
int cur = left;
while(cur < more){
if(copy[cur] < median){
swap(copy,++less,cur++);
}else if(copy[cur] > median){
swap(copy,cur,--more);
}else{
cur++;
}
}
return new int[]{less+1,more-1};
}
private static int getMedianOfMedian(int[] copy, int left, int right) {
int len = right - left + 1;
// 检查区间长度是否能被5整除,如果不足5个剩下的数作为一组
int offset = len%5==0?0:1;
// median是存放每个数组中位数的数组
int[] median = new int[len/5+offset];
int index = 0;
for(int i = left;i <= right;i = i+5){
// 取最小值是因为最后一组可能没有5个数
int end = Math.min(i + 4,right);
// 采用插入排序
insertSort(copy,i,end);
// 取每组的中位数
median[index++] = copy[(i+end)>>1];
}
// 求中位数组成的数组的中位数
return getKthMInNumByBFPRT(median,0,median.length-1,median.length/2);
}
private static void insertSort(int[] copy, int left, int right) {
for(int i = left+1;i<=right;i++){
for (int j = i-1; j >=left ; j--) {
if(copy[j] > copy[j+1]){
swap(copy,j,j+1);
}else{
break;
}
}
}
}
private static void swap(int[] copy, int j, int i) {
int temp = copy[j];
copy[j]= copy[i];
copy[i] = temp;
}
private static int[] copyArr(int[] arr) {
int[] res = new int[arr.length];
for (int i = 0; i < arr.length; i++) {
res[i] = arr[i];
}
return res;
}