HDU 6059 Kanade's trio 字典树 统计 容斥

链接

http://acm.hdu.edu.cn/showproblem.php?pid=6059

题意

给出 A[1..n] (1<=n<=5105) (0<=A[i]<230) ,要求统计三元组 (i,j,k) 的个数使其满足 i<j<k 并且 (A[i]xorA[j])<(A[j]xorA[k])

思路

事先把所有数字插入字典树中,用字典树维护 A[k] 的信息,接着对每一个 A[i] ,枚举其二进制最高位小于 A[k] 的位数,考虑这样一个情况:若当前枚举到了 A[i] 的二进制第5位比 A[k] 小,那么 A[i] A[k] 的第30位到第6位都是相同的,此时就不用考虑 A[j] 的第30位到第6位如何,只考虑第5位的情况就好。
考虑当前位置 A[i] 的情况,若当前位置 A[i] 为0,那么 A[j] 的相同位置要为0才能使两者异或值为0,此时 A[k] 为1,这时满足条件的 A[j] A[k] 对数可以计入答案。当前位置 A[i] 为1的情况同理( A[j] 为1, A[k] 为0)。

对于 A[j] A[k] 对数的统计,在插入 A[k] 时,之前插入的数就都成了 A[j] 。因此用一个cnt[i][j]数组记录下第i位为j的数之前出现了几次,那么在插入时,对于这一位置 A[k] 为0的情况,之前有多少的 A[j] 在这一位为1,就是此时满足条件的 A[j] 的个数。代码里的cnt[i][nxt ^ 1]就是此时符合条件的 A[j] 个数。

当我们把一个数从字典树中去掉时,也要考虑去掉这个数留下来的统计值。
这题特殊的地方在于,插入是连续的,之后是连续的删除,所以在插入完成后可以把cnt[i][j]数组清空一遍,用来记录第i位为j的数被删除了几次。
考虑两个方面:
- 一个是这个数作为 A[k] 直接被去掉带来的影响,像之前一样减去其前面已经被删去的 A[j] 的个数(依然是cnt[now][nxt ^ 1])就好。(这一步操作在Trie::Insert()里面,与插入时的操作类似)
- 还有一个是这个数作为 A[j] 带来的影响,因为这个数已经不能和后面的 A[k] 组合产生贡献了,考虑到在统计时,当前位的 A[k] 已经把可以与其组合的 A[j] 个数统计在了sum[tmp]中,这里面还需去掉被删去的 A[j] ,被删去的 A[j] 已经被统计在了cnt[i][nxt]中,现有的 A[k] 被存在了val[tmp]中,这一部分不能被计入答案,相乘,减去。(sum[tmp] - val[tmp] * cnt[i][nxt]这一步在函数solve()里)

希望思路说清楚了,详见代码

代码

#include 
#include 
#include 
#include 

using namespace std;

#define MS(x, y) memset(x, y, sizeof(x))

typedef long long LL;
const int MAXN = 5e5 + 5;
int bits[32];

struct Trie {
  int tot, root;
  int val[MAXN * 30], ch[MAXN * 30][2];
  LL sum[MAXN * 30], cnt[MAXN][2];

  int newnode() {
    val[tot] = sum[tot] = 0;
    ch[tot][0] = ch[tot][1] = -1;
    return tot++;
  }

  void init() {
    tot = 0;
    root = newnode();
    MS(cnt, 0);
  }

  void Insert(int x, int v) {
    int now = root, nxt, tmp;
    for (int i = 30; i >= 0; --i) {
      nxt = !!(x & bits[i]);
      if (ch[now][nxt] == -1) ch[now][nxt] = newnode();
      now = ch[now][nxt];
      ++cnt[i][nxt];
      sum[now] += v * cnt[i][nxt ^ 1];
      val[now] += v;
    }
  }

  LL solve(int x) {
    LL ret = 0;
    int now = root, tmp, nxt;
    for (int i = 30; i >= 0; --i) {
      nxt = !!(x & bits[i]);
      tmp = ch[now][nxt ^ 1];
      now = ch[now][nxt];
      if (tmp != -1) {
        ret += sum[tmp] - val[tmp] * cnt[i][nxt];
      }
      if (now == -1) break;
    }
    return ret;
  }
};

int n;
int a[MAXN];
LL ans;
Trie trie;

int main() {
  bits[0] = 1;
  for (int i = 1; i < 32; ++i) bits[i] = bits[i - 1] << 1;
  int T;
  scanf("%d", &T);
  while (T--) {
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) scanf("%d", a + i);
    ans = 0;
    trie.init();
    for (int i = 1; i <= n; ++i) trie.Insert(a[i], 1);
    MS(trie.cnt, 0);
    for (int i = 1; i < n; ++i) {
      trie.Insert(a[i], -1);
      ans += trie.solve(a[i]);
    }
    printf("%I64d\n", ans);
  }
}

你可能感兴趣的:(容斥,字典树)