传送门:【HDOJ】4718 The LCIS on the Tree
题目分析:昨晚调了半天。。最后还是在床上找到错误的地方,用手机A掉的# _ #
树上的LCIS,首先还是树链剖分,然后就是和线段树一样维护一个区间的左右端点以及区间左最长递增,右最长递增,区间内最长递增。因为是在树上,所以会有两个相反的方向,所以我们同时还要保存递减的最长序列。查询的时候还要保存之前左最长以及右最长的长度。
其他的么。。。就是需要足够的细心。。。。题目不难
代码如下:
#include <cstdio> #include <cstring> #include <algorithm> using namespace std ; typedef long long LL ; #define travel( e , H , u ) for ( Edge* e = H[u] ; e ; e = e -> next ) #define rep( i , a , b ) for ( int i = ( a ) ; i < ( b ) ; ++ i ) #define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i ) #define FOR( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i ) #define clr( a , x ) memset ( a , x , sizeof a ) #define cpy( a , x ) memcpy ( a , x , sizeof a ) #define ls ( o << 1 ) #define rs ( o << 1 | 1 ) #define lson ls , l , m #define rson rs , m + 1 , r #define rt o , l , r #define root 1 , 1 , n #define mid ( ( l + r ) >> 1 ) const int MAXN = 100005 ; const int MAXE = 100005 ; const int INF = 0x3f3f3f3f ; struct Edge { int v ; Edge* next ; } E[MAXE] , *H[MAXN] , *edge ; int Lmax[2][MAXN << 2] ; int Rmax[2][MAXN << 2] ; int maxv[2][MAXN << 2] ; int Lnum[MAXN << 2] ; int Rnum[MAXN << 2] ; int siz[MAXN] ; int pos[MAXN] ; int val[MAXN] ; int top[MAXN] ; int son[MAXN] ; int pre[MAXN] ; int dep[MAXN] ; int idx[MAXN] ; int tree_idx ; int n , q ; void clear () { edge = E ; tree_idx = 0 ; clr ( H , 0 ) ; siz[0] = 0 ; pre[1] = 0 ; dep[0] = 0 ; } void addedge ( int u , int v ) { edge -> v = v ; edge -> next = H[u] ; H[u] = edge ++ ; } void dfs ( int u ) { siz[u] = 1 ; son[u] = 0 ; travel ( e , H , u ) { int v = e -> v ; pre[v] = u ; dep[v] = dep[u] + 1 ; dfs ( v ) ; siz[u] += siz[v] ; if ( siz[v] > siz[son[u]] ) son[u] = v ; } } void rewrite ( int u , int top_element ) { top[u] = top_element ; pos[u] = ++ tree_idx ; idx[tree_idx] = u ; if ( son[u] ) rewrite ( son[u] , top_element ) ; travel ( e , H , u ) { int v = e -> v ; if ( v != pre[u] && v != son[u] ) { rewrite ( v , v ) ; } } } void pushup ( int o , int l , int r ) { int m = mid ; Lnum[o] = Lnum[ls] ; Rnum[o] = Rnum[rs] ; Lmax[0][o] = Lmax[0][ls] ; Lmax[1][o] = Lmax[1][ls] ; Rmax[0][o] = Rmax[0][rs] ; Rmax[1][o] = Rmax[1][rs] ; maxv[0][o] = max ( maxv[0][ls] , maxv[0][rs] ) ; maxv[1][o] = max ( maxv[1][ls] , maxv[1][rs] ) ; if ( Rnum[ls] < Lnum[rs] ) { if ( Lmax[0][o] == m - l + 1 ) Lmax[0][o] += Lmax[0][rs] ; if ( Rmax[0][o] == r - m ) Rmax[0][o] += Rmax[0][ls] ; maxv[0][o] = max ( maxv[0][o] , Rmax[0][ls] + Lmax[0][rs] ) ; } if ( Rnum[ls] > Lnum[rs] ) { if ( Lmax[1][o] == m - l + 1 ) Lmax[1][o] += Lmax[1][rs] ; if ( Rmax[1][o] == r - m ) Rmax[1][o] += Rmax[1][ls] ; maxv[1][o] = max ( maxv[1][o] , Rmax[1][ls] + Lmax[1][rs] ) ; } } void build ( int o , int l , int r ) { if ( l == r ) { Lnum[o] = Rnum[o] = val[idx[l]] ; Lmax[0][o] = maxv[0][o] = Rmax[0][o] = 1 ; Lmax[1][o] = maxv[1][o] = Rmax[1][o] = 1 ; return ; } int m = mid ; build ( lson ) , build ( rson ) ; pushup ( rt ) ; } int query_Lmax ( int L , int R , int v , int o , int l , int r ) { if ( L <= l && r <= R ) return Lmax[v][o] ; int m = mid ; if ( R <= m ) return query_Lmax ( L , R , v , lson ) ; if ( m < L ) return query_Lmax ( L , R , v , rson ) ; if ( ( !v && Rnum[ls] < Lnum[rs] || v && Rnum[ls] > Lnum[rs] ) && m - L + 1 <= Rmax[v][ls] ) { return m - L + 1 + min ( Lmax[v][rs] , R - m ) ; } else return query_Lmax ( L , R , v , lson ) ; } int query_Rmax ( int L , int R , int v , int o , int l , int r ) { if ( L <= l && r <= R ) return Rmax[v][o] ; int m = mid ; if ( R <= m ) return query_Rmax ( L , R , v , lson ) ; if ( m < L ) return query_Rmax ( L , R , v , rson ) ; if ( ( !v && Rnum[ls] < Lnum[rs] || v && Rnum[ls] > Lnum[rs] ) && R - m <= Lmax[v][rs] ) { return min ( m - L + 1 , Rmax[v][ls] ) + R - m ; } else return query_Rmax ( L , R , v , rson ) ; } int query ( int L , int R , bool v , int o , int l , int r ) { if ( L <= l && r <= R ) return maxv[v][o] ; int m = mid ; if ( R <= m ) return query ( L , R , v , lson ) ; if ( m < L ) return query ( L , R , v , rson ) ; int res = max ( query ( L , R , v , lson ) , query ( L , R , v , rson ) ) ; if ( !v && Rnum[ls] < Lnum[rs] || v && Rnum[ls] > Lnum[rs] ) { res = max ( res , min ( m - L + 1 , Rmax[v][ls] ) + min ( R - m , Lmax[v][rs] ) ) ; } return res ; } int Query ( int x , int y ) { int res = 0 ; int Xnum = 0 , Ynum = 0 ; int Xmax = 0 , Ymax = 0 ; while ( top[x] != top[y] ) { if ( dep[top[x]] > dep[top[y]] ) { int L = pos[top[x]] , R = pos[x] ; int tmp = query ( L , R , 1 , root ) ; res = max ( res , tmp ) ; if ( Xnum && val[x] > val[Xnum] ) res = max ( res , query_Rmax ( L , R , 1 , root ) + Xmax ) ; if ( tmp == R - L + 1 && Xnum && val[x] > val[Xnum] ) Xmax += tmp ; else Xmax = query_Lmax ( L , R , 1 , root ) ; res = max ( res , Xmax ) ; Xnum = top[x] ; x = pre[top[x]] ; } else { int L = pos[top[y]] , R = pos[y] ; int tmp = query ( L , R , 0 , root ) ; res = max ( res , tmp ) ; if ( Ynum && val[y] < val[Ynum] ) res = max ( res , query_Rmax ( L , R , 0 , root ) + Ymax ) ; if ( tmp == R - L + 1 && Ynum && val[y] < val[Ynum] ) Ymax += tmp ; else Ymax = query_Lmax ( L , R , 0 , root ) ; res = max ( res , Ymax ) ; Ynum = top[y] ; y = pre[top[y]] ; } } if ( dep[x] > dep[y] ) { int L = pos[y] , R = pos[x] ; int tmp = query ( L , R , 1 , root ) ; res = max ( res , tmp ) ; if ( tmp == R - L + 1 ) { if ( Xnum && val[x] > val[Xnum] ) tmp += Xmax ; if ( Ynum && val[y] < val[Ynum] ) tmp += Ymax ; res = max ( res , tmp ) ; } else { if ( Xnum && val[x] > val[Xnum] ) res = max ( res , Xmax + query_Rmax ( L , R , 1 , root ) ) ; if ( Ynum && val[y] < val[Ynum] ) res = max ( res , Ymax + query_Lmax ( L , R , 1 , root ) ) ; } } else { int L = pos[x] , R = pos[y] ; int tmp = query ( L , R , 0 , root ) ; res = max ( res , tmp ) ; if ( tmp == R - L + 1 ) { if ( Xnum && val[x] > val[Xnum] ) tmp += Xmax ; if ( Ynum && val[y] < val[Ynum] ) tmp += Ymax ; res = max ( res , tmp ) ; } else { if ( Xnum && val[x] > val[Xnum] ) res = max ( res , Xmax + query_Lmax ( L , R , 0 , root ) ) ; if ( Ynum && val[y] < val[Ynum] ) res = max ( res , Ymax + query_Rmax ( L , R , 0 , root ) ) ; } } return res ; } void solve () { int x , y ; clear () ; scanf ( "%d" , &n ) ; FOR ( i , 1 , n ) scanf ( "%d" , &val[i] ) ; FOR ( i , 2 , n ) { scanf ( "%d" , &x ) ; addedge ( x , i ) ; } dfs ( 1 ) ; rewrite ( 1 , 1 ) ; build ( root ) ; scanf ( "%d" , &q ) ; while ( q -- ) { scanf ( "%d%d" , &x , &y ) ; printf ( "%d\n" , Query ( x , y ) ) ; } } int main () { int T , cas = 0 ; scanf ( "%d" , &T ) ; while ( T -- ) { printf ( "Case #%d:\n" , ++ cas ) ; solve () ; if ( T ) printf ( "\n" ) ; } return 0 ; }