参照陈竞潇的模板写的BZOJ 3188:
#include<cstdio> #include<cstring> #include<algorithm> #define for1(i,a,b) for(int i=(a);i<=(b);++i) using namespace std; typedef long long ll; const int N=1E5+100; int n,m,data[N]; struct node{ node(); node *ch[2],*fa; ll d,sum,set,add[2]; int size; short vset; short pl() {return this==fa->ch[1];} void count(); void push(); void mark(ll,ll,short); }*null; node::node(){ch[0]=ch[1]=fa=null; size=vset=sum=add[0]=add[1]=0;} void node::mark(ll val,ll dd,short t){ if(this==null) return; if(!t){ set=val; sum=size*set; d=set; vset=1; add[0]=add[1]=0; }else{ add[0]+=val; add[1]+=dd; sum+=val*size; sum+=dd*size*(size-1)/2; d+=val+dd*(ch[0]->size); } } void node::push(){ if(this==null)return; if(vset){ ch[0]->mark(set,0,0); ch[1]->mark(set,0,0); vset=0; set=0; } if(add[0]||add[1]){ ch[0]->mark(add[0],add[1],1); ch[1]->mark(add[0]+add[1]*(ch[0]->size+1),add[1],1); add[0]=add[1]=0; } } void node::count(){ size=ch[0]->size+ch[1]->size+1; sum=ch[0]->sum+ch[1]->sum+d; } namespace Splay{ node *ROOT; node *build(int l=1,int r=n){ if (l>r) return null; int mid=(l+r)>>1; node *ro=new node; ro->d=data[mid]; ro->ch[0]=build(l,mid-1); ro->ch[1]=build(mid+1,r); ro->ch[0]->fa=ro; ro->ch[1]->fa=ro; ro->count(); return ro; } void Build(){ null=new node; *null=node(); ROOT=build(); } void rotate(node *k){ node *r=k->fa; if (k==null||r==null) return; r->push(); k->push(); int x=k->pl()^1;; r->ch[x^1]=k->ch[x]; r->ch[x^1]->fa=r; if (r->fa!=null) r->fa->ch[r->pl()]=k; else ROOT=k; k->fa=r->fa; r->fa=k; k->ch[x]=r; r->count(); k->count(); } void splay(node *r,node *tar=null){ for (;r->fa!=tar;rotate(r)) if (r->fa->fa!=tar)rotate(r->pl()==r->fa->pl()?r->fa:r); r->push(); } void insert(int x,int val){ node *r=ROOT; if (ROOT==null){ ROOT=new node; ROOT->d=val; ROOT->count(); return; } while (1) { r->push(); int c; if (r->ch[0]->size+1>=x) c=0; else c=1,x-=r->ch[0]->size+1; if (r->ch[c]==null){ r->ch[c]=new node; r->ch[c]->fa=r; r->ch[c]->d=val; splay(r->ch[c]); return; }else r=r->ch[c]; } } node *kth(int k){ node *r=ROOT; while (r!=null){ r->push(); if (r->ch[0]->size>=k) r=r->ch[0]; else if (r->ch[0]->size+1>=k) return r; else k-=r->ch[0]->size+1,r=r->ch[1]; } return null; } node *pack(int l,int r){ node *ln=kth(l-1),*rn=kth(r+1); if ((ln==null)&&(rn==null)) return ROOT; else if (ln==null){ splay(rn); return rn->ch[0]; }else if (rn==null){ splay(ln); return ln->ch[1]; }else{ splay(ln); splay(rn,ROOT); return rn->ch[0]; } } } int main(){ scanf("%d%d",&n,&m); for1(i,1,n)scanf("%d",&data[i]); Splay::Build(); int j,a,b,c; for1(i,1,m){ scanf("%d",&j); switch(j){ node *r; case 1: scanf("%d%d%d",&a,&b,&c); r=Splay::pack(a,b); r->mark(c,0,0); Splay::splay(r); break; case 2: scanf("%d%d%d",&a,&b,&c); r=Splay::pack(a,b); r->mark(c,c,1); Splay::splay(r); break; case 3: scanf("%d%d",&a,&b); Splay::insert(a,b); break; case 4: scanf("%d%d",&a,&b); r=Splay::pack(a,b); printf("%lld\n",r->sum); break; } } return 0; }
还是参照陈竞潇的模板写的指针版的BZOJ 3224普通平衡树:
#include<cctype> #include<cstdio> #include<cstring> #include<algorithm> #define for1(i,a,b) for(int i=(a);i<=(b);++i) using namespace std; typedef long long ll; struct node{ node(); node *ch[2],*fa; int d,size,sum; short pl(){return this==fa->ch[1];} void count(){sum=ch[0]->sum+ch[1]->sum+size;} }*null; node::node(){ch[0]=ch[1]=fa=null;size=sum=0;} int getint(){ char c; int fh=1; while (!isdigit(c=getchar())) if (c=='-') fh=-1; int a=c-'0'; while (isdigit(c=getchar())) a=a*10+c-'0'; return a*fh; } namespace Splay{ node *ROOT; void Build(){ null=new node; *null=node(); ROOT=null; } void rotate(node *k){ node *r=k->fa; if (k==null||r==null) return; int x=k->pl()^1; r->ch[x^1]=k->ch[x]; r->ch[x^1]->fa=r; if (r->fa!=null) r->fa->ch[r->pl()]=k; else ROOT=k; k->fa=r->fa; r->fa=k; k->ch[x]=r; r->count(); k->count(); } void splay(node *r,node *tar=null){ for (;r->fa!=tar;rotate(r)) if (r->fa->fa!=tar) rotate(r->pl()==r->fa->pl()?r->fa:r); } void updata(node *r){ while (r!=null){ r->count(); r=r->fa; } } void insert(int x){ node *r=ROOT; if (ROOT==null){ ROOT=new node; ROOT->d=x; ROOT->size=1; ROOT->sum=1; return; } while (1){ int c; if (x<r->d) c=0; else if (x>r->d) c=1; else {r->size++;r->sum++;splay(r); return;} if (r->ch[c]==null){ r->ch[c]=new node; r->ch[c]->fa=r; r->ch[c]->d=x; r->ch[c]->size=1; r->ch[c]->sum=1; splay(r->ch[c]); return; }else r=r->ch[c]; } } node *kth(int k){ node *r=ROOT; while (r!=null){ if (r->ch[0]->sum>=k) r=r->ch[0]; else if (r->ch[0]->sum+r->size>=k) return r; else k-=r->ch[0]->sum+r->size,r=r->ch[1]; } return null; } node *ques(int k){ node *r=ROOT; int ans=0; while (r!=null){ if (k<r->d) r=r->ch[0]; else if (k>r->d) ans+=r->ch[0]->sum+r->size,r=r->ch[1]; else {printf("%d\n",ans+r->ch[0]->sum+1); return r;} } return null; } node *ques2(int k){ node *r=ROOT; while (r!=null){ if (k<r->d) r=r->ch[0]; else if (k>r->d) r=r->ch[1]; else return r; } return null; } node *rightdown(node *r){ while (r->ch[1]!=null){ r=r->ch[1]; }return r; } node *leftdown(node *r){ while (r->ch[0]!=null){ r=r->ch[0]; }return r; } void deleter(node *r){ if (r->size>1){ splay(r); r->size--; r->sum--; return; }else{ splay(r); if ((r->ch[0]==null)&&(r->ch[1]==null)){ ROOT=null; delete r; }else if (r->ch[0]==null){ r->ch[1]->fa=null; ROOT=r->ch[1]; delete r; }else if (r->ch[1]==null){ r->ch[0]->fa=null; ROOT=r->ch[0]; delete r; }else{ splay(rightdown(r->ch[0]),ROOT); r->ch[0]->ch[1]=r->ch[1]; r->ch[1]->fa=r->ch[0]; r->ch[0]->fa=null; r->ch[0]->count(); ROOT=r->ch[0]; delete r; } } } int predd(node *r,int x){ if (r==null) return -1E7-10; if (x<=r->d) return predd(r->ch[0],x); return max(r->d,predd(r->ch[1],x)); } int pross(node *r,int x){ if (r==null) return 1E7+10; if (r->d<=x) return pross(r->ch[1],x); return min(r->d,pross(r->ch[0],x)); } int predds(int x){ return predd(ROOT,x); } int prosss(int x){ return pross(ROOT,x); } } int main() { int n,x,num; n=getint(); Splay::Build(); while (n>0){n--; x=getint(); switch(x){ node *r; case 1: num=getint(); Splay::insert(num); break; case 2: num=getint(); r=Splay::ques2(num); Splay::deleter(r); break; case 3: num=getint(); r=Splay::ques(num); break; case 4: num=getint(); r=Splay::kth(num); printf("%d\n",r->d); break; case 5: num=getint(); printf("%d\n",Splay::predds(num)); break; case 6: num=getint(); printf("%d\n",Splay::prosss(num)); break; } } return 0; }
还是参照陈竞潇的模板写的BZOJ 3224普通平衡树,数组版,里面带上了垃圾回收,效率和指针版的差不多,但更短,也许是我压代码了吧,,,
#include<cstdio> #include<algorithm> #define read(x) x=getint() using namespace std; inline const int getint(){char c=getchar();int k=1,r=0;for(;c<'0'||c>'9';c=getchar())if(c=='-')k=-1;for(;c>='0'&&c<='9';c=getchar())r=r*10+c-'0';return k*r;} struct node{ int fa,ch[2],d,size,sum; }T[100003]; int chi[100003],top=0,cnt=1,ROOT=0; inline bool pl(int X){return T[T[X].fa].ch[1]==X;} inline void newnode(int &X){ if (top) {X=chi[top];top--;} else X=cnt++; T[X].fa=T[X].ch[0]=T[X].ch[1]=T[X].d=T[X].size=T[X].sum=0; } inline void count(int k){T[k].sum=T[T[k].ch[0]].sum+T[T[k].ch[1]].sum+T[k].size;} inline void rotate(int k){ int r=T[k].fa; if (k==0||r==0) return; int x=pl(k)^1; T[r].ch[x^1]=T[k].ch[x]; T[T[r].ch[x^1]].fa=r; if (T[r].fa!=0) T[T[r].fa].ch[pl(r)]=k; else ROOT=k; T[k].fa=T[r].fa; T[r].fa=k; T[k].ch[x]=r; count(r); count(k); } inline void splay(int k,int tar=0){ for(;T[k].fa!=tar;rotate(k)) if (T[T[k].fa].fa!=tar) rotate(pl(k)==pl(T[k].fa)?T[k].fa:k); } inline void insect(int x){ int k=ROOT; if (ROOT==0){ newnode(ROOT); T[ROOT].d=x; T[ROOT].size=T[ROOT].sum=1; return; } while (1){ int c; if (x<T[k].d) c=0; else if (x>T[k].d) c=1; else {T[k].size++;T[k].sum++;splay(k);return;} if (T[k].ch[c]==0){ newnode(T[k].ch[c]); T[T[k].ch[c]].fa=k; T[T[k].ch[c]].d=x; T[T[k].ch[c]].size=1; T[T[k].ch[c]].sum=1; splay(T[k].ch[c]); return; }else k=T[k].ch[c]; } } inline void del(int x){top++; chi[top]=x;} inline int rightdown(int x){ while (T[x].ch[1]!=0) x=T[x].ch[1]; return x; } inline void deletr(int x){ int k=ROOT; while (k){ if (x<T[k].d) k=T[k].ch[0]; else if (x>T[k].d) k=T[k].ch[1]; else break; }if (k==0) return; if (T[k].size>1){ splay(k); T[k].size--;T[k].sum--;return; }else{ splay(k); if ((T[k].ch[0]==0)&&(T[k].ch[1]==0)){ ROOT=0; del(k); }else if (T[k].ch[0]==0){ T[T[k].ch[1]].fa=0; ROOT=T[k].ch[1]; del(k); }else if (T[k].ch[1]==0){ T[T[k].ch[0]].fa=0; ROOT=T[k].ch[0]; del(k); }else{ splay(rightdown(T[k].ch[0]),ROOT); T[T[k].ch[0]].ch[1]=T[k].ch[1]; T[T[k].ch[1]].fa=T[k].ch[0]; T[T[k].ch[0]].fa=0; count(T[k].ch[0]); ROOT=T[k].ch[0]; del(k); } } } inline void query(int x){ int k=ROOT,s=0; while (k){ if (x<T[k].d) k=T[k].ch[0]; else if (x>T[k].d) s+=T[T[k].ch[0]].sum+T[k].size,k=T[k].ch[1]; else break; }printf("%d\n",s+T[T[k].ch[0]].sum+1); } inline void queryk(int x){ int k=ROOT; while (k){ if (T[T[k].ch[0]].sum>=x) {k=T[k].ch[0];} else {if (T[T[k].ch[0]].sum+T[k].size>=x) break; else x-=T[T[k].ch[0]].sum+T[k].size;k=T[k].ch[1];} }printf("%d\n",T[k].d); } inline int pre(int k,int x){ if (k==0) return -1E7-10; if (x<=T[k].d) return pre(T[k].ch[0],x); return max(pre(T[k].ch[1],x),T[k].d); } inline int pro(int k,int x){ if (k==0) return 1E7+10; if (T[k].d<=x) return pro(T[k].ch[1],x); return min(pro(T[k].ch[0],x),T[k].d); } int main(){ int t,opt,x; read(t); while (t--){ read(opt); read(x); switch (opt){ case 1: insect(x); break; case 2: deletr(x); break; case 3: query(x); break; case 4: queryk(x); break; case 5: printf("%d\n",pre(ROOT,x)); break; case 6: printf("%d\n",pro(ROOT,x)); break; } }return 0; }
指针版的垃圾回收,补上:
#include<cctype> #include<cstdio> #include<cstring> #include<algorithm> #define for1(i,a,b) for(int i=(a);i<=(b);++i) using namespace std; typedef long long ll; struct node{ node *ch[2],*fa; int d,size,sum; short pl(){return this==fa->ch[1];} void count(){sum=ch[0]->sum+ch[1]->sum+size;} }*null; int getint(){char c;int fh=1;while(!isdigit(c=getchar()))if(c=='-')fh=-1;int a=c-'0';while(isdigit(c=getchar()))a=a*10+c-'0';return a*fh;} namespace Splay{ node *ROOT,pool[100003]; int tot=0; node *newnode(){ node *t=&pool[tot++]; t->ch[0]=t->ch[1]=t->fa=null; t->size=t->sum=0; return t; } void Build(){ null=newnode(); null->ch[0]=null->ch[1]=null->fa=null; ROOT=null; } void rotate(node *k){ node *r=k->fa; if (k==null||r==null) return; int x=k->pl()^1; r->ch[x^1]=k->ch[x]; r->ch[x^1]->fa=r; if (r->fa!=null) r->fa->ch[r->pl()]=k; else ROOT=k; k->fa=r->fa; r->fa=k; k->ch[x]=r; r->count(); k->count(); } void splay(node *r,node *tar=null){ for (;r->fa!=tar;rotate(r)) if (r->fa->fa!=tar) rotate(r->pl()==r->fa->pl()?r->fa:r); } void updata(node *r){ while (r!=null){ r->count(); r=r->fa; } } void insert(int x){ node *r=ROOT; if (ROOT==null){ ROOT=newnode(); ROOT->d=x; ROOT->size=1; ROOT->sum=1; return; } while (1){ int c; if (x<r->d) c=0; else if (x>r->d) c=1; else {r->size++;r->sum++;splay(r); return;} if (r->ch[c]==null){ r->ch[c]=newnode(); r->ch[c]->fa=r; r->ch[c]->d=x; r->ch[c]->size=1; r->ch[c]->sum=1; splay(r->ch[c]); return; }else r=r->ch[c]; } } node *kth(int k){ node *r=ROOT; while (r!=null){ if (r->ch[0]->sum>=k) r=r->ch[0]; else if (r->ch[0]->sum+r->size>=k) return r; else k-=r->ch[0]->sum+r->size,r=r->ch[1]; } return null; } node *ques(int k){ node *r=ROOT; int ans=0; while (r!=null){ if (k<r->d) r=r->ch[0]; else if (k>r->d) ans+=r->ch[0]->sum+r->size,r=r->ch[1]; else {printf("%d\n",ans+r->ch[0]->sum+1); return r;} } return null; } node *ques2(int k){ node *r=ROOT; while (r!=null){ if (k<r->d) r=r->ch[0]; else if (k>r->d) r=r->ch[1]; else return r; } return null; } node *rightdown(node *r){ while (r->ch[1]!=null){ r=r->ch[1]; }return r; } node *leftdown(node *r){ while (r->ch[0]!=null){ r=r->ch[0]; }return r; } void deleter(node *r){ if (r->size>1){ splay(r); r->size--; r->sum--; return; }else{ splay(r); if ((r->ch[0]==null)&&(r->ch[1]==null)){ ROOT=null; }else if (r->ch[0]==null){ r->ch[1]->fa=null; ROOT=r->ch[1]; }else if (r->ch[1]==null){ r->ch[0]->fa=null; ROOT=r->ch[0]; }else{ splay(rightdown(r->ch[0]),ROOT); r->ch[0]->ch[1]=r->ch[1]; r->ch[1]->fa=r->ch[0]; r->ch[0]->fa=null; r->ch[0]->count(); ROOT=r->ch[0]; } } } int predd(node *r,int x){ if (r==null) return -1E7-10; if (x<=r->d) return predd(r->ch[0],x); return max(r->d,predd(r->ch[1],x)); } int pross(node *r,int x){ if (r==null) return 1E7+10; if (r->d<=x) return pross(r->ch[1],x); return min(r->d,pross(r->ch[0],x)); } int predds(int x){ return predd(ROOT,x); } int prosss(int x){ return pross(ROOT,x); } } int main() { int n,x,num; n=getint(); Splay::Build(); while (n>0){n--; x=getint(); switch(x){ node *r; case 1: num=getint(); Splay::insert(num); break; case 2: num=getint(); r=Splay::ques2(num); Splay::deleter(r); break; case 3: num=getint(); r=Splay::ques(num); break; case 4: num=getint(); r=Splay::kth(num); printf("%d\n",r->d); break; case 5: num=getint(); printf("%d\n",Splay::predds(num)); break; case 6: num=getint(); printf("%d\n",Splay::prosss(num)); break; } } return 0; }
差不多了