Leetcode题解:Median of Two Sorted Arrays

原题:https://leetcode.com/problems/median-of-two-sorted-arrays/description/

There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).

You may assume nums1 and nums2 cannot be both empty.
Example1:

nums1 = [1, 3]
nums2 = [2]
The median is 2.0

Example2:

nums1 = [1, 2]
nums2 = [3, 4]
The median is (2 + 3)/2 = 2.5

题目给定了两个有序数组(从小到大排序),要求找出这两个数组包含的所有数的中位数,并且要求时间复杂度为 O ( l o g ( m + n ) ) O(log(m + n)) O(log(m+n))

如果将两个数组归并成一个有序数组再求中位数,则是比较低效率的。毕竟我们只需要找到其中位数,而不需要知道每个数的顺序,且其复杂度 O ( m n ) O(mn) O(mn)也超过了 O ( l o g ( m + n ) ) O(log(m + n)) O(log(m+n))

以下来看一下一种较为高效的算法,充分利用了数组有序这一特点来简化算法。

# python 56ms
def median(A, B):
    m, n = len(A), len(B)
    if m > n:
        A, B, m, n = B, A, n, m
    if n == 0:
        raise ValueError

    imin, imax, half_len = 0, m, (m + n + 1) / 2
    while imin <= imax:
        i = (imin + imax) / 2
        j = half_len - i
        if i < m and B[j-1] > A[i]:
            # i is too small, must increase it
            imin = i + 1
        elif i > 0 and A[i-1] > B[j]:
            # i is too big, must decrease it
            imax = i - 1
        else:
            # i is perfect

            if i == 0: max_of_left = B[j-1]
            elif j == 0: max_of_left = A[i-1]
            else: max_of_left = max(A[i-1], B[j-1])

            if (m + n) % 2 == 1:
                return max_of_left

            if i == m: min_of_right = B[j]
            elif j == n: min_of_right = A[i]
            else: min_of_right = min(A[i], B[j])

            return (max_of_left + min_of_right) / 2.0

通过代码可以初步了解到其思想是将有序数组(从小到大排序)A(m个元素)和B(n个元素)分别切分成 A 1 , A 2 , B 1 , B 2 A_1,A_2,B_1,B_2 A1,A2,B1,B2(同样都是有序数组),并假设 A 1 A_1 A1包含A的前 i i i个元素, B 1 B_1 B1包含B的前 j j j个元素,则 A 2 A_2 A2 B 2 B_2 B2分别含有 m − 1 m - 1 m1个元素和 n − j n - j nj个元素。

假设 A A A B B B包含的所有数从小到大排序后得到数组 S S S A 1 A_1 A1 B 1 B_1 B1包含了S中较小的数,记为 S 1 S_1 S1 A 2 A_2 A2 B 2 B_2 B2包含了S中较大的数,记为 S 2 S_2 S2。若 S 1 S_1 S1包含的元素数量刚好是S的一半,那么中位数就是 S 1 S_1 S1中最大的那个数,且由于 A 1 A_1 A1 B 1 B_1 B1都是有序的, 找出两者中最大的数也是简单的!

那么问题的关键就在于如何切分 A A A B B B,使其满足:

  • S 1 S_1 S1包含的元素数为 S S S的一半,这里的一半规定为half_len = (m + n + 1) / 2,即 i + j = h a l f _ l e n i + j = half\_len i+j=half_len
  • S 1 S_1 S1包含的元素都比 S 2 S_2 S2中的元素小,即
    m a x ( A 1 ) < = m i n ( B 2 ) max(A_1) <= min(B_2) max(A1)<=min(B2) m a x ( B 1 ) < = m i n ( B 1 ) max(B_1) <= min(B_1) max(B1)<=min(B1)

所以关键就是找出一对 ( i , j ) (i,j) (i,j)满足以上两个条件,因为 i i i j j j的和是知道的,那么只需找出 i i i即可。

为了算法实现的方便,我们规定 A A A的大小小于等于 B B B的大小,即 m < = n m <= n m<=n,这样的话 i i i [ 0 : m ] [0:m] [0:m]就包含了所有 ( i , j ) (i,j) (i,j)的组合。因此对这些组合进行搜索,找到符合以上两个条件的 ( i , j ) (i,j) (i,j)组合,再结合S元素个数的奇偶性来计算中位数即可。其算法复杂度为 O ( l o g ( m i n ( m , n ) ) ) O(log(min(m,n))) O(log(min(m,n)))

以下给出其他语言的实现作为参考:

//golang 24 ms
func findMedianSortedArrays(A []int, B []int) float64 {
    m, n := len(A), len(B)
    if m > n {
        A, B, m, n = B, A, n, m
    }
    
    imin, imax, half_len := 0, m, (m + n + 1) / 2
    for ;imin <= imax; {
        i := (imin + imax) / 2
        j := half_len - i
        if i < m && B[j-1] > A[i] {
            // i is too small, must increase it
            imin = i + 1
        } else if i > 0 && A[i-1] > B[j] {
            // i is too big, must decrease it
            imax = i - 1
        } else {
            // i is perfect
            var max_of_left int
            if i == 0 {
                max_of_left = B[j-1]
            } else if j == 0 {
                max_of_left = A[i-1]
            } else {
                max_of_left = int(math.Max(float64(A[i-1]),float64(B[j-1])))
            }
            
            if (m + n) % 2 == 1{
                return float64(max_of_left)
            }
            
            var min_of_right int
            if i == m {
                min_of_right = B[j]
            } else if j == n { 
                min_of_right = A[i]
            } else { 
                min_of_right = int(math.Min(float64(A[i]), float64(B[j])))
            }

            return float64(max_of_left + min_of_right) / 2.0
        }
    }
    return -9999
}
//java 49ms
class Solution {
    public double findMedianSortedArrays(int[] A, int[] B) {
        int m = A.length;
        int n = B.length;
        if (m > n) { // to ensure m<=n
            int[] temp = A; A = B; B = temp;
            int tmp = m; m = n; n = tmp;
        }
        int iMin = 0, iMax = m, halfLen = (m + n + 1) / 2;
        while (iMin <= iMax) {
            int i = (iMin + iMax) / 2;
            int j = halfLen - i;
            if (i < iMax && B[j-1] > A[i]){
                iMin = i + 1; // i is too small
            }
            else if (i > iMin && A[i-1] > B[j]) {
                iMax = i - 1; // i is too big
            }
            else { // i is perfect
                int maxLeft = 0;
                if (i == 0) { maxLeft = B[j-1]; }
                else if (j == 0) { maxLeft = A[i-1]; }
                else { maxLeft = Math.max(A[i-1], B[j-1]); }
                if ( (m + n) % 2 == 1 ) { return maxLeft; }

                int minRight = 0;
                if (i == m) { minRight = B[j]; }
                else if (j == n) { minRight = A[i]; }
                else { minRight = Math.min(B[j], A[i]); }

                return (maxLeft + minRight) / 2.0;
            }
        }
        return 0.0;
    }
}
//c++ 32ms
class Solution {
public:
    double findMedianSortedArrays(vector<int>& A, vector<int>& B) {
        //vector A(nums1), B(nums2);  
        //如果将参数先拷贝再进行交换操作,则不会导致外部传进来的参数被修改,但是拷贝操作导致评测成绩从32ms增加到52ms
        int m = A.size(), n = B.size();
        if (m > n) {
            A.swap(B);
            int temp;
            temp = m;
            m = n;
            n = temp;
        }
        
        int imin = 0, imax = m, half_len = (m + n + 1) / 2;
        while (imin <= imax) {
            int i = (imin + imax) / 2;
            int j = half_len - i;
            if (i < m && B[j-1] > A[i]) {
                // i is too small, must increase it
                imin = i + 1;
            }
            else if (i > 0 && A[i-1] > B[j]) {
                // i is too big, must increase it
                imax = i - 1;
            }
            else {
                int max_of_left = 0;
                int min_of_right = 0;
                if (i == 0) {
                    max_of_left = B[j-1];
                } 
                else if (j == 0) {
                    max_of_left = A[i-1];
                }
                else {
                    max_of_left = A[i-1] > B[j-1] ? A[i-1] : B[j-1];
                }
                
                if ((m + n) % 2 == 1) {
                return max_of_left;
                }

                if (i == m) {
                    min_of_right = B[j];
                } 
                else if (j == n) {
                    min_of_right = A[i];
                } 
                else {
                    min_of_right = A[i] < B[j] ? A[i] : B[j];
                }
                return (max_of_left + min_of_right) / 2.0;
                }

            
        }
        
        return 0.0;
    }
};

从以上几种语言的实现效果来看,golang似乎是表现最佳的,24ms的成绩明显由于其他语言。有待以后继续多多比较。

你可能感兴趣的:(算法)