n sum 问题总结

文章目录

  • 简介
    • two sum
      • hash map
      • two point
      • 借助二分搜索 BinarySearch
      • 改进的O(NlogN)算法
    • three sum
      • hash map
      • two point方法 O(N^2)
    • 总结
    • 参考

简介

初始问题:给定一个数组,找到数组里面n个数之和为0的组合
常见的问题有two sum,three sum等等
这些问题在leetcode上都有准备,搜一下即可

其中变化在数组里面

  1. 数组是否有序
  2. 数组是否有重复数字
  3. 如何处理0?(0加任意个0都等于0)

直观感受一下不同级别的算法对于问题规模和运行时间的差异
n sum 问题总结_第1张图片

针对n sum问题,写出暴力算法是简单的,暴力算法的时间复杂度是O(N^n)
我们可以在暴力算法上面提升,或是牺牲空间,或是采用two point的思想

  1. 使用hash,牺牲O(N)的空间,换取大约O(1)的查找(关于hashmap的查找时间复杂度这里不再赘述,理想情况下查找操作是趋于O(1)的)
  2. two point帮助我们将需要两次遍历的操作降低为一次遍历

本文以two sum、three sum为例,介绍自己编写的算法,一步步迭代。
关于n sum问题,leetcode上面有很多习题,可以加以练习。

two sum

写出暴力算法是简单的
伪码

	int cnt=0;
	for i in range(len(array)):
		for j in range(i+1,len(array)):
			if a[i]+a[j] == 0:
				cnt++
	return cnt

上述算法时间复杂度是 O ( N 2 ) O(N^2) O(N2)

前面介绍了,这里有两种做法,一种是引用hashmap,另一种是two point思想,这里逐一介绍

hash map

java代码如下
我们使用HashMap来存储a中元素出现的次数,然后遍历a中的元素v,如果v不等于0,并且-v存在于a中,那么最终结果+1
遍历完之后,cnt要除以2,原因是上述遍历过程中会重复计算组合数
最后我们找出a中有多少个0,然后对0求 C m 2 C_m^2 Cm2,加上cnt

    // O(N),使用hash表存储,对于0特殊处理
    public static int TwoSumFaster(int[] a) {
        HashMap<Integer, Integer> map = new HashMap<>();
        int cnt = 0;
        // a中元素以及它出现的次数
        for (Integer i : a) {
            int count = 0;
            if (map.containsKey(i))
                count = map.get(i);
            map.put(i, count + 1);
        }

        // 注意是遍历a而不是map
        for (Integer v : a) {
            // 如果存在-v使得两数之和为0
            if (map.containsKey(-v)) {
                if (v != 0)
                    cnt += map.get(-v);
            }
        }
        cnt /= 2;

        // 特殊处理0
        if (map.containsKey(0) && map.get(0) > 1)
            cnt += cmn(map.get(0), 2);
        return cnt;
    }

其中cmn函数如下

    //要求输入n不能大于m,否则会返回-1,这是无意义的
    private static long cmn(int m, int n) {
        if (n == 0)
            return 1;
        if (n == 1)
            return m;
        if (n > m / 2)
            return cmn(m, m - n);
        if (n > 1)
            return cmn(m - 1, n - 1) + cmn(m - 1, n);
        return -1; //通过编译用
    }

上述需要注意的地方就是

  1. 对于0要特殊处理
  2. 遍历是遍历a,而不是map
  3. 一次遍历之后记录的组合数是重复的,要除以2

该方法时间复杂度可以视为 O ( N ) O(N) O(N), 空间复杂度是 O ( N ) O(N) O(N)

two point

另一个方法就是two point,它有前提,就是假设a是已经有序的
在有序数组里面寻找两数之和为零,我们可以用i,j指针分别指向a的首位元素,然后判断a[i]+a[j]的和来移动i,j指针
我们知道,如果能在一个有序数组中找到了两个数之和为零,那么它们一定是在该数组的两端(相对而言)
使用i从左到右,使用j从右到左,取得a[i],a[j],根据a[i]+a[j]的结果来移动指针,特别要注意的是数组中有重复数和有0元素的情况

java代码:

    public static int TwoSumFaster_2(int[] a) {
        int i = 0, j = a.length - 1;
        int cnt = 0;
        // 如果没有全为正或者全为负就退出
        if (a[i] > 0 || a[j] < 0)
            return;
        // 在two sum里面,是求两个数之和,所以任何一个指针碰到0就退出了
        // 0 加任何不为0的数都不可能等于0,而0+0会等于0,但是我们会额外处理,所以遍历的时候碰到0就退出
        while (i < j && a[i] != 0 && a[j] != 0) {
            if (a[i] + a[j] > 0)
                j -= 1;
            else if (a[i] + a[j] < 0)
                i += 1;
            else {
                // 此时a[i]+a[j]=0,做进一步处理
                // 记录重复个数
                int di = 1, dj = 1;
                while (i+di < j && a[i+di] == a[i]) {
                    di += 1;
                }
                while (j-dj > i && a[j-dj] == a[j]) {
                    dj += 1;
                }
                // 如果没有重复数
                if (di == 1 && dj == 1) {
                    i++;
                    j--;
                    cnt++;
                } else {
                    // 组合数,其实相当于di*dj,写成这样更容易想到背后的联系
                    cnt += cmn(di, 1) * cmn(dj, 1);
                    i += di;
                    j -= dj;
                }
            }
        }
        // 记录0的个数 特殊处理0和0的组合
        int zn = (int) Arrays.stream(a).filter(v -> v == 0).count();
        if (zn > 1)
            cnt += cmn(zn, 2);
        return cnt;
    }

注意事项

  1. 数组必须有序
  2. 0是特殊处理的
  3. 需要处理数组中有重复数的情况

上述算法时间复杂度是 O ( N ) O(N) O(N),空间复杂度是 O ( 1 ) O(1) O(1),不需要额外的空间开销,特点是依据a已经排序的前提,和 two point 的思想

借助二分搜索 BinarySearch

其实还有一种做法,可以让时间复杂度降为 O ( N l o g N ) O(NlogN) O(NlogN),从logN中我们容易想到二分搜索,当然前提也是数组必须有序。
java代码:

    // O(NlogN)版本 也是书上的版本,不适用于数组中出现重复数的情况
    public static int TwoSum_NlogN(int[] a) {
        int cnt = 0;
        for (int i = 0; i < a.length; i++) {
            if (BinarySearch.indexOf(a, -a[i]) > i)
                cnt += 1;
        }
		return cnt;
    }

indexOf结果要大于i的原因是消除重复选取的情况,例如

-1 0 1 3

当i=0的时候,找到1与之配对,而当i=2的时候,找到-1与之配对,显然这两个组合是相同的。
我们可以在循环的时候就避免这种情况出现,使用indefOf>i的结果来规避
BinarySearch是《算法》上面的库,indexof代码如下

    public static int indexOf(int[] a, int key) {
        int lo = 0;
        int hi = a.length - 1;
        while (lo <= hi) {
            // Key is in a[lo..hi] or not present.
            int mid = lo + (hi - lo) / 2;
            if      (key < a[mid]) hi = mid - 1;
            else if (key > a[mid]) lo = mid + 1;
            else return mid;
        }
        return -1;
    }

上面的代码一个要求就是数组中不能出现重复数,否则会遗漏很多情况。
能把O(N^2)时间复杂度降为O(NlogN)其实是一项壮举

改进的O(NlogN)算法

我们可以在二分搜索上面稍加扩展就能处理数组中有重复数的情况了

    // O(NlogN)版本 可以处理数组中出现重复数的情况
    public static int TwoSum_NlogN_p(int[] a) {
        int cnt = 0;
        for (int i = 0; i < a.length; i++) {
            int j = lower_bound(a, -a[i]);
            if (j > i && a[j]+a[i]==0) {
                int r = 1;
                while (j < a.length && a[++j] == -a[i])
                    r++;
                cnt += r;
            }
        }
        int zn= (int) Arrays.stream(a).filter(v->v==0).count();
        if(zn>1)
            cnt+=cmn(zn,2);
        return cnt;
    }

其中lower_bound代码如下,它也是《算法》1.4.15的答案

    private static int lower_bound(int[] a, int key) {
        //返回第一个大于等于key的位置
        int lo = 0, hi = a.length;
        int mid = 0;
        while (lo < hi) {
            mid = lo + (hi - lo) / 2;
            if (a[mid] >= key) {
                hi = mid;
            } else {
                lo = mid + 1;
            }
        }
        return lo;
    }

总结一下上述的算法

  1. hash map
  2. two point
  3. 二分搜索
  4. lower_bound 处理重复数的情况

three sum

three sum问题和two sum问题是类似的,很多技巧可以借鉴过来
首先我们也是写一下暴力算法
伪码

	int cnt=0;
	for i in range(len(array)):
		for j in range(i+1,len(array)):
			for k in range(j+1,len(array)):
				if a[i]+a[j]+a[k] == 0:
					cnt++
	return cnt

hash map

同样的我们也可以使用hash map来降低时间复杂度,但是稍微和two sum的方法有点区别
再看代码之前,先想一下,现在three sum问题要解决三个数相加为0的问题,一个直观的想法是用两个for循环找到a[i],a[j]作为前两个数,然后再在map中find -(a[i]+a[j])是否存在
大致想法是这样,那么hash map里面仍然存放key出现过的次数吗?
先写一段序列验证一下

-3 -2 -1 1 2 3 4 5

检查上面的序列,会发现-3 + 1 + 2 =0 ,而 -3 + 2 +1 也等于0,前一种情况是i=0,j=2,后一种情况是i=0,j=4,显然这两组数字都是唯一的,意味着我们重复选取了组合,而这种重复选取的结果大约会得到cnt是真实结果的3倍。但是如果直接返回cnt/3是不准确的,在数字很多的情况下会产生误差,误差的原因我还没有验证,猜测跟0有关

所以我们不希望出现这种重复选取的情况,再看上面的组合数,这种现象出现的原因是选取的第三个数下标小于j,也就是第二个数。我们规定只能选取j之后的第三个数,就可以消除以上的重复情况。
其实这个思想很像two sum里面的二分搜索法

下面是java代码:

    // 时间复杂度大于等于 O(N^2)
    // 同样借助hash表,不同的是hash表里面存放的是这个元素的下标,目的是为了不重复计算组合
    public static int ThreeSumFaster(int[] a) {
        HashMap<Integer, List<Integer>> map = new HashMap<>();
        // 记录答案
        int cnt = 0;
        // 记录下标
        for (int i = 0; i < a.length; i++) {
            if (!map.containsKey(a[i]))
                map.put(a[i], new ArrayList<>());
            map.get(a[i]).add(i);
        }

        for (int i = 0; i < a.length; i++) {
            for (int j = i + 1; j < a.length; j++) {
                int t = a[i] + a[j];
                if (map.containsKey(-t)) {
                    int finalJ = j;
                    cnt += map.get(-t).stream().filter(v -> v.compareTo(finalJ) > 0).count();
                }
            }
        }
        return cnt;
    }

时间复杂度最好情况是 O ( N 2 ) O(N^2) O(N2),此时数组中没有重复数,get操作时间复杂度是 O ( 1 ) O(1) O(1)而最差的情况时间复杂度退化为 O ( N 3 ) O(N^3) O(N3),get操作变为 O ( N ) O(N) O(N),但是这种情况很少出现,我们仍然可以认为是趋于 O ( N 2 ) O(N^2) O(N2)的时间复杂度
空间复杂度是 O ( N ) O(N) O(N)

其实适当牺牲一点空间是可以接受的,尤其是换取的空间把时间复杂度降低了一个量级,而且还降低了编码难度,更简单的编码难度意味着更少的bug和调试时间

two point方法 O(N^2)

由于是三数之和,我们没有办法直接用two point的方法来做,而且我也不会写three point,三个指针的移动太复杂了,简单的做法是在while循环外面套上一个for循环,由for来固定一个a[i],然后寻找两个数a[start],a[end]使得它们之和为0
下面是java代码

    // 同样按照two point 的思想,由于是求三数之和,我们外面再套一个for
    public static int ThreeSumFaster_2(int[] a) {
        int start = 0, end = a.length - 1;
        int cnt = 0;
        // 如果没有全为正或者全为负就退出
        if (a[start] > 0 || a[end] < 0)
            return 0;

        for (int i = 0; i < a.length; i++) {
            start = i + 1;
            end = a.length - 1;

            while (start < end) {
                // 现在的判断条件变为三数之和
                int t = a[i] + a[start] + a[end];
                if (t > 0)
                    end -= 1;
                else if (t < 0)
                    start += 1;
                else {
                    // 此时a[i]+a[j]=0,做进一步处理

                    // 跳过为0的情况,0的组合单独计算
                    if (a[start] == 0 && a[end] == 0) {
                        start++;
                        end--;
                        continue;
                    }
                    // 记录重复个数
                    int di = 1, dj = 1;

                    while (start + di < end && a[start + di] == a[start]) {
                        di += 1;
                    }

                    while (end - dj > start && a[end - dj] == a[end]) {
                        dj += 1;
                    }

                    // 如果没有重复数
                    if (di == 1 && dj == 1) {
                        start++;
                        end--;
                        cnt++;
                    } else {
                        // 组合数
                        cnt += cmn(di, 1) * cmn(dj, 1);
                        start += di;
                        end -= dj;
                    }
                }
            }
        }
        // 记录0的个数 特殊处理0和0的组合
        int zn = (int) Arrays.stream(a).filter(value -> value == 0).count();
        if (zn > 2)
            cnt += cmn(zn, 3);
        return cnt;
    }

你可能注意到了它的while和two sum里面的while并不同
这里的while遇到0的时候并不会直接退出,two sum里while碰到0就退出的理由已经阐述过了,这里不能这么做的原因也很简单。
-3 -1 0 1 3 2
当i=0的时候,选取了-3,很显然可以找到0和3来使得它们之和为0。所以我们不能在遇到0的时候直接退出了
另外我们仍然不想在while里面处理0的情况,所以当a[start]+a[end]==0的时候,直接跳过,把所有0的组合在while循环外面计算

该算法的时间复杂度是 O ( N 2 ) O(N^2) O(N2)

题外话:three sum问题有 O ( N l o g N ) O(NlogN) O(NlogN)的解法吗?
对于 3-sum,回答是不知道,不过专家们相信 3-sum 可能的最优算法是平方级别的(《算法》上面的一段话)

总结

总结了上面的解法,另外找到了leetcode上面一些相似题目
其实这种n sum问题还有很多很多衍生问题,例如找到一组数字使得它们差值的绝对值最大/最小,找到一组数使得它们的和最接近target等等问题
一想到这些问题头都大了

上面的代码已经上传到github上了
https://github.com/hhmy27/Alg4_Code/blob/master/src/ch01/part4/ex_1_4_15.java

参考

参考了一下three sum问题的解法,这位作者写的代码更工整
https://github.com/reneargento/algorithms-sedgewick-wayne/blob/f31578352b7774e857a664bf8768bca2200e75e0/src/chapter1/section4/Exercise15_1_TwoSumFaster.java

https://algs4.cs.princeton.edu/14analysis/ThreeSumFast.java.html

《算法》

你可能感兴趣的:(算法)