今年寒假时封装了一个支持查询rank的treap。
然后发现这样无法支持指针的O(1)加减。事实上通过维护指向前继和后继的指针可以实现迭代器的O(1)加减。
今天就又写了一个treap模板,封装性自我感觉良好,有自己的迭代器,而且速度还行,在洛谷的普通平衡树一题中是第16页,总共2700份左右的AC代码。
同时为了测试迭代器和begin指针,还放到快排和堆的模板题里测试,发现我的treap常数是快排3倍不止(还有I/O的硬指标)(我96ms一个点,别人32ms一个点)
在堆的模板题里好像虐stl的priority_queue
然后在本地发现我的multiset插 6∗105 个int只要0.75秒,而我的treap插 6∗105 个int却要1秒。连不开O2的stl的multiset都跑不过,人生还有什么希望?
突然感觉可能常数跟lct差不多。
我觉得之所以这么慢,根本原因是我太菜了,直接原因是我的程序不记father的,这样旋转省了常数,但是导致我维护的辅助链表对于程序其他操作的效率没有帮助,本来记了父亲,删除迭代器就可以 O(1) 实现,最多再 O(logn) 自底向上维护一下size(这不必递归,省常数)。因为普通BST是通过取前后继来替代旋转到叶子,而我的辅助链表可以O(1)查一个迭代器的前后继。
而删除键值,甚至可以非递归找到其所在的迭代器,再用前述方法删除,无需递归。
其实插入也可以迭代实现?
然而并不想写,6K代码写完,已经累觉不爱
#include
namespace GenHelper
{
unsigned z1,z2,z3,z4,b;
unsigned rand_()
{
b=((z1<<6)^z1)>>13;
z1=((z1&4294967294U)<<18)^b;
b=((z2<<2)^z2)>>27;
z2=((z2&4294967288U)<<2)^b;
b=((z3<<13)^z3)>>21;
z3=((z3&4294967280U)<<7)^b;
b=((z4<<3)^z4)>>12;
z4=((z4&4294967168U)<<13)^b;
return (z1^z2^z3^z4);
}
}
void srand(unsigned x)
{using namespace GenHelper;
z1=x; z2=(~x)^0x233333333U; z3=x^0x1234598766U; z4=(~x)+51;}
int rand()
{
using namespace GenHelper;
int a=rand_()&32767;
int b=rand_()&32767;
return a*32768+b;
}
template<typename T> class treap{
private:
struct node{
node*l,*r,*a[2];
int p,size,w;
T v;
node(T _v):v(_v){l=r=a[0]=a[1]=NULL,p=rand(),w=size=1;}
void maintain(){size=w+(l?l->size:0)+(r?r->size:0);}
};
node*head,*mi;
void lturn(node* &x){
node*t=x->r;
x->r=t->l;
t->l=x;
x->maintain();
t->maintain();
x=t;
}
void rturn(node* &x){
node*t=x->l;
x->l=t->r;
t->r=x;
x->maintain();
t->maintain();
x=t;
}
void ins(node* &o,int y,node*fa,int v){
if(o==NULL){
o=new node(y);
o->a[v^1]=fa;
o->a[v]=fa->a[v];
if(fa->a[v])fa->a[v]->a[v^1]=o;
fa->a[v]=o;
}else if(y>o->v){
ins(o->r,y,o,1);
if(o->r->p>o->p)lturn(o);
}else if(yv){
ins(o->l,y,o,0);
if(o->l->p>o->p)rturn(o);
}else ++o->w;
o->maintain();
}
void del(node* &x,int y){
if(x==NULL)return;
if(y>x->v)del(x->r,y);
else if(yv)del(x->l,y);
else{
if(x->w>1){
--x->size;
--x->w;
return;
}
if(x->l==NULL){
node*z=x;
if(x->a[0])x->a[0]->a[1]=x->a[1];
if(x->a[1])x->a[1]->a[0]=x->a[0];
x=x->r;
delete z;
return;
}
if(x->r==NULL){
node*z=x;
if(x->a[0])x->a[0]->a[1]=x->a[1];
if(x->a[1])x->a[1]->a[0]=x->a[0];
x=x->l;
delete z;
return;
}
if(x->l->p>x->r->p){
rturn(x);
del(x->r,y);
}else{
lturn(x);
del(x->l,y);
}
}
if(x!=NULL)--x->size;
}
public:
struct iterator{
node* x;
iterator(node*_x=NULL):x(_x){}
bool operator!=(const iterator&rhs)const{return x!=rhs.x;}
bool operator==(const iterator&rhs)const{return x==rhs.x;}
T operator*(){return x->v;}
iterator operator++(){return x=x->a[1];}
iterator operator++(int){register node*t=x;x=x->a[1];return t;}
iterator operator--(){return x=x->a[0];}
iterator operator--(int){register node*t=x;x=x->a[0];return t;}
};
treap(){srand(19260817);}
inline void insert(T x){
if(head==NULL)mi=head=new node(x);
else ins(head,x,NULL,0),mi=mi->a[0]?mi->a[0]:mi;
}
inline void erase(T x){
mi=mi->v==x && mi->w==1?mi->a[1]:mi;del(head,x);
}
inline void erase(iterator x){
erase(*x);
}
inline T rank(T y){
register node*x=head;
register int ans=0,s;
while(x!=NULL){
s=x->l?x->l->size:0;
if(y==x->v)return ans+s+1;
if(y>x->v){
ans+=s+x->w;
x=x->r;
}else x=x->l;
}
return ans+1;
}
inline T kth(T y){
register node*x=head;
register int u,v;
for(;;){
u=x->l?x->l->size:0,v=x->w;
if(u=y)return x->v;
if(y>u+v){
x=x->r;
y-=u+v;
}else x=x->l;
}
}
inline iterator prec(T y){
node*x=head,*t=NULL;
while(x!=NULL){
if(x->v>=y)x=x->l;
else{
t=x;
x=x->r;
}
}
return t;
}
inline iterator succ(T y){
node*x=head,*t=NULL;
while(x)
if(x->v<=y)x=x->r;
else{
t=x;
x=x->l;
}
return t;
}
inline iterator find(T v){
register node*x=head;
for(;x && x->v!=v;x=vv?x->l:x->r);
return x;
}
inline int count(T v){
register iterator t=find(v);return t.x?t.x->w:0;
}
inline iterator begin(){return mi;}
inline iterator end(){return NULL;}
};
treap<int> t;
int main(){
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++){
int x,y;
scanf("%d%d",&x,&y);
if(x==1)t.insert(y);
if(x==2)t.erase(y);
if(x==3)printf("%d\n",t.rank(y));
if(x==4)printf("%d\n",t.kth(y));
if(x==5)printf("%d\n",*t.prec(y));
if(x==6)printf("%d\n",*t.succ(y));
}
return 0;
}
upd:1.1版本
今天把老版本里极不优美的递归实现改成了迭代实现,好像快了一些?(224ms->196ms,都没有I/O优化)
除了快一些,好像代码也短了一些,实测40万次插入+查rank,本机900ms左右,作为对比的__gnu_pbds::rb_tree_tag
(我不跟stl的set比,因为stl的set不能查rank),同样操作要1560ms左右,真开心。
附上当前版本的几个要点(或问题)
1.需要重载==,<,<=,>,>=。本来可以不重载,但重载可以给写模板的我省事
2.本模板使用游程编码,重复键值会导致计数器++,而不是使节点数增加
3.本模板的rk返回的是小于给定键值的元素个数再加1
4.本模板的erase即使是键值,也只是使计数器–,因为目前还没有彻底删除的需求
5.此版本目前不支持–s.end()
好像手写的treap比红黑树慢很多?看来理论分析出来的常数在实践中还是有意义的。
下面就是1.1版本的代码
#include
#include
#include
templateT> class treap{
private:
struct node{
node*ch[2],*a[2],*fa;
unsigned int p;
int size,w;
T v;
node(T _v):v(_v){
static unsigned int seed=19260817;
ch[0]=ch[1]=a[0]=a[1]=fa=NULL;
p=seed^=seed>>13,seed^=seed<<21,seed^=seed>>17,seed^=seed<<24;
w=size=1;
}
void maintain(){size=w+(ch[0]?ch[0]->size:0)+(ch[1]?ch[1]->size:0);}
inline int lr(){return fa->ch[1]==this;}
};
node*rt,*mi;
inline void rotate(node*x){
node*y=x->fa,*z=y->fa;
if(z)z->ch[y->lr()]=x;
int o=x->lr();
x->fa=z,y->fa=x;
y->ch[o]=x->ch[!o];
if(x->ch[!o])x->ch[!o]->fa=y;
x->ch[!o]=y;y->maintain();x->maintain();
}
public:
struct iterator{
node* x;
iterator(node*_x=NULL):x(_x){}
bool operator!=(const iterator&rhs)const{return x!=rhs.x;}
bool operator==(const iterator&rhs)const{return x==rhs.x;}
T operator*(){return x->v;}
iterator operator++(){return x=x->a[1];}
iterator operator++(int){node*t=x;x=x->a[1];return t;}
iterator operator--(){return x=x->a[0];}
iterator operator--(int){node*t=x;x=x->a[0];return t;}
};
inline iterator find(T x){
for(node*i=rt;i!=NULL;i=xv?i->ch[0]:i->ch[1])if(i->v==x)return i;
return NULL;
}
inline void insert(T x){
if(!rt){
rt=mi=new node(x);
return;
}
node*i=rt,*j=NULL;int o;
while(i!=NULL){
j=i;++i->size;
if(x==i->v){++i->w;return;}
i=xv?i->ch[o=0]:i->ch[o=1];
}
i=new node(x);i->fa=j;
j->ch[o]=i;
if(j->a[o])j->a[o]->a[!o]=i;
i->a[o]=j->a[o];
i->a[!o]=j;
j->a[o]=i;
if(mi==NULL || xv)mi=i;
while(i->fa!=NULL && i->fa->pp)
rotate(i);
if(i->fa==NULL)rt=i;
}
inline void erase(node*x){
if(x->w>1){
--x->w;
for(;x!=NULL;x=x->fa)
--x->size;
return;
}
if(x==rt && rt->size==1){delete rt;mi=rt=NULL;return;}
if(x->a[0])x->a[0]->a[1]=x->a[1];
if(x->a[1])x->a[1]->a[0]=x->a[0];
if(x==mi)mi=x->a[1];
if(!x->ch[0]){
if(x->fa)x->fa->ch[x->lr()]=x->ch[1];
if(x->ch[1])x->ch[1]->fa=x->fa;
for(node*y=x->fa;y;y=y->fa)--y->size;delete x;
}else{
node*y=x->a[0],*u=y->ch[0];
if(u!=NULL)u->fa=y->fa,y->fa->ch[y->lr()]=u;else y->fa->ch[y->lr()]=NULL;
x->v=y->v,x->w=y->w,x->a[0]=y->a[0],x->a[1]=y->a[1];
if(y->a[0])y->a[0]->a[1]=x;
if(y->a[1])y->a[1]->a[0]=x;
node*z=y->fa;
for(;z!=x;z=z->fa)z->size-=y->w;
for(;z;z=z->fa)--z->size;
delete y;
}
}
inline void erase(T x){
erase(find(x).x);
}
inline void erase(iterator x){
erase(x.x);
}
inline int rank(T y){
node*x=rt;
int ans=0,s;
while(x){
s=x->ch[0]?x->ch[0]->size:0;
if(y==x->v)return ans+s+1;
if(y>x->v){
ans+=s+x->w;
x=x->ch[1];
}else x=x->ch[0];
}
return ans+1;
}
inline T kth(T y){
node*x=rt;
int u,v;
for(;;){
u=x->ch[0]?x->ch[0]->size:0,v=x->w;
if(u=y)return x->v;
if(y>u+v){
x=x->ch[1];
y-=u+v;
}else x=x->ch[0];
}
}
inline iterator prec(T y){
node*x=rt,*t=NULL;
while(x)
if(x->vch[1];
else x=x->ch[0];
return t;
}
inline iterator succ(T y){
node*x=rt,*t=NULL;
while(x)
if(yv)t=x,x=x->ch[0];
else x=x->ch[1];
return t;
}
inline int count(T v){
iterator t=find(v);return t.x!=NULL?t.x->w:0;
}
inline iterator begin(){return mi;}
inline iterator end(){return NULL;}
inline bool empty(){return rt==NULL;}
inline int size(){return rt?rt->size:0;}
};
treap t;
inline void read(int&x){
char c=getchar();int f=1;
for(;!isdigit(c);c=getchar())f=c=='-'?-1:f;
for(x=0;isdigit(c);c=getchar())x=x*10+c-48;x*=f;
}
int main(){
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++){
int x,y;
scanf("%d%d",&x,&y);
if(x==1)t.insert(y);
if(x==2)t.erase(y);
if(x==3)printf("%d\n",t.rank(y));
if(x==4)printf("%d\n",t.kth(y));
if(x==5)printf("%d\n",*t.prec(y));
if(x==6)printf("%d\n",*t.succ(y));
}
return 0;
}