传送门:【Tsinsen】A1486. 树
题目分析:点分治里面套字典树。暴力更新,得到从重心出发的一个子树的所有路径,在字典树中能反着走就反着走,返回能构成的最大异或值。查询完后将这个子树的所有路径插入到字典树中。
因为每一层N个点,一共logN层,每个点在字典树中是31个点,所以最多即NlogN*31个节点。
这次总算是空间复杂度没考虑错了……
代码如下:
#include <cstdio> #include <cstring> #include <algorithm> using namespace std ; typedef long long LL ; #define rep( i , a , b ) for ( int i = ( a ) ; i < ( b ) ; ++ i ) #define For( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i ) #define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i ) #define clr( a , x ) memset ( a , x , sizeof a ) #define ls ( o << 1 ) #define rs ( o << 1 | 1 ) #define lson l , m #define rson m + 1 , r #define mid ( ( l + r ) >> 1 ) const int MAXN = 100005 ; const int MAXE = 200005 ; struct Dictionary_Tree { int next[MAXN * 32][2] ; int pre[MAXN * 32] ; int maxv[MAXN * 32] ; int p ; int root ; int newnode () { next[p][0] = next[p][1] = 0 ; maxv[p] = 0 ; return p ++ ; } void init () { p = 0 ; root = newnode () ; } void insert ( int num , int v ) { int now = root ; rev ( i , 30 , 0 ) { int x = ( num >> i ) & 1 ; if ( !next[now][x] ) next[now][x] = newnode () ; pre[next[now][x]] = now ; now = next[now][x] ; } maxv[now] = max ( maxv[now] , v ) ; //printf ( "%d\n" , maxv[now] ) ; while ( now ) { maxv[now] = max ( maxv[now] , max ( maxv[next[now][0]] , maxv[next[now][1]] ) ) ; now = pre[now] ; } } int find ( int num , int v ) { int now = root , res = 0 ; rev ( i , 30 , 0 ) { int x = ( num >> i ) & 1 ; if ( next[now][x ^ 1] && maxv[next[now][x ^ 1]] >= v ) { now = next[now][x ^ 1] ; res |= ( 1 << i ) ; } else if ( next[now][x] && maxv[next[now][x]] >= v ) { now = next[now][x] ; } else return -1 ; } return res ; } } ; struct Edge { int v , n ; Edge () {} Edge ( int v , int n ) : v ( v ) , n ( n ) {} } ; struct Node { int d , c ; Node () {} Node ( int d , int c ) : d ( d ) , c ( c ) {} } ; Dictionary_Tree T ; Node S[MAXN] ; Edge E[MAXE] ; int H[MAXN] , cntE ; int love[MAXN] ; int val[MAXN] ; int n , k ; int Q[MAXN] , head , tail ; int siz[MAXN] ; int pre[MAXN] ; int vis[MAXN] , Time ; int dis[MAXN] ; int num[MAXN] ; int top ; int ans ; void clear () { ans = -1 ; ++ Time ; cntE = 0 ; clr ( H , -1 ) ; } void addedge ( int u , int v ) { E[cntE] = Edge ( v , H[u] ) ; H[u] = cntE ++ ; } int get_root ( int src ) { head = tail = 0 ; Q[tail ++] = src ; pre[src] = 0 ; for ( ; head < tail ; ++ head ) { int u = Q[head] ; for ( int i = H[u] ; ~i ; i = E[i].n ) { int v = E[i].v ; if ( vis[v] == Time || v == pre[u] ) continue ; pre[v] = u ; Q[tail ++] = v ; } } int root = src , root_size = tail , tot_size = tail ; for ( -- head ; head >= 0 ; -- head ) { int u = Q[head] ; int cnt = 0 ; siz[u] = 1 ; for ( int i = H[u] ; ~i ; i = E[i].n ) { int v = E[i].v ; if ( vis[v] == Time || v == pre[u] ) continue ; if ( siz[v] > cnt ) cnt = siz[v] ; siz[u] += siz[v] ; } cnt = max ( cnt , tot_size - siz[u] ) ; if ( cnt < root_size ) { root_size = cnt ; root = u ; } } return root ; } void get_dis ( int src , int init_val , int init_cnt ) { head = tail = 0 ; Q[tail ++] = src ; pre[src] = 0 ; dis[src] = init_val ; num[src] = init_cnt ; top = 0 ; for ( ; head < tail ; ++ head ) { int u = Q[head] ; S[top ++] = Node ( dis[u] , num[u] ) ; for ( int i = H[u] ; ~i ; i = E[i].n ) { int v = E[i].v ; if ( vis[v] == Time || v == pre[u] ) continue ; pre[v] = u ; dis[v] = dis[u] ^ val[v] ; num[v] = num[u] + love[v] ; Q[tail ++] = v ; } } } void dfs ( int u ) { int root = get_root ( u ) ; vis[root] = Time ; if ( love[root] >= k ) ans = max ( ans , val[root] ) ; T.init () ; for ( int i = H[root] ; ~i ; i = E[i].n ) { int v = E[i].v ; if ( vis[v] == Time ) continue ; get_dis ( v , val[v] , love[v] ) ; rep ( j , 0 , top ) { int tmp = T.find ( S[j].d ^ val[root] , k - S[j].c - love[root] ) ; ans = max ( ans , tmp ) ; if ( S[j].c + love[root] >= k ) ans = max ( ans , S[j].d ^ val[root] ) ; if ( S[j].c >= k && S[j].d > ans ) ans = S[j].d ; } rep ( j , 0 , top ) T.insert ( S[j].d , S[j].c ) ; } for ( int i = H[root] ; ~i ; i = E[i].n ) if ( vis[E[i].v] != Time ) dfs ( E[i].v ) ; } void solve () { int u , v ; clear () ; For ( i , 1 , n ) scanf ( "%d" , &love[i] ) ; For ( i , 1 , n ) scanf ( "%d" , &val[i] ) ; rep ( i , 1 , n ) { scanf ( "%d%d" , &u , &v ) ; addedge ( u , v ) ; addedge ( v , u ) ; } dfs ( 1 ) ; printf ( "%d\n" , ans ) ; } int main () { Time = 0 ; clr ( vis , 0 ) ; while ( ~scanf ( "%d%d" , &n , &k ) ) solve () ; return 0 ; }