给定一棵 n 个点的树,每条边有一种颜色,对于一条路径,可以写出一个颜色序列,将颜色序列划分成很多相同颜色的颜色段,定义一条路径的权值是颜色序列的颜色段数。
求树中经过边数在 l,r 之间的路径的最大权值。
Data Constraint
n≤2×105
考虑点剖,对于当前的分治重心,将每个儿子按照颜色排序。
然后维护两棵线段树即可。
时间复杂度: O(nlog2n)
#include
#include
#include
#include
#include
using namespace std ;
#define N 200000 + 10
const int inf = 2e9 ;
int T[2][4*N] ;
bool vis[N] , Clear[2][4*N] ;
int Node[2*N] , Next[2*N] , Col[2*N] , Head[N] , tot ;
int C[N] , Size[N] , Maxs[N] , Son[N] , MaxDeep[N] , flag[N] , D[N][2] ;
int n , m , L , R ;
int Root , All , Minv , Cnt , MaxDep , tag ;
int ans = -inf , ret , nowv ;
int Read() {
int ret = 0 , sign = 1 ;
char ch = getchar() ;
while ( ch < '0' || ch > '9' ) {
if ( ch == '-' ) sign = -1 ;
ch = getchar() ;
}
while ( ch >= '0' && ch <= '9' ) {
ret = ret * 10 + ch - '0' ;
ch = getchar() ;
}
return sign * ret ;
}
void link( int u , int v , int w ) {
Node[++tot] = v ;
Next[tot] = Head[u] ;
Col[tot] = w ;
Head[u] = tot ;
}
bool cmp( int a , int b ) { return Col[a] < Col[b] ; }
void GetSize( int x , int F ) {
Size[x] = Maxs[x] = 1 ;
for (int p = Head[x] ; p ; p = Next[p] ) {
if ( Node[p] == F || vis[Node[p]] ) continue ;
GetSize( Node[p] , x ) ;
Size[x] += Size[Node[p]] ;
if ( Size[Node[p]] > Maxs[x] ) Maxs[x] = Size[Node[p]] ;
}
}
void GetRoot( int x , int F ) {
Maxs[x] = max( Maxs[x] , Size[All] - Maxs[x] ) ;
if ( Maxs[x] < Minv ) Root = x , Minv = Maxs[x] ;
for (int p = Head[x] ; p ; p = Next[p] ) {
if ( Node[p] == F || vis[Node[p]] ) continue ;
GetRoot( Node[p] , x ) ;
}
}
void DFS( int x , int F , int last , int sum , int deep ) {
if ( flag[deep] != tag ) {
flag[deep] = tag ;
MaxDeep[deep] = -inf ;
MaxDep = max( MaxDep , deep ) ;
}
MaxDeep[deep] = max( MaxDeep[deep] , sum ) ;
if ( deep >= L && deep <= R ) ans = max( ans , sum ) ;
D[++Cnt][0] = deep , D[Cnt][1] = sum ;
for (int p = Head[x] ; p ; p = Next[p] ) {
if ( Node[p] == F || vis[Node[p]] ) continue ;
DFS( Node[p] , x , Col[p] , sum + (Col[p] != last) * C[Col[p]] , deep + 1 ) ;
}
}
void Insert( int type , int v , int l , int r , int x , int val ) {
if ( l == r ) {
T[type][v] = max( T[type][v] , val ) ;
return ;
}
if ( Clear[type][v] ) {
T[type][v+v] = -inf ;
T[type][v+v+1] = -inf ;
Clear[type][v+v] = Clear[type][v+v+1] = 1 ;
Clear[type][v] = 0 ;
}
int mid = (l + r) >> 1 ;
if ( x <= mid ) Insert( type , v + v , l , mid , x , val ) ;
else Insert( type , v + v + 1 , mid + 1 , r , x , val ) ;
T[type][v] = max( T[type][v+v] , T[type][v+v+1] ) ;
}
void Search( int type , int v , int l , int r , int x , int y ) {
if ( nowv + T[type][v] <= ans || T[type][v] <= ret ) return ;
if ( l == x && r == y ) {
ret = max( ret , T[type][v] ) ;
return ;
}
if ( Clear[type][v] ) {
T[type][v+v] = -inf ;
T[type][v+v+1] = -inf ;
Clear[type][v+v] = Clear[type][v+v+1] = 1 ;
Clear[type][v] = 0 ;
}
int mid = (l + r) >> 1 ;
if ( y <= mid ) Search( type , v + v , l , mid , x , y ) ;
else if ( x > mid ) Search( type , v + v + 1 , mid + 1 , r , x , y ) ;
else {
Search( type , v + v , l , mid , x , mid ) ;
Search( type , v + v + 1 , mid + 1 , r , mid + 1 , y ) ;
}
T[type][v] = max( T[type][v+v] , T[type][v+v+1] ) ;
}
void Solve( int x ) {
Root = All = x , Minv = inf ;
GetSize( x , 0 ) ;
GetRoot( x , 0 ) ;
vis[Root] = 1 ;
Son[0] = 0 ;
for (int p = Head[Root] ; p ; p = Next[p] ) {
if ( vis[Node[p]] ) continue ;
Son[++Son[0]] = p ;
}
sort( Son + 1 , Son + Son[0] + 1 , cmp ) ;
Clear[0][1] = 1 ;
T[0][1] = -inf ;
MaxDep = 0 ;
for (int i = 1 ; i <= Son[0] ; i ++ ) {
int p = Son[i] ;
if ( i == 1 || Col[p] != Col[Son[i-1]] ) {
Clear[1][1] = 1 , tag ++ ;
T[1][1] = -inf ;
for (int k = 1 ; k <= MaxDep ; k ++ ) {
Insert( 0 , 1 , 1 , n , k , MaxDeep[k] ) ;
}
MaxDep = 0 ;
}
Cnt = 0 ;
DFS( Node[p] , Root , Col[p] , C[Col[p]] , 1 ) ;
for (int k = 1 ; k <= Cnt ; k ++ ) {
int deep = D[k][0] , val = D[k][1] ;
nowv = val ;
if ( R > deep && i > 1 && val + T[0][1] > ans ) {
ret = -inf ;
Search( 0 , 1 , 1 , n , max( L - deep , 1 ) , R - deep ) ;
ans = max( ans , val + ret ) ;
}
if ( i > 1 && R > deep && Col[p] == Col[Son[i-1]] && val + T[1][1] - C[Col[p]] > ans ) {
ret = -inf ;
nowv -= C[Col[p]] ;
Search( 1 , 1 , 1 , n , max( L - deep , 1 ) , R - deep ) ;
ans = max( ans , val + ret - C[Col[p]] ) ;
}
}
for (int k = 1 ; k <= Cnt ; k ++ ) Insert( 1 , 1 , 1 , n , D[k][0] , D[k][1] ) ;
}
for (int p = Head[Root] ; p ; p = Next[p] ) {
if ( vis[Node[p]] ) continue ;
Solve( Node[p] ) ;
}
}
int main() {
freopen( "journey.in" , "r" , stdin ) ;
freopen( "journey.out" , "w" , stdout ) ;
scanf( "%d%d%d%d" , &n , &m , &L , &R ) ;
for (int i = 1 ; i <= m ; i ++ ) C[i] = Read() ;
for (int i = 1 ; i < n ; i ++ ) {
int u = Read() , v = Read() , w = Read() ;
link( u , v , w ) ;
link( v , u , w ) ;
}
Solve( (n + 1) / 2 ) ;
printf( "%d\n" , ans ) ;
return 0 ;
}
以上.