我不会说我是为了学博弈才去做这题的……
题意:
给定一棵N个点的树,1号点为根,每个节点是白色或者黑色。双方轮流操作,每次选择一个白色节点,将从这个点到根的路径上的点全部染成黑色。问先手是否必胜,以及第一步可选节点有哪些。N<=100000。
分析:
首先是博弈方面的分析。令SG[x]为,只考虑以x为根的子树时的SG值。令g[x]为,只考虑以x为根的子树时,所有后继局面的SG值的集合。那么SG[x]=mex{g[x]}。
我们考虑怎么计算g[x]。假设x的儿子为v1,v2,...,vk,令sum[x]=SG[v1] xor SG[v2] xor .. xor SG[vk]。考虑两种情况:
1、x为黑色。不难发现以x的每个儿子为根的子树是互相独立的。假设这一步选择了vi子树的某一个节点,那么转移到的局面的SG值就是sum[x] xor SG[vi] xor (在g[vi]中的某个值)。那么我们只需将每个g[vi]整体xor上sum[x] xor SG[vi]再合并到g[x]即可。
2、x为白色。这时候我们多了一种选择,即选择x点。可以发现,选择x点之后x点变成黑色,所有子树仍然独立,而转移到的局面的SG值就是sum[x]。如果此时不选择x而是选择x子树里的某个白色节点,那么x一样会被染成黑色,所有子树依然独立。所以x为白色时只是要向g[x]中多插入一个值sum[x]。
这样我们就有一个自底向上的DP了。朴素的复杂度是O(N^2)的。
接下来再考虑第一步可选的节点。我们要考虑选择哪些节点之后整个局面的SG值会变成0。假设我们选择了x点,那么从x到根的路径都会被染黑,将原来的树分成了一堆森林。我们令up[x]为,不考虑以x为根的子树,将从x到根的路径染黑,剩下的子树的SG值的xor和。那么up[x]=up[fa[x]] xor sum[fa[x]] xor sg[x],其中fa[x]为x的父亲节点编号。那么如果点x初始颜色为白色且up[x] xor sum[x]=0,那么这个点就是第一步可选的节点。这一步是O(N)的。
剩下的就是优化求SG了。我们需要一个可以快速整体xor并合并的数据结构。整体xor可以用二进制Trie打标记实现,至于合并,用启发式合并是O(Nlog^2N)的,而用线段树合并的方法可以做到O(NlogN)。不过还需要注意各种常数的问题……比如不要用指针,Trie的节点不用记大小,只要记是否满了……
做这题的时候先去膜拜了主席的题解……然后又去膜拜了主席冬令营的讲课……最后还去膜拜了翱犇的代码……然后几乎是照着抄了一遍……
代码:(SPOJ上排到了倒数第三……)
//SPOJ11414; COT3; Game Theory + Trie Merging #include <cstdio> #include <iostream> #include <algorithm> #include <ctime> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef unsigned int uint; typedef long double ld; #define pair(x, y) make_pair(x, y) #define runtime() ((double)clock() / CLOCKS_PER_SEC) #define N 100000 #define LOG 17 struct edge { int next, node; } e[N << 1 | 1]; int head[N + 1], tot = 0; inline void addedge(int a, int b) { e[++tot].next = head[a]; head[a] = tot, e[tot].node = b; } #define SIZE 2000000 struct Node { int l, r; bool full; int d; } tree[SIZE + 1]; #define l(x) tree[x].l #define r(x) tree[x].r #define d(x) tree[x].d #define full(x) tree[x].full int root[N + 1], tcnt = 0; int n, col[N + 1], sg[N + 1], sum[N + 1], up[N + 1]; bool v[N + 1]; inline int newnode() { return ++tcnt; } inline void update(int x) { full(x) = full(l(x)) && full(r(x)); } inline void push(int x) { if (d(x)) { if (l(x)) d(l(x)) ^= d(x) >> 1; if (r(x)) d(r(x)) ^= d(x) >> 1; if (d(x) & 1) swap(l(x), r(x)); d(x) = 0; } } int merge(int l, int r) { if (!l || full(r)) return r; if (!r || full(l)) return l; push(l), push(r); int ret = newnode(); l(ret) = merge(l(l), l(r)); r(ret) = merge(r(l), r(r)); update(ret); return ret; } inline int rev(int x) { int r = 0; for (int i = LOG; i > 0; --i) if (x >> i - 1 & 1) r += 1 << LOG - i; return r; } void insert(int x, int v, int p) { push(x); if (v >> p - 1 & 1) { if (!r(x)) r(x) = newnode(); if (p != 1) insert(r(x), v, p - 1); else full(r(x)) = true; } else { if (!l(x)) l(x) = newnode(); if (p != 1) insert(l(x), v, p - 1); else full(l(x)) = true; } update(x); } int mex(int x) { int r = 0; for (int i = LOG; x; --i) { push(x); if (full(l(x))) r += 1 << i - 1, x = r(x); else x = l(x); } return r; } void calc(int x) { v[x] = true; int xorsum = 0; for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; calc(node); v[node] = false; xorsum ^= sg[node]; } for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; d(root[node]) ^= rev(xorsum ^ sg[node]); root[x] = merge(root[x], root[node]); } if (!col[x]) insert(root[x], xorsum, LOG); sg[x] = mex(root[x]); sum[x] = xorsum; } int ans[N + 1], cnt = 0; void find(int x) { v[x] = true; if ((up[x] ^ sum[x]) == 0 && col[x] == 0) ans[++cnt] = x; for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; up[node] = up[x] ^ sum[x] ^ sg[node]; find(node); } } int main(int argc, char* argv[]) { #ifdef KANARI freopen("input.txt", "r", stdin); freopen("output.txt", "w", stdout); #endif scanf("%d", &n); for (int i = 1; i <= n; ++i) scanf("%d", col + i); for (int i = 1; i < n; ++i) { static int x, y; scanf("%d%d", &x, &y); addedge(x, y), addedge(y, x); } for (int i = 1; i <= n; ++i) root[i] = newnode(); calc(1); for (int i = 1; i <= n; ++i) v[i] = false; find(1); if (cnt == 0) printf("-1\n"); else { sort(ans + 1, ans + cnt + 1); for (int i = 1; i <= cnt; ++i) printf("%d\n", ans[i]); } // cerr << runtime() << endl; // for (int i = 1; i <= n; ++i) printf("%d ", sg[i]); fclose(stdin); fclose(stdout); return 0; }
// #include <cstdio> #include <cstdlib> #include <cstring> #include <iostream> #include <algorithm> #include <climits> #include <cmath> #include <utility> #include <set> #include <map> #include <queue> #include <ios> #include <iomanip> #include <ctime> #include <numeric> #include <functional> #include <fstream> #include <sstream> #include <string> #include <vector> #include <bitset> #include <cstdarg> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef unsigned int uint; typedef long double ld; #define pair(x, y) make_pair(x, y) #define runtime() ((double)clock() / CLOCKS_PER_SEC) inline int read() { static int r; static char c; r = 0, c = getchar(); while (c < '0' || c > '9') c = getchar(); while (c >= '0' && c <= '9') r = r * 10 + (c - '0'), c = getchar(); return r; } template <typename T> inline void print(T *a, int n) { for (int i = 1; i < n; ++i) cout << a[i] << " "; cout << a[n] << endl; } #define PRINT(__l, __r, __begin, __end) { for (int __i = __begin; __i != __end; ++__i) cout << __l __i __r << " "; cout << endl; } #define N 100000 #define LOG 17 struct edge { int next, node; } e[N << 1 | 1]; int head[N + 1], tot = 0; inline void addedge(int a, int b) { e[++tot].next = head[a]; head[a] = tot, e[tot].node = b; } struct Node { Node *l, *r; bool full; int d; Node() { l = r = NULL, full = false, d = 0; } } *root[N + 1]; int n, col[N + 1], sg[N + 1], sum[N + 1], up[N + 1]; bool v[N + 1]; inline void update(Node *x) { if (x->l && x->r) x->full = x->l->full && x->r->full; else x->full = false; } inline void applyDelta(Node *x, int v) { x->d ^= v; } inline void push(Node *x) { if (x->d) { if (x->l) applyDelta(x->l, x->d >> 1); if (x->r) applyDelta(x->r, x->d >> 1); if (x->d & 1) swap(x->l, x->r); x->d = 0; } } Node* merge(Node *l, Node *r) { if (l == NULL || (r != NULL && r->full)) return r; if (r == NULL || (l != NULL && l->full)) return l; push(l), push(r); Node *ret = new Node(); ret->l = merge(l->l, r->l); ret->r = merge(l->r, r->r); update(ret); return ret; } inline int rev(int x) { int r = 0; for (int i = LOG; i > 0; --i) if (x >> i - 1 & 1) r += 1 << LOG - i; return r; } void insert(Node *x, int v, int p) { push(x); if (v >> p - 1 & 1) { if (x->r == NULL) x->r = new Node(); if (p != 1) insert(x->r, v, p - 1); else x->r->full = true; } else { if (x->l == NULL) x->l = new Node(); if (p != 1) insert(x->l, v, p - 1); else x->l->full = true; } update(x); } int mex(Node *x) { int r = 0; for (int i = LOG; x != NULL; --i) { push(x); if (x->l && x->l->full) r += 1 << i - 1, x = x->r; else x = x->l; } return r; } void calc(int x) { v[x] = true; int xorsum = 0; for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; calc(node); v[node] = false; xorsum ^= sg[node]; } for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; applyDelta(root[node], rev(xorsum ^ sg[node])); root[x] = merge(root[x], root[node]); } if (!col[x]) insert(root[x], xorsum, LOG); sg[x] = mex(root[x]); sum[x] = xorsum; } int ans[N + 1], cnt = 0; void find(int x) { v[x] = true; if ((up[x] ^ sum[x]) == 0 && col[x] == 0) ans[++cnt] = x; for (int i = head[x]; i; i = e[i].next) { int node = e[i].node; if (v[node]) continue; up[node] = up[x] ^ sum[x] ^ sg[node]; find(node); } } int main(int argc, char* argv[]) { #ifdef KANARI freopen("input.txt", "r", stdin); freopen("output.txt", "w", stdout); #endif scanf("%d", &n); for (int i = 1; i <= n; ++i) scanf("%d", col + i); for (int i = 1; i < n; ++i) { static int x, y; scanf("%d%d", &x, &y); addedge(x, y), addedge(y, x); } for (int i = 1; i <= n; ++i) root[i] = new Node(); calc(1); for (int i = 1; i <= n; ++i) v[i] = false; find(1); if (cnt == 0) printf("-1\n"); else { sort(ans + 1, ans + cnt + 1); for (int i = 1; i <= cnt; ++i) printf("%d\n", ans[i]); } // for (int i = 1; i <= n; ++i) printf("%d ", sg[i]); fclose(stdin); fclose(stdout); return 0; }