杨氏矩阵找第N大(小)的O(N)线性算法 LeetCode 378. Kth Smallest Element in a Sorted Matrix

杨氏矩阵:一个N*N的矩阵,它的每行每列都单调递增(或者宽松一些,单调不减),即a[i][j]<=a[i+1][j], a[i][j]<=a[i][j+1]。

遇到的两道面试题:
1. 输出杨氏矩阵中最小的N个数。
2. 两个升序数组A和B,长度都是N。从两个数组中分别取出一个数,相加得到一个和。求这N*N个和的前N小。

本质上第2题可以转化成第一题:把A[0]+B[k]的结果填入矩阵第一行,A[1]+B[k]的结果填入第二行……就得到一个杨氏矩阵。所以现在就只考虑第1题咯。

 

此题常见的一种做法:N路归并,用一个大小为N的堆,可以O(NlgN)得到解。但是利用杨氏矩阵的性质,这题是有O(N)的算法的……

(为了方便,把矩阵记为num)

首先根据杨氏矩阵的性质得到最关键的一点:前N小的数,肯定不大于矩阵中的num[sqrt(N)+1][sqrt(N)+1]。

(为了方便,令M = sqrt(N)+1)

因此要找的前N小的数,肯定在矩阵的前M行和前M列中。

 

所以,要找的是一个M*N的矩阵和另外一个(N-M)*M的矩阵。

这样的规模相当于M*(2N-M),相当于M*N

这样问题可以转化为:在M个长度为N的有序数组中,查找前N小的数。(*)

除了之前提到的方法,此题还有一个比较容易想到的方法:二分上界并计数。在INT_MIN~INT_MAX中二分第K大数的上界,每次对所有数组二分统计其中不大于上界的数的个数。总体的复杂度是O(R*M*lgN)。其中R是最坏情况下二分的次数。对于32位整数,最多二分32次,R=32。但是对于浮点数,需要的二分次数会增多。

在这个思路的基础上加以改进,把R改进为lgN,便可得到线性算法。

基本思路是,把二分时取数的范围从INT_MIN~INT_MAX缩小到这MN个数中。每次从这些数中选一个,来作为计数的上界。

选的方法:

每一轮计数时,先找出这M个数组的中位数,作为每个数组潜在的切分点,然后选择这些切分点的中位数作为上界。O(M)选出M个切分点,O(MlgM)把这些数排个序再选中间的,所以这一步可以O(MlgM)(注:无序数组选中位数有均摊O(M)的算法)。

但是为了每一轮都能缩小查找范围,所以对于每个数组,还要维护一个“潜在切分点的可能区间”,选择该数组的新切分点时,取这个区间的中位数。实际上就是对每个数组,维护一个二分切分点的过程信息。

这样一轮统计过后,某些偏大(或偏小)的切分点所在区间长度就需要减半。并且,至少有半数区间的长度是要减半的。(对于一个数组,不大于中位数的数的个数至少是一半。“不小于”同理)

由于所有数组一共有MN个数,因此在lg(MN)轮后,所有区间长度都会减到1。

整理一下复杂度。一共要进行lg(MN)次计数;每次计数需要O(MlgM)找切分点的中位数,以及O(MlgN)对一个数组计数。因此整体的复杂度是:
O(lg(MN)*(MlgM+MlgN)) = O(sqrt(N)*(lgN)^2) = o(sqrt(N)*sqrt(N)) = o(N)
ps. 所以这个算法复杂度其实是低于O(N)的.

(*)附代码:用以上算法实现在M个有序数组中,查找第K小的数。

转自: 

http://wolf5x.cc/blog/algorithm/young-tableau-smallest-kth#comment-123

 

 

#include 
#include 
#include 
using namespace std;

// i, partition_point_lower_index_i, partition_point_upper_index_i
typedef pair > PartRange;
typedef vector > VVI;

class PartComparator {
  const VVI &ary;
public:
  PartComparator(const VVI &a): ary(a){}
  bool operator()(const PartRange &x, const PartRange &y) const
  {
    return ary[x.first][(x.second.first+x.second.second)/2]
    < ary[y.first][(y.second.first+y.second.second)/2];
  }
};

// Get the count of numbers less than or equal to upper
int getCount(VVI &num, int upper) {
  int ret = 0;
  for (int i = 0; i < num.size(); i++) {
    ret += upper_bound(num[i].begin(), num[i].end(), upper) - num[i].begin();
  }
  return ret;
}

int chooseKthSmallest(VVI num, int k) {
  int n = num.size();
  vector part(n);
  for (int i = 0; i < n; i++) {
    part[i] = make_pair(i, make_pair(0, num[i].size()-1));
  }
  int ans = 1<<30; // INT_MAX;
  while(part.size() > 0) {
    // sort all the medians
    sort(part.begin(), part.end(), PartComparator(num));
    // choose the median of medians
    int mid = part.size()/2;
    int upper = num[part[mid].first][(part[mid].second.first+part[mid].second.second)/2];
    int count = getCount(num, upper);
    if (count >= k) {
      // update answer
      ans = min(ans, upper);
      // halve the median intervals of which the median is too large
      for(int i = 0; i < part.size(); i++) {
        int mid = (part[i].second.first+part[i].second.second)/2;
        if (num[part[i].first][mid] >= upper) {
          part[i].second.second = mid-1;
        }
      }
    } else {
      // halve the median intervals of which the median is too small
      for (int i = 0; i < part.size(); i++) {
        int mid = (part[i].second.first+part[i].second.second)/2;
        if (num[part[i].first][mid] <= upper) {
          part[i].second.first = mid+1;
        }
      }
    }
    // remove the empty median intervals
    for (int i = part.size()-1; i >= 0; i--) {
      if (part[i].second.first > part[i].second.second) {
        swap(part[i], part[part.size()-1]);
        part.erase(part.end()-1);
      }
    }
  }
  return ans;
}
int main() {
  int v[][3] = {{1,2,3},{2,3,4},{3,4,5}};
  vector vec0(v[0],v[0]+3); 
  vector vec1(v[1],v[1]+3); 
  vector vec2(v[2],v[2]+3); 
  int arr[] = {1,2,2,3,4,4,4,4,5,6,7,8,9,9,10};
  vector vec3(arr, arr+sizeof(arr)/sizeof(int));

  VVI num;
  num.push_back(vec0);
  num.push_back(vec1);
  num.push_back(vec2);
  int up = distance(vec3.begin(), upper_bound(vec3.begin(), vec3.end(), 11));
  int low = distance(vec3.begin(), lower_bound(vec3.begin(), vec3.end(), 11));
  
  int res = chooseKthSmallest(num, 3);

  return 0;
}

--------------------------------------------------------------------------------------------------------

很久之后发现一种更好的解法,仍然二分,假设数据二分的范围是整个整数,那么log(2^32)次最多是32次,实际中范围可以由左上角和右下角来确定,anyway,二分的次数可以看做是一个常数。。。

剩下的问题就是给定一个target,求这个matrix中<=target的元素个数,这个是可以O(n)实现的,也就是从右上角开始往下找。。所以整体复杂度是O(n)。LeetCode上恰好有这么一道题:

Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

Example:

matrix = [
   [ 1,  5,  9],
   [10, 11, 13],
   [12, 13, 15]
],
k = 8,

return 13.

 

Note:
You may assume k is always valid, 1 ≤ k ≤ n2.

-----------------------------------------------------------

class Solution:
    def count_lower(self, nums, target, n):
        j, res = n - 1, 0
        for i in range(n):
            while (j >= 0 and nums[i][j] > target):
                j -= 1
            res += (j + 1)
        return res

    def kthSmallest(self, matrix, k: int) -> int:
        n = len(matrix)
        left, right = matrix[0][0], matrix[n - 1][n - 1]
        while (left <= right):
            target = ((right - left) >> 1) + left
            lower = self.count_lower(matrix, target, n)
            if (lower < k):
                left = target + 1
            else:
                right = target - 1
        return left

 

 

 

 

 

你可能感兴趣的:(c++,算法)