题目是LeetCode第188场周赛的第二题,链接:1442. 形成两个异或相等数组的三元组数目。具体描述为:给你一个整数数组arr
。现需要从数组中取三个下标i
、j
和k
,其中 (0 <= i < j <= k < arr.length
) 。a
和b
定义如下:
a = arr[i] ^ arr[i + 1] ^ ... ^ arr[j - 1]
b = arr[j] ^ arr[j + 1] ^ ... ^ arr[k]
注意:^
表示按位异或操作。
请返回能够令a == b
成立的三元组(i, j , k)
的数目。
示例1:
输入:arr = [2,3,1,6,7]
输出:4
解释:满足题意的三元组分别是 (0,1,2), (0,2,2), (2,3,4) 以及 (2,4,4)
示例2:
输入:arr = [1,1,1,1,1]
输出:10
示例3:
输入:arr = [2,3]
输出:0
示例4:
输入:arr = [1,3,5,7,9]
输出:3
示例4:
输入:arr = [7,11,12,9,5,2,7,17,22]
输出:8
这道题的关键在于知道异或操作的一些特性,比如0^x=x
,x^x=0
。我们可以先用一个数组curXor
保存累积异或结果,也就是curXor[i]=arr[0]^arr[1]^...^arr[i]
,然后当要求题目里的比如a= arr[i] ^ arr[i + 1] ^ ... ^ arr[j - 1]
时,利用curXor[i-1]^curXor[i-1]=0
,我们可以简单地求得a=curXor[j-1]^curXor[i-1]
,从而避免一层循环。
所以现在可以简单地用三层循环来解决这个问题了。时间复杂度为 O ( n 3 ) O(n^{3}) O(n3),空间复杂度为 O ( n ) O(n) O(n)。
JAVA版代码如下:
class Solution {
public int countTriplets(int[] arr) {
int result = 0;
int[] curXor = new int[arr.length];
curXor[0] = arr[0];
for (int i = 1; i < arr.length; ++i) {
curXor[i] = arr[i] ^ curXor[i - 1];
}
for (int i = 0; i < arr.length - 1; ++i) {
for (int j = i + 1; j < arr.length; ++j) {
for (int k = j; k < arr.length; ++k) {
int a = curXor[j - 1] ^ (i == 0 ? 0 : curXor[i - 1]);
int b = curXor[k] ^ curXor[j - 1];
if (a == b) {
++result;
}
}
}
}
return result;
}
}
提交结果如下:
接着可以进行优化,注意到a==b
意味着a^b==0
,那么我们只要找到符合arr[i]^...^a[j-1]^a[j]^...arr[k]==0
的(i,k)
对,就可以得到k-i
对符合条件的(i,j,k)
了。时间复杂度降为 O ( n 2 ) O(n^{2}) O(n2),空间复杂度为 O ( n ) O(n) O(n)。
JAVA版代码如下:
class Solution {
public int countTriplets(int[] arr) {
int result = 0;
int[] curXor = new int[arr.length];
curXor[0] = arr[0];
for (int i = 1; i < arr.length; ++i) {
curXor[i] = arr[i] ^ curXor[i - 1];
}
for (int i = 0; i < arr.length - 1; ++i) {
for (int k = i + 1; k < arr.length; ++k) {
int num = curXor[k] ^ (i == 0 ? 0 : curXor[i - 1]);
if (num == 0) {
result += k - i;
}
}
}
return result;
}
}
提交结果如下:
其实上面的还可以再继续优化,上面我们在求curXor[k]^curXor[i-1]==
意味着curXor[k]==curXor[i-1]
,所以有点类似之前的两数之和,用一个Map来保存{curXor[i]->[..., i]}
,遍历curXor
的过程中,遇到一个curXor[i]
,如果存在于Map中,先取得一个链表lst=map.get(curXor[i])
,然后遍历这个链表中存有的索引idx
,每个都可以与当前索引i
组成一对(idx, i)
,生成i-idx-1
个符合条件的组合。时间复杂度最好的情况下为 O ( n ) O(n) O(n),最坏的情况下还是 O ( n 2 ) O(n^{2}) O(n2)(当curXor
中的数全一样的时候),空间复杂度为 O ( n ) O(n) O(n)。
JAVA版代码如下:
class Solution {
public int countTriplets(int[] arr) {
int result = 0;
int[] curXor = new int[arr.length];
curXor[0] = arr[0];
for (int i = 1; i < arr.length; ++i) {
curXor[i] = arr[i] ^ curXor[i - 1];
}
Map> map = new HashMap<>();
List lst = new LinkedList<>();
lst.add(-1);
map.put(0, lst);
for (int i = 0; i < curXor.length; ++i) {
List idxs = map.getOrDefault(curXor[i], new LinkedList<>());
if (map.containsKey(curXor[i])) {
for (int idx : idxs) {
result += i - idx - 1;
}
idxs.add(i);
}
else {
idxs.add(i);
}
map.put(curXor[i], idxs);
}
return result;
}
}
提交结果如下:
继续对上面的方法进行优化就可以得到最终的 O ( n ) O(n) O(n)解法了,上面我们在遇到一个curXor[i]
的时候都需要到Map中取得其对应的一个列表,然后需要遍历列表的各个元素(记为idx[j]
),然后进行求和得到 a d d = ∑ j i − ( i d x [ j ] + 1 ) add=\sum_{j}i-(idx[j]+1) add=∑ji−(idx[j]+1),假设我们记录某个curXor[i]
出现次数count
,其实上面的可以化简为 a d d = c o u n t ∗ i − ∑ j ( i d x [ j ] + 1 ) add=count*i-\sum_{j}(idx[j]+1) add=count∗i−∑j(idx[j]+1),所以又可以用一个累加变量curSum
记录 ∑ j ( i d x [ j ] + 1 ) \sum_{j}(idx[j]+1) ∑j(idx[j]+1),从而我们只需计算 a d d = c o u n t ∗ i − c u r S u m add=count*i-curSum add=count∗i−curSum。这么做避免了内嵌循环,将时间复杂度真正降为 O ( n ) O(n) O(n),空间复杂度仍为 O ( n ) O(n) O(n)。
JAVA版代码如下:
class Solution {
public int countTriplets(int[] arr) {
int result = 0;
int curXor = 0;
// int[0]:出现次数,int[1]:多个索引(i+1)累加和
Map map = new HashMap<>();
map.put(0, new int[] {1, 0});
for (int i = 0; i < arr.length; ++i) {
curXor ^= arr[i];
int[] countAndSum = map.getOrDefault(curXor, new int[2]);
if (map.containsKey(curXor)) {
result += countAndSum[0] * i - countAndSum[1];
++countAndSum[0];
countAndSum[1] += i + 1;
}
else {
countAndSum[0] = 1;
countAndSum[1] = i + 1;
}
map.put(curXor, countAndSum);
}
return result;
}
}
提交结果如下:
Python版代码如下:
class Solution:
def countTriplets(self, arr: List[int]) -> int:
curXor = 0
result = 0
xor2countAndSum = {0 : [1, 0]}
for i in range(len(arr)):
curXor ^= arr[i]
if curXor in xor2countAndSum:
countAndSum = xor2countAndSum[curXor]
result += countAndSum[0] * i - countAndSum[1]
countAndSum[0] += 1
countAndSum[1] += i + 1
xor2countAndSum[curXor] = countAndSum
else:
countAndSum = [1, i + 1]
xor2countAndSum[curXor] = countAndSum
return result
提交结果如下: