#include<bits/stdc++.h> using namespace std; struct node{ int val; node *l,*r; }; node *newnode(int x){ node *ret = new node; ret->l=ret->r=NULL; ret->val = x; return ret; } void zig(node* &o){ node *k = o->l; o->l = k->r; k->r = o; o = k; } void zag(node* &o){ node *k = o->r; o->r = k->l; k->l = o; o = k; } void inser(int x,node *rot){ if(rot->val > x){ if(rot->l==NULL) rot->l = newnode(x); else inser(x,rot->l); } else{ if(rot->r==NULL){ rot->r = newnode(x); } else{ inser(x,rot->r); } } } bool fndx(node *&rot,int x,int &d){ if(rot->val == x){ d = 0; return true; } else{ int k; if(rot->val > x){ //to left if(rot->l == NULL || fndx(rot->l,x,k)==false) return false; else{ if(k==0){ d = -1; return true; } else{ if(k == -1){ // zig zig zig(rot);zig(rot); } else{ // k=1, zag zig zag(rot->l);zig(rot); } d = 0; return true; } } } else{ if(rot->r == NULL || fndx(rot->r,x,k)==false) return false; else{ if(k==0){ d = 1; return true; } else{ if(k == 1){ // zag zag zag(rot);zag(rot); } else{ // zig zag zig(rot->r);zag(rot); } d = 0; return true; } } } } } bool fnd(node *&root,int x){ int k; if(fndx(root,x,k)==false){ return false; } else{ if(k){ if(k==1) zag(root); else zig(root); } return true; } } void out(node *s,int d,int k){ if(s->l) out(s->l,d+1,-1); for(int i =0;i<d;i++) printf(" "); if(k!=0){ printf(" "); if(k==1) printf("\\"); else printf("/"); } else{ printf("-"); } printf("%d\n",s->val); if(s->r) out(s->r,d+1,1); } int n =30; int main(){ node *root = newnode(0); for(int i = 1;i<n;i++){ inser(i,root); } out(root,0,0); int x; while(~scanf("%d",&x)){ printf("after find %d",x); puts("----------------"); fnd(root,x); out(root,0,0); } return 0; }
=============update==============
给出一些学习的建议吧
刚开始学的时候建议在纸上模拟几遍左旋和右旋的过程,最好能到想到zig或zag就可以在脑海里模拟出这个过程(
然后看双旋的时候也是
双旋到想到x在哪个方向能在纸上模拟出这个过程就可以了
然后就可以自己独立写一遍了,(这个时候会感觉自己写的十分复杂。。。就像我上面那个代码一样QAQ
再接着可以压缩自己的代码(提示:根据双旋的对称性)
以及下面是我优化过的代码
#include<bits/stdc++.h> using namespace std; #define l ch[0] #define r ch[1] struct node{ int val; node *ch[2]; }; node *newnode(int x){ node * ret = new node; ret->l=ret->r=NULL; ret->val = x; return ret; } void inser(node *rot,int x){ int d = (rot->val > x)?0:1; if(rot->ch[d]) inser(rot->ch[d],x); else rot->ch[d]=newnode(x); } void out(node *s,int d,int k){ if(s->l) out(s->l,d+1,-1); for(int i =0;i<d;i++) printf(" "); if(k!=0){ printf(" "); if(k==1) printf("\\"); else printf("/"); } else{ printf("-"); } printf("%d\n",s->val); if(s->r) out(s->r,d+1,1); } void print(node *s){ puts("-----"); out(s,0,0); puts("====="); } node *root = newnode(0); void zg(node* &o,int d){ node * k = o->ch[d]; o->ch[d] = k->ch[1^d]; k->ch[1^d] = o; o = k; } int splay(int x,node *& rot){ if(!rot) return -1; if(rot->val == x) return 2; else{ int d = rot->val > x?0:1; int k = splay(x,rot->ch[d]); if(k==2) return d; if(k==-1) return -1; if(k==d){ zg(rot,k); zg(rot,k); } else{ zg(rot->ch[d],k); zg(rot,d); } return 2; } } bool fnd(node* &root,int x){ int k = splay(x,root); if(k==-1) return false; if(k!=2){ zg(root,k); } return true; } int n =16; int main(){ for(int i = 1;i<n;i++){ inser(root,i); } print(root); int x; while(~scanf("%d",&x)){ if(x>=n){ for(int i = 0;i<n;i++){ fnd(root,i); print(root); } continue; } if(x < 0){ for(int i =n-1;i>=0;i--){ fnd(root,i); print(root); } continue; } printf("after find %d",x); fnd(root,x); print(root); } return 0; }