[Treap] POJ 1442

树堆

通过随机rand选择的priority来平衡二叉树


The nodes of the treap are ordered so that the keys obey the binary-search-tree property and the priorities obey the min-heap order property:


 If v is a left child of u, then v:key < u:key.
 If v is a right child of u, then v:key > u:key.

 If v is a child of u, then v:priority > u:priority.


用Treap实现名次树

名次树支持两个操作:
Kth(x): 找出第k小(第k大)的元素。

Rank(x): 值x的名次,即比x小(大)的结点个数加 1 。

Treap 实现名次树 :  http://www.cnblogs.com/TreeDream/p/6730574.html


#include 

using namespace std;

const int N = 30005;
int A[ N ];
int u[ N ];

struct Node {
        Node *ch[ 2 ]; //两个孩子
        int pri;       //优先级
        int val;
        int size;

        Node ( int v )
            : val ( v ) {
                ch[ 0 ] = ch[ 1 ] = NULL;
                pri = rand ();
                size = 1;
        }

        bool operator< ( const Node &rhs ) const { return pri < rhs.pri; }

        int cmp ( int x ) const {
                if ( x == val )
                        return -1;
                return x < val ? 0 : 1; //小于去左子树,大于去右子树
        }

        // 记录当前节点有多少个孩子
        void maintain () {
                size = 1;
                if ( ch[ 0 ] != NULL )
                        size += ch[ 0 ]->size;
                if ( ch[ 1 ] != NULL )
                        size += ch[ 1 ]->size;
        }
};

//将左/右孩子旋转到上面,保持树的性质不变
void rotate ( Node *&rt, int d ) {
        Node *sn = rt->ch[ d ^ 1 ];
        rt->ch[ d ^ 1 ] = sn->ch[ d ];
        sn->ch[ d ] = rt;
        rt->maintain ();
        sn->maintain ();
        rt = sn;
}

void insert ( Node *&rt, int x ) {
        if ( rt == NULL )
                rt = new Node ( x );
        else {
                int d = ( x < rt->val ? 0 : 1 );
                insert ( rt->ch[ d ], x );

                if ( rt->ch[ d ]->pri > rt->pri )
                        rotate ( rt, d ^ 1 );
        }

        rt->maintain ();
}

// x是 要删除节点的 val
void remove ( Node *&rt, int x ) {
        int d = rt->cmp ( x );
        Node *u = rt;
        // rt->val 和 x 相等, 找到了节点,开始删除
        if ( d == -1 ) {
                //左右树非空
                if ( rt->ch[ 0 ] != NULL && rt->ch[ 1 ] != NULL ) {
                        int d2 = ( rt->ch[ 0 ]->pri > rt->ch[ 1 ]->pri ? 1 : 0 );
                        rotate ( rt, d2 );
                        remove ( rt->ch[ d2 ], x );
                } else {
                        //左子树空
                        if ( rt->ch[ 0 ] == NULL )
                                rt = rt->ch[ 1 ];
                        //右子树空
                        else
                                rt = rt->ch[ 0 ];
                        delete u;
                }

        } else {
                remove ( rt->ch[ d ], x );
        }

        if ( rt != NULL )
                rt->maintain ();
}

int find ( Node *rt, int x ) {
        while ( rt != NULL ) {
                int d = rt->cmp ( x );
                if ( d == -1 )
                        return 1;
                else
                        rt = rt->ch[ d ];
        }
        return 0;
}

// 第k小的元素 k == size
// 求第k大就把0改成1
int kth ( Node *rt, int k ) {
        if ( rt == NULL || k <= 0 || k > rt->size )
                return -1;
        int s = rt->ch[ 0 ] == NULL ? 0 : rt->ch[ 0 ]->size;
        if ( k == s + 1 )
                return rt->val;
        else if ( k <= s )
                return kth ( rt->ch[ 0 ], k );
        else
                return kth ( rt->ch[ 1 ], k - s - 1 );
}

// 比x小的节点个数+1 x == val
int Rank ( Node *rt, int x ) {
        int tmp;
        if ( rt->ch[ 0 ] == NULL )
                tmp = 0;
        else
                tmp = rt->ch[ 0 ]->size;
        if ( rt->val == x )
                return tmp + 1;
        if ( x < rt->val )
                return Rank ( rt->ch[ 0 ], x );
        else
                return tmp + 1 + Rank ( rt->ch[ 1 ], x );
}

int main () {
        int m, n;
        while ( ~scanf ( "%d%d", &m, &n ) ) {

                Node *rt = NULL;

                for ( int i = 1; i <= m; ++i )
                        scanf ( "%d", &A[ i ] );
                for ( int i = 1; i <= n; ++i )
                        scanf ( "%d", &u[ i ] );

                int cnt = 1;
                int k = 1;

                for ( int i = 1; i <= m; ++i ) {
                        insert ( rt, A[ i ] );
                        while ( u[ cnt ] == i ) {
                                printf ( "%d\n", kth ( rt, k++ ) );
                                ++cnt;
                        }
                }
        }
        return 0;
}



你可能感兴趣的:(ACM)