leetcode--形成两个异或相等数组的三元组数目,从O(N^3)优化到O(N)

 题目是LeetCode第188场周赛的第二题,链接:1442. 形成两个异或相等数组的三元组数目。具体描述为:给你一个整数数组arr。现需要从数组中取三个下标ijk,其中 (0 <= i < j <= k < arr.length) 。ab定义如下:

  • a = arr[i] ^ arr[i + 1] ^ ... ^ arr[j - 1]
  • b = arr[j] ^ arr[j + 1] ^ ... ^ arr[k]

 注意:^表示按位异或操作。

 请返回能够令a == b成立的三元组(i, j , k)的数目。

 示例1:

输入:arr = [2,3,1,6,7]
输出:4
解释:满足题意的三元组分别是 (0,1,2), (0,2,2), (2,3,4) 以及 (2,4,4)

 示例2:

输入:arr = [1,1,1,1,1]
输出:10

 示例3:

输入:arr = [2,3]
输出:0

 示例4:

输入:arr = [1,3,5,7,9]
输出:3

 示例4:

输入:arr = [7,11,12,9,5,2,7,17,22]
输出:8

 这道题的关键在于知道异或操作的一些特性,比如0^x=xx^x=0。我们可以先用一个数组curXor保存累积异或结果,也就是curXor[i]=arr[0]^arr[1]^...^arr[i],然后当要求题目里的比如a= arr[i] ^ arr[i + 1] ^ ... ^ arr[j - 1]时,利用curXor[i-1]^curXor[i-1]=0,我们可以简单地求得a=curXor[j-1]^curXor[i-1],从而避免一层循环。

 所以现在可以简单地用三层循环来解决这个问题了。时间复杂度为 O ( n 3 ) O(n^{3}) O(n3),空间复杂度为 O ( n ) O(n) O(n)

 JAVA版代码如下:

class Solution {
    public int countTriplets(int[] arr) {
        int result = 0;
        int[] curXor = new int[arr.length];
        curXor[0] = arr[0];
        for (int i = 1; i < arr.length; ++i) {
            curXor[i] = arr[i] ^ curXor[i - 1];
        }
        for (int i = 0; i < arr.length - 1; ++i) {
            for (int j = i + 1; j < arr.length; ++j) {
                for (int k = j; k < arr.length; ++k) {
                    int a = curXor[j - 1] ^ (i == 0 ? 0 : curXor[i - 1]);
                    int b = curXor[k] ^ curXor[j - 1];
                    if (a == b) {
                        ++result;
                    }
                }
            }
        }
        return result;
    }
}

 提交结果如下:


 接着可以进行优化,注意到a==b意味着a^b==0,那么我们只要找到符合arr[i]^...^a[j-1]^a[j]^...arr[k]==0(i,k)对,就可以得到k-i对符合条件的(i,j,k)了。时间复杂度降为 O ( n 2 ) O(n^{2}) O(n2),空间复杂度为 O ( n ) O(n) O(n)

 JAVA版代码如下:

class Solution {
    public int countTriplets(int[] arr) {
        int result = 0;
        int[] curXor = new int[arr.length];
        curXor[0] = arr[0];
        for (int i = 1; i < arr.length; ++i) {
            curXor[i] = arr[i] ^ curXor[i - 1];
        }
        for (int i = 0; i < arr.length - 1; ++i) {
            for (int k = i + 1; k < arr.length; ++k) {
                int num = curXor[k] ^ (i == 0 ? 0 : curXor[i - 1]);
                if (num == 0) {
                    result += k - i;
                }
            }
        }
        return result;
    }
}

 提交结果如下:


 其实上面的还可以再继续优化,上面我们在求curXor[k]^curXor[i-1]==意味着curXor[k]==curXor[i-1],所以有点类似之前的两数之和,用一个Map来保存{curXor[i]->[..., i]},遍历curXor的过程中,遇到一个curXor[i],如果存在于Map中,先取得一个链表lst=map.get(curXor[i]),然后遍历这个链表中存有的索引idx,每个都可以与当前索引i组成一对(idx, i),生成i-idx-1个符合条件的组合。时间复杂度最好的情况下为 O ( n ) O(n) O(n),最坏的情况下还是 O ( n 2 ) O(n^{2}) O(n2)(当curXor中的数全一样的时候),空间复杂度为 O ( n ) O(n) O(n)

 JAVA版代码如下:

class Solution {
    public int countTriplets(int[] arr) {
        int result = 0;
        int[] curXor = new int[arr.length];
        curXor[0] = arr[0];
        for (int i = 1; i < arr.length; ++i) {
            curXor[i] = arr[i] ^ curXor[i - 1];
        }
        Map> map = new HashMap<>();
        List lst = new LinkedList<>();
        lst.add(-1);
        map.put(0, lst);
        for (int i = 0; i < curXor.length; ++i) {
            List idxs = map.getOrDefault(curXor[i], new LinkedList<>());
            if (map.containsKey(curXor[i])) {
                for (int idx : idxs) {
                    result += i - idx - 1;
                }
                idxs.add(i);
            }
            else {
                idxs.add(i);
            }
            map.put(curXor[i], idxs);
        }
        return result;
    }
}

 提交结果如下:


 继续对上面的方法进行优化就可以得到最终的 O ( n ) O(n) O(n)解法了,上面我们在遇到一个curXor[i]的时候都需要到Map中取得其对应的一个列表,然后需要遍历列表的各个元素(记为idx[j]),然后进行求和得到 a d d = ∑ j i − ( i d x [ j ] + 1 ) add=\sum_{j}i-(idx[j]+1) add=ji(idx[j]+1),假设我们记录某个curXor[i]出现次数count,其实上面的可以化简为 a d d = c o u n t ∗ i − ∑ j ( i d x [ j ] + 1 ) add=count*i-\sum_{j}(idx[j]+1) add=countij(idx[j]+1),所以又可以用一个累加变量curSum记录 ∑ j ( i d x [ j ] + 1 ) \sum_{j}(idx[j]+1) j(idx[j]+1),从而我们只需计算 a d d = c o u n t ∗ i − c u r S u m add=count*i-curSum add=counticurSum。这么做避免了内嵌循环,将时间复杂度真正降为 O ( n ) O(n) O(n),空间复杂度仍为 O ( n ) O(n) O(n)

 JAVA版代码如下:

class Solution {
    public int countTriplets(int[] arr) {
        int result = 0;
        int curXor = 0;
        // int[0]:出现次数,int[1]:多个索引(i+1)累加和
        Map map = new HashMap<>();
        map.put(0, new int[] {1, 0});
        for (int i = 0; i < arr.length; ++i) {
            curXor ^= arr[i];
            int[] countAndSum = map.getOrDefault(curXor, new int[2]);
            if (map.containsKey(curXor)) {
                result += countAndSum[0] * i - countAndSum[1];
                ++countAndSum[0];
                countAndSum[1] += i + 1;
            }
            else {
                countAndSum[0] = 1;
                countAndSum[1] = i + 1;
            }
            map.put(curXor, countAndSum);
        }
        return result;
    }
}

 提交结果如下:


 Python版代码如下:

class Solution:
    def countTriplets(self, arr: List[int]) -> int:
        curXor = 0
        result = 0
        xor2countAndSum = {0 : [1, 0]}
        for i in range(len(arr)):
            curXor ^= arr[i]
            if curXor in xor2countAndSum:
                countAndSum = xor2countAndSum[curXor]
                result += countAndSum[0] * i - countAndSum[1]
                countAndSum[0] += 1
                countAndSum[1] += i + 1
                xor2countAndSum[curXor] = countAndSum
            else:
                countAndSum = [1, i + 1]
                xor2countAndSum[curXor] = countAndSum
        return result

 提交结果如下:


你可能感兴趣的:(LeetCode)