hdu 2871 Memory Control(伸展树splay tree)

hdu 2871 Memory Control

题意:就是对一个区间的四种操作,NEW x,占据最左边的连续的x个单元,Free x 把x单元所占的连续区间清空 , Get x 把第x次占据的区间输出来, R 清空整个区间。

解题思路:这个题就是一个区间合并,以前用线段树写的,拿来练练splay。要记录的是区间最大的连续空格,要维护这个最值,需要两个辅助的值,该区间左边连续的最值和右边连续的最值。更新的时候仔细就好了,其他就是splay的常规操作的。还有就是记录占据的连续区间和查找占据的连续区间用个vector,二分查找,插入就好了。在清空整个区间时,用update,不要重建一棵树。

(看似简单,我调了一天啊。。起初是push_up写不好,调很久没弄出来,后来过了样例,一直TLE。。还以为是效率不够高,第二天才发现是内存池的tot没清零)

 

#pragma comment(linker, "/STACK:1024000000,1024000000")

#include<stdio.h>

#include<string.h>

#include<algorithm>

#include<vector>

using namespace std ;



vector< pair<int,int> > vec ;

const int maxn = 55555 ;



int tot , n , m ;



int lm[maxn] , rm[maxn] , mm[maxn] , col[maxn] , val[maxn] ;

int son[2][maxn] , fa[maxn] , size[maxn] , pos[maxn] ;



void new_node ( int l , int r ) {

    size[tot] = 1 ;

    son[0][tot] = son[1][tot] = fa[tot] = -1 ;

    lm[tot] = rm[tot] = mm[tot] = r - l + 1 ;

    col[tot] = 1 , val[tot] = 1 ;

    if ( l == 0 ) col[tot] = -1 , lm[tot] = 0 , rm[tot] -- , mm[tot] -- ;

    if ( r == n + 1 ) col[tot] = -1 , rm[tot] = 0 , lm[tot] -- , mm[tot] -- ;

    if ( ( ( l + r ) >> 1 == 0 ) || ( ( l + r ) >> 1 == n + 1 ) ) val[tot] = 0 ;

    tot ++ ;

}



void push_down ( int rt ) {

    if ( col[rt] != -1 ) {

        int ls = son[0][rt] , rs = son[1][rt] ;

        val[ls] = val[rs] = col[rt] ;

        if ( ls != -1 ) {

            col[ls] = col[rt] ;

            lm[ls] = rm[ls] = mm[ls] = size[ls] * col[rt] ;

        }

        if ( rs != -1 ) {

            col[rs] = col[rt] ;

            lm[rs] = rm[rs] = mm[rs] = size[rs] * col[rt] ;

        }

        col[rt] = -1 ;

    }

}



void push_up ( int rt ) {

    size[rt] =1 ;

    mm[rt] = lm[rt] = rm[rt] = val[rt] ;

    int ls = son[0][rt] , rs = son[1][rt] ;

    if ( val[rt] ) {

        if ( ls == -1 ) lm[rt] = 1 + ( rs == -1 ? 0 : lm[rs] ) ;

        if ( rs == -1 ) rm[rt] = 1 + ( ls == -1 ? 0 : rm[ls] ) ;

    }

    if ( ls != -1 ) {

        lm[rt] = lm[ls] ;

        mm[rt] = max ( mm[ls] , rm[ls] + val[rt] ) ;

        if ( lm[ls] == size[ls] ) lm[rt] += val[rt] , mm[rt] += val[rt] ;

        size[rt] += size[ls] ;

    }

    if ( rs != -1 ) {

        rm[rt] = rm[rs] ;

        mm[rt] = max ( mm[rt] , max ( mm[rs] , lm[rs] + val[rt] ) ) ;

        if ( rm[rs] == size[rs] ) rm[rt] += val[rt] , mm[rt] = max ( mm[rt] , rm[rs] + val[rt] ) ;

        size[rt] += size[rs] ;

    }

    if ( ls != -1 && rs != -1 && val[rt] ) {

        mm[rt] = max ( mm[rt] , rm[ls] + lm[rs] + 1 ) ;

        if ( lm[ls] == size[ls] ) lm[rt] = lm[ls] + 1 + lm[rs] ;

        if ( rm[rs] == size[rs] ) rm[rt] = rm[rs] + 1 + rm[ls] ;

    }

}



int build ( int l , int r ) {

    if ( l > r ) return -1 ;

    int mid = ( l + r ) >> 1 ;

    new_node ( l , r ) ;

    int temp = tot - 1 ;

    son[0][temp] = build ( l , mid - 1 ) ;

    if ( son[0][temp] != -1 ) fa[son[0][temp]] = temp , size[temp] += size[son[0][temp]] ;

    son[1][temp] = build ( mid + 1 , r ) ;

    if ( son[1][temp] != -1 ) fa[son[1][temp]] = temp , size[temp] += size[son[1][temp]] ;

    return temp ;

}



void rot ( int rt , int c ) {

    int y = fa[rt] , z = fa[y] ;

    push_down ( y ) , push_down ( rt ) ;

    son[!c][y] = son[c][rt] ;

    if ( son[c][rt] != -1 ) fa[son[c][rt]] = y ;

    fa[rt] = z ;

    if ( z != -1 ) {

        if ( y == son[0][z] ) son[0][z] = rt ;

        else son[1][z] = rt ;

    }

    son[c][rt] = y , fa[y] = rt ;

    push_up ( y ) ;

}



void splay ( int rt , int to ) {

    push_down ( rt ) ;

    while ( fa[rt] != to ) {

        if ( fa[fa[rt]] == to ) rot ( rt , rt == son[0][fa[rt]] ) ;

        else {

            int y = fa[rt] , z = fa[y] ;

            if ( rt == son[0][y] ) {

                if ( y == son[0][z] ) rot ( y , 1 ) , rot ( rt , 1 ) ;

                else rot ( rt , 1 ) , rot ( rt , 0 ) ;

            }

            else {

                if ( y == son[1][z] ) rot ( y , 0 ) , rot ( rt , 0 ) ;

                else rot ( rt , 0 ) , rot ( rt , 1 ) ;

            }

        }

    }

    push_up ( rt ) ;

}



int find ( int rt , int key ) {

    int cnt = 0 ;

    if ( son[0][rt] != -1 ) cnt = size[son[0][rt]] ;

    if ( cnt + 1 == key ) return rt ;

    if ( cnt >= key ) return find ( son[0][rt] , key ) ;

    return find ( son[1][rt] , key - cnt - 1 ) ;

}



int update ( int l , int r , int c , int rt ) {

    l ++ , r ++ ;

    int temp = pos[l-1] ;

    splay ( temp , -1 ) ;

    rt = temp ;

    temp = pos[r+1] ;

    splay ( temp , rt ) ;

    temp = son[0][temp] ;

    col[temp] = val[temp] = c ;

    lm[temp] = rm[temp] = mm[temp] = size[temp] * c ;

    push_up ( fa[temp] ) ;

    push_up ( fa[fa[temp]] ) ;

    return rt ;

}



int search ( int rt , int key , int rk ) {

    push_down ( rt ) ;

    int ls = son[0][rt] , rs = son[1][rt] ;

    if ( ls != -1 && mm[ls] >= key ) {

        return search ( ls , key , rk ) ;

    }

    if ( val[rt] ) {

        int cnt = 1 ;

        if ( ls != -1 ) cnt += rm[ls] ;

        if ( rs != -1 ) cnt += lm[rs] ;

        if ( cnt >= key ) {

            int pos = rk ;

            if ( ls != -1 ) pos += size[ls] ;

            return pos - rm[ls] ;

        }

    }

    return search ( rs , key , rk + 1 + ( ls == -1 ? 0 : size[ls] ) ) ;

}



int bin ( int key ) {

    int l = 0 , r = vec.size () - 1 ;

    while ( l <= r ) {

        int m = ( l + r ) >> 1 ;

        if ( vec[m].second >= key ) r = m - 1 ;

        else l = m + 1 ;

    }

    return r + 1 ;

}



int bin2 ( int key ) {

    int l = 0 , r = vec.size () - 1 ;

    while ( l <= r ) {

        int m = ( l + r ) >> 1 ;

        if ( vec[m].second <= key ) l = m + 1 ;

        else r = m - 1 ;

    }

    return l ;

}



int get_num()

{

    char a ;

    int num = 0 ;

    int flag = 1 ;

    while ( a = getchar() , ( a < '0' || a > '9' ) && a != '-' ) ;

    if ( a == '-' ) flag = -1 ;

    else num = a - '0' ;

    while (( a = getchar()) != ' ' && a != '\n' )

        num = num * 10 + (a-'0') ;

    return num * flag ;

}



int main () {

    char op[11] ;

    int a , b , i ;

    while ( scanf ( "%d%d" , &n , &m ) != EOF ) {

        tot = 0 ;

        int root = build ( 0 , n + 1 ) ;

        for ( i = 1 ; i <= n + 2 ; i ++ ) pos[i] = find ( root , i ) ;

        pair<int,int> u ;

        vec.clear () ;

        while ( m -- ) {

            scanf ( "%s" , op ) ;

            if ( op[0] == 'R' ) {

                root = update ( 1 , n , 1 , root ) ;

                vec.clear () ;

                puts ( "Reset Now" ) ;

            }

            else if ( op[0] == 'N' ) {

                a = get_num () ;

                if ( mm[root] < a ) puts ( "Reject New" ) ;

                else {

                    int l = search ( root , a , 0 ) ;

                    printf ( "New at %d\n" , l ) ;

                    int r = l + a - 1 ;

                    root = update ( l , r , 0 , root ) ;

                    u = make_pair ( l , r ) ;

                    l = bin2 ( r ) ;

                    vec.insert ( vec.begin () + l , u ) ;

                }

            }

            else if ( op[0] == 'F' ) {

                a = get_num () ;

                int k = bin ( a ) ;

                if ( k == vec.size () || vec[k].first > a ) puts ( "Reject Free" ) ;

                else {

                    printf ( "Free from %d to %d\n" , vec[k].first , vec[k].second ) ;

                    root = update ( vec[k].first , vec[k].second , 1 , root ) ;

                    vec.erase ( vec.begin () + k , vec.begin () + k + 1 ) ;

                }

            }

            else {

                a = get_num () ;

                if ( a > vec.size () ) puts ( "Reject Get" ) ;

                else printf ( "Get at %d\n" , vec[a-1].first ) ;

            }

        }

        puts ( "" ) ;

    }

}

/*

10 2

N 9

N 1

*/


 

 

你可能感兴趣的:(memory)