[APIO2012]派遣 (平衡树启发式合并)

[APIO2012]派遣 (平衡树启发式合并)
题目大意:大概是这样的,一棵树n个点,每个点有点权val[i]和cost[i],给定一个m,对于每颗子树,计算出一个w值,w的计算方法为(val[i]*k),其中k为i子树下,最多能取出的使得cost的和小于等于m点的个数。
解题思路:假如,我们能将一颗子树的每个点,以cost为key,建成平衡树,那么计算答案想必还是比较简单的吧。但是我们不能给每颗子树建一棵平衡树,那么我们就从叶子节点开始计算,然后往上合并,合并两棵树时,将小的往大的里面一个个的插。
代码:
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define lowbit(x) (x&(-x))
#define ll long long
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
#define ls son[0][rt]
#define rs son[1][rt]
#define new_edge(a,b,c) edge[tot].t = b , edge[tot].v = c , edge[tot].next = head[a] , head[a] = tot ++
using namespace std;

const int maxn = 111111 ;
int son[2][maxn] , fa[maxn] , size[maxn] ;
ll val[maxn] , sum[maxn] ;
int pos[maxn] ;

int new_node ( int _val , int rt ) {
    if ( !rt ) return 0 ;
    sum[rt] = val[rt] = _val ;
    size[rt] = 1 ;
    fa[rt] = son[0][rt] = son[1][rt] = 0 ;
    return rt ;
}

struct Edge {
    int t , next , v ;
} edge[maxn<<1] ;
int head[maxn] , tot ;
ll m , num[maxn] ;

void push_up ( int rt ) {
    size[rt] = size[ls] + size[rs] + 1 ;
    sum[rt] = val[rt] + sum[ls] + sum[rs] ;
}

void rot ( int rt , int c ) {
    int y = fa[rt] , z = fa[y] ;
    son[!c][y] = son[c][rt] , fa[son[c][rt]] = y ;
    son[c][rt] = y , fa[y] = rt ;
    fa[rt] = z ;
    son[y==son[1][z]][z] = rt ;
    push_up ( y ) ;
}

void splay ( int rt ) {
    while ( fa[rt] ) {
        int y = fa[rt] , z = fa[y] ;
        if ( !fa[y] ) rot ( rt , rt == son[0][y] ) ;
        else {
            int c = ( rt == son[0][y] ) , d = ( y == son[0][z] ) ;
            if ( c == d ) rot ( y , c ) , rot ( rt , c ) ;
            else rot ( rt , c ) , rot ( rt , d ) ;
        }
    }
    push_up ( rt ) ;
}

void insert ( int rt , int y ) {
    if ( val[rt] <= val[y] ) {
        if ( !rs ) {
            rs = y , fa[y] = rt ;
            push_up ( rt ) ;
            return ;
        }
        insert ( rs , y ) ;
    }
    else {
        if ( !ls ) {
            ls = y , fa[y] = rt ;
            push_up ( rt ) ;
            return ;
        }
        insert ( ls , y ) ;
    }
    push_up ( rt ) ;
}

void print ( int rt ) {
    if ( !rt ) return ;
    printf ( "rt = %d , fa = %d , sum = %I64d\n" , rt , fa[rt] , sum[rt] ) ;
    printf ( "ls = %d , rs = %d , val = %I64d , size = %d\n" , ls , rs , val[rt] , size[rt] ) ;
    print ( ls ) ;
    print ( rs ) ;
}

void join ( int& x , int y ) {
    if ( son[0][y] ) join ( x , son[0][y] ) ;
    if ( son[1][y] ) join ( x , son[1][y] ) ;
    new_node ( val[y] , y ) ;
    insert ( x , y ) ;
    splay ( y ) ;
    x = y ;
}

ll ans = 0 ;

int cnt ( int rt , ll now , int k ) {
    if ( now + sum[ls] + val[rt] <= m ) {
        if ( !rs ) return k + size[rt] ;
        else return cnt ( rs , now + sum[ls] + val[rt] , k + size[ls] + 1 ) ;
    }
    else {
        if ( !ls ) return k ;
        else return cnt ( ls , now , k ) ;
    }
}

void dfs ( int u ) {
    int i ;
    int temp = u ;
    for ( i = head[u] ; i != -1 ; i = edge[i].next ) {
        int v = edge[i].t ;
        dfs ( v ) ;
        v = pos[v] ;
        if ( size[temp] > size[v] ) join ( temp , v ) ;
        else join ( v , temp ) , temp = v ;
    }
    int k = cnt ( temp , 0 , 0 ) ;
    ans = max ( ans , num[u] * k ) ;
    pos[u] = temp ;
}

int main() {
    int n , i , j , k ;
    while ( scanf ( "%d%d" , &n , &m ) != EOF ) {
        memset ( head , -1 , sizeof ( head ) ) ;
        ans = tot = 0 ;
        int rt ;
        for ( i = 1 ; i <= n ; i ++ ) {
            int a , b , c ;
            pos[i] = i ;
            scanf ( "%d%d%d" , &a , &b , &c ) ;
            if ( a == 0 ) rt = i ;
            new_edge ( a , i , 0 ) ; num[i] = c ;
            new_node ( b , i ) ;
        }
        dfs ( rt ) ;
        printf ( "%lld\n" , ans ) ;
    }
    return 0;
}


你可能感兴趣的:(平衡树)