首先如果给定一些数,询问这些数中哪个数^给定的数的值最大的话,我们可以建立一颗trie树,根连接的两条边分别为0,1,表示二进制下第15位,那么我们可以建立一颗trie树,每一条从根到叶子节点的链表示一个2^16以内的数,开始每个节点的cnt都是0,那么每插入一个元素,将表示这个值的链上所有位置的cnt++,那么对于一个值要求使得^最大,如果这个值的某一位是1,那么我们最好要找到一个为0的数来和他^,那么判断下0儿子的cnt是不是大于0,然后做就好了。
那么对于这棵树,我们可以先将1为根,然后对于两个点x,y之间的链拆成x,lca和y,lca的两条链,现在问题就转化为了求一个深度递增的链上所有值和给定值的^值最大,那么我们可以建立可持久化trie,每个节点继承父节点的trie树,我们只需要用x的trie树减去lca father的trie树做开始的问题就好了。
反思:调试的时候输出调试的,然后答案更新的只按照一部分更新的,忘了改回去了。因为这个题没看题,是别人讲的题意,所以没看到多组数据,在这儿一直错= =。
//By BLADEVIL #include <cstdio> #include <cstring> #include <algorithm> #define maxn 200010 using namespace std; struct ww { int son[2]; int cnt; ww() { cnt=0; memset(son,0,sizeof son); } }t[maxn<<5]; int n,m,l,tot; int a[maxn],pre[maxn<<1],other[maxn<<1],last[maxn],que[maxn],dis[maxn],jump[maxn][20]; void connect(int x,int y) { pre[++l]=last[x]; last[x]=l; other[l]=y; } int lca(int x,int y) { if (dis[x]>dis[y]) swap(x,y); int dep=dis[y]-dis[x]; for (int i=0;i<=18;i++) if (dep&(1<<i)) y=jump[y][i]; if (x==y) return x; for (int i=18;i>=0;i--) if (jump[x][i]!=jump[y][i]) x=jump[x][i],y=jump[y][i]; return jump[x][0]; } void build(int &x,int dep) { if (!x) x=++tot; if (dep<0) return ; build(t[x].son[0],dep-1); build(t[x].son[1],dep-1); } void insert(int &x,int rot,int dep,int key) { if (!x) x=++tot; if (dep==-2) return ; if (key&(1<<dep)) { insert(t[x].son[1],t[rot].son[1],dep-1,key); t[x].son[0]=t[rot].son[0]; } else { insert(t[x].son[0],t[rot].son[0],dep-1,key); t[x].son[1]=t[rot].son[1]; } t[x].cnt+=t[rot].cnt+1; //printf("|%d %d\n",t[x].cnt,x); } int query(int rx,int lx,int key,int dep) { if (dep==-2) return 0; //printf("%d %d %d %d\n",t[rx].son[1],t[rx].son[0],t[t[rx].son[1]].cnt,t[t[rx].son[0]].cnt); int ans=0; if (key&(1<<dep)) { if (t[t[rx].son[0]].cnt-t[t[lx].son[0]].cnt) { ans=1<<dep; ans+=query(t[rx].son[0],t[lx].son[0],key,dep-1); } else ans=query(t[rx].son[1],t[lx].son[1],key,dep-1); } else { if (t[t[rx].son[1]].cnt-t[t[lx].son[1]].cnt) { ans=1<<dep; ans+=query(t[rx].son[1],t[lx].son[1],key,dep-1); } else ans=query(t[rx].son[0],t[lx].son[0],key,dep-1); } //printf("%d\n",ans); return ans; } void work() { memset(t,0,sizeof t); memset(last,0,sizeof last); memset(dis,0,sizeof dis); tot=n; l=0; for (int i=1;i<=n;i++) scanf("%d",&a[i]); for (int i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); connect(x,y); connect(y,x); } int head=0,tail=1; que[1]=1; dis[1]=1; while (head<tail) { int cur=que[++head]; for (int p=last[cur];p;p=pre[p]) { if (dis[other[p]]) continue; que[++tail]=other[p]; dis[other[p]]=dis[cur]+1; } } //for (int i=1;i<=n;i++) printf(i==n?"%d\n":"%d ",que[i]); jump[1][0]=++tot; for (int i=1;i<=n;i++) for (int p=last[que[i]];p;p=pre[p]) if (dis[other[p]]>dis[que[i]]) jump[other[p]][0]=que[i]; for (int i=1;i<=18;i++) for (int j=1;j<=n;j++){ int cur=que[j]; jump[cur][i]=jump[jump[cur][i-1]][i-1]; } build(jump[1][0],15); for (int i=1;i<=n;i++) insert(que[i],jump[que[i]][0],15,a[que[i]]); //for (int i=1;i<=tot;i++) printf("%d %d %d %d\n",i,t[i].son[0],t[i].son[1],t[i].cnt); //int x,y; scanf("%d%d",&x,&y); printf("%d\n",lca(x,y)); while (m--) { int x,y,z; scanf("%d%d%d",&x,&y,&z); int root=lca(x,y); int ans=0; ans=max(query(x,jump[root][0],z,15),query(y,jump[root][0],z,15)); //ans=query(y,jump[root][0],z,15); printf("%d\n",ans); } } int main() { while (scanf("%d%d",&n,&m)!=EOF) work(); return 0; }