题目链接
There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
Example 1:
nums1 = [1, 3]
nums2 = [2]The median is 2.0
Example 2:
nums1 = [1, 2]
nums2 = [3, 4]The median is (2 + 3)/2 = 2.5
首先,看到是两个已经排好序的数组,自然想到最直接简单的方法,就是归并排序,再得到中值。显然这样的时间复杂度是O(m+n)的。代码实现如下:
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int size = nums1.size() + nums2.size();
if (size == 0)
return 0;
int count = 0, i = 0, j = 0;
while (count < (size - 1) / 2) {
if (i >= nums1.size())
j++;
else if (j >= nums2.size())
i++;
else if (nums1[i] < nums2[j])
i++;
else
j++;
count++;
}
double median = 0;
if (size % 2 == 0) {
for (int k = 0; k < 2; k++) {
if (i >= nums1.size()) {
median += nums2[j];
j++;
}
else if (j >= nums2.size()) {
median += nums1[i];
i++;
}
else if (nums1[i] < nums2[j]) {
median += nums1[i];
i++;
}
else {
median += nums2[j];
j++;
}
}
median /= 2;
}
else {
if (i >= nums1.size())
median = nums2[j];
else if (j >= nums2.size())
median = nums1[i];
else if (nums1[i] < nums2[j])
median = nums1[i];
else
median = nums2[j];
}
return median;
}
};
虽然思路简单,但是由于两个数组的长度是不定的,因此需要很多的条件判断。
另一种思路,可以利用分治的方法。例如现在有数组A,B:
___ ___ ___ ___ ___
A: | a | b | c | d | e |
--- --- --- --- ---
___ ___ ___
B: | x | y | z |
--- --- ---
我们得到A+B的中值,可以分别先得到A的中值c和B的中值y。那么c和y有三种关系。我们以 c<y 为例,我们在下一次分治的时候,只需要处理c,d,e和x,y即可。原因很简单,当 c<y ,那么有 a<b<c<y<z 和 a<b<c<d<e ,即我们可以简单的理解为这样一个理论:分别根据A,B的中值,将A分为A1,A2,将B分为B1,B2。那么,在这种情况下,显然有 A1<A2,A1<B2,B1<B2 ,即总体被分成了四部分,且A2和B1组成了中间的两部分,而A1和B2分别是首尾的两部分。即最后的中值,必定不会在A1和B2中,于是可以将它们舍去。
然而,其实这个思路实现起来很繁琐(至少对我而言),因为太多的边界条件需要判断,元素总数的奇偶性也造成了很大的麻烦。
于是。在此基础上,我们可以退一步,利用求第K小数的思路来解题。在上例中,假设c=A[k/2-1],y=B[k/2-1],那么同样,A1这部分是可以舍去的 ,但是B2则不能。这样,虽然每次问题规模缩小的幅度没有上面的方法大,但是实现起来却方便得多。因为k的奇偶性问题可以很方便解决,且边界判定也少了很多麻烦。代码实现如下:
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int size = nums1.size() + nums2.size();
deque<int> que1, que2;
for (int i = 0; i < nums1.size(); i++)
que1.push_back(nums1[i]);
for (int i = 0; i < nums2.size(); i++)
que2.push_back(nums2[i]);
if (size % 2 == 0) {
deque<int> que3(que1), que4(que2);
double x = getKsmallest(que1, que2, size / 2);
double y = getKsmallest(que3, que4, size / 2 + 1);
return (x + y) / 2;
}
else {
return getKsmallest(que1, que2, size / 2 + 1);
}
}
double getKsmallest(deque<int>& que1, deque<int>& que2, int k) {
int m = que1.size(), n = que2.size();
if (m > n)
return getKsmallest(que2, que1, k);
if (m == 0)
return que2[k - 1];
if (k == 1) {
return min(que1[0], que2[0]);
}
int index1 = min(k / 2, m) - 1, index2 = k - index1 - 1 - 1;
if (que1[index1] < que2[index2]) {
int x = 0;
while (x++ < index1 + 1)
que1.pop_front();
return getKsmallest(que1, que2, k - index1 - 1);
}
else if (que1[index1] > que2[index2]) {
int y = 0;
while (y++ < index2 + 1)
que2.pop_front();
return getKsmallest(que1, que2, k - index2 - 1);
}
else {
return que1[index1];
}
}
};
同时,这样也将时间复杂度降低到了O(log(m + n))