splay 伸展树 代码实现

Splay 概念文章: http://blog.csdn.net/naivebaby/article/details/1357734

叉姐 数组实现: https://github.com/ftiasch/mithril/blob/master/2012-10-24/I.cpp#L43

Vani 指针实现: https://github.com/Azure-Vani/acm-icpc/blob/master/spoj/SEQ2.cpp

hdu 1890 写法: http://blog.csdn.net/fp_hzq/article/details/8087431

HH splay写法: http://www.notonlysuccess.com/index.php/splay-tree/

 

poj 3468 HH写法

View Code
  1 /*

  2 http://acm.pku.edu.cn/JudgeOnline/problem?id=3468

  3 区间跟新,区间求和

  4 */

  5 #include <cstdio>

  6 #define keyTree (ch[ ch[root][1] ][0])

  7 const int maxn = 222222;

  8 struct SplayTree{

  9     int sz[maxn];

 10     int ch[maxn][2];

 11     int pre[maxn];

 12     int root , top1 , top2;

 13     int ss[maxn] , que[maxn];

 14  

 15     inline void Rotate(int x,int f) {

 16         int y = pre[x];

 17         push_down(y);

 18         push_down(x);

 19         ch[y][!f] = ch[x][f];

 20         pre[ ch[x][f] ] = y;

 21         pre[x] = pre[y];

 22         if(pre[x]) ch[ pre[y] ][ ch[pre[y]][1] == y ] = x;

 23         ch[x][f] = y;

 24         pre[y] = x;

 25         push_up(y);

 26     }

 27     inline void Splay(int x,int goal) {

 28         push_down(x);

 29         while(pre[x] != goal) {

 30             if(pre[pre[x]] == goal) {

 31                 Rotate(x , ch[pre[x]][0] == x);

 32             } else {

 33                 int y = pre[x] , z = pre[y];

 34                 int f = (ch[z][0] == y);

 35                 if(ch[y][f] == x) {

 36                     Rotate(x , !f) , Rotate(x , f);

 37                 } else {

 38                     Rotate(y , f) , Rotate(x , f);

 39                 }

 40             }

 41         }

 42         push_up(x);

 43         if(goal == 0) root = x;

 44     }

 45     inline void RotateTo(int k,int goal) {//把第k位的数转到goal下边

 46         int x = root;

 47         push_down(x);

 48         while(sz[ ch[x][0] ] != k) {

 49             if(k < sz[ ch[x][0] ]) {

 50                 x = ch[x][0];

 51             } else {

 52                 k -= (sz[ ch[x][0] ] + 1);

 53                 x = ch[x][1];

 54             }

 55             push_down(x);

 56         }

 57         Splay(x,goal);

 58     }

 59     inline void erase(int x) {//把以x为祖先结点删掉放进内存池,回收内存

 60         int father = pre[x];

 61         int head = 0 , tail = 0;

 62         for (que[tail++] = x ; head < tail ; head ++) {

 63             ss[top2 ++] = que[head];

 64             if(ch[ que[head] ][0]) que[tail++] = ch[ que[head] ][0];

 65             if(ch[ que[head] ][1]) que[tail++] = ch[ que[head] ][1];

 66         }

 67         ch[ father ][ ch[father][1] == x ] = 0;

 68         pushup(father);

 69     }

 70     //以上一般不修改//////////////////////////////////////////////////////////////////////////////

 71     void debug() {printf("%d\n",root);Treaval(root);}

 72     void Treaval(int x) {

 73         if(x) {

 74             Treaval(ch[x][0]);

 75             printf("结点%2d:左儿子 %2d 右儿子 %2d 父结点 %2d size = %2d ,val = %2d\n",x,ch[x][0],ch[x][1],pre[x],sz[x],val[x]);

 76             Treaval(ch[x][1]);

 77         }

 78     }

 79     //以上Debug

 80  

 81  

 82     //以下是题目的特定函数:

 83     inline void NewNode(int &x,int c) {

 84         if (top2) x = ss[--top2];//用栈手动压的内存池

 85         else x = ++top1;

 86         ch[x][0] = ch[x][1] = pre[x] = 0;

 87         sz[x] = 1;

 88  

 89         val[x] = sum[x] = c;/*这是题目特定函数*/

 90         add[x] = 0;

 91     }

 92  

 93     //把延迟标记推到孩子

 94     inline void push_down(int x) {/*这是题目特定函数*/

 95         if(add[x]) {

 96             val[x] += add[x];

 97             add[ ch[x][0] ] += add[x];

 98             add[ ch[x][1] ] += add[x];

 99             sum[ ch[x][0] ] += (long long)sz[ ch[x][0] ] * add[x];

100             sum[ ch[x][1] ] += (long long)sz[ ch[x][1] ] * add[x];

101             add[x] = 0;

102         }

103     }

104     //把孩子状态更新上来

105     inline void push_up(int x) {

106         sz[x] = 1 + sz[ ch[x][0] ] + sz[ ch[x][1] ];

107         /*这是题目特定函数*/

108         sum[x] = add[x] + val[x] + sum[ ch[x][0] ] + sum[ ch[x][1] ];

109     }

110  

111     /*初始化*/

112     inline void makeTree(int &x,int l,int r,int f) {

113         if(l > r) return ;

114         int m = (l + r)>>1;

115         NewNode(x , num[m]);        /*num[m]权值改成题目所需的*/

116         makeTree(ch[x][0] , l , m - 1 , x);

117         makeTree(ch[x][1] , m + 1 , r , x);

118         pre[x] = f;

119         push_up(x);

120     }

121     inline void init(int n) {/*这是题目特定函数*/

122         ch[0][0] = ch[0][1] = pre[0] = sz[0] = 0;

123         add[0] = sum[0] = 0;

124  

125         root = top1 = 0;

126         //为了方便处理边界,加两个边界顶点

127         NewNode(root , -1);

128         NewNode(ch[root][1] , -1);

129         pre[top1] = root;

130         sz[root] = 2;

131  

132  

133         for (int i = 0 ; i < n ; i ++) scanf("%d",&num[i]);

134         makeTree(keyTree , 0 , n-1 , ch[root][1]);

135         push_up(ch[root][1]);

136         push_up(root);

137     }

138     /*更新*/

139     inline void update( ) {/*这是题目特定函数*/

140         int l , r , c;

141         scanf("%d%d%d",&l,&r,&c);

142         RotateTo(l-1,0);

143         RotateTo(r+1,root);

144         add[ keyTree ] += c;

145         sum[ keyTree ] += (long long)c * sz[ keyTree ];

146     }

147     /*询问*/

148     inline void query() {/*这是题目特定函数*/

149         int l , r;

150         scanf("%d%d",&l,&r);

151         RotateTo(l-1 , 0);

152         RotateTo(r+1 , root);

153         printf("%lld\n",sum[keyTree]);

154     }

155  

156  

157     /*这是题目特定变量*/

158     int num[maxn];

159     int val[maxn];

160     int add[maxn];

161     long long sum[maxn];

162 }spt;

163  

164  

165 int main() {

166     int n , m;

167     scanf("%d%d",&n,&m);

168     spt.init(n);

169     while(m --) {

170         char op[2];

171         scanf("%s",op);

172         if(op[0] == 'Q') {

173             spt.query();

174         } else {

175             spt.update();

176         }

177     }

178     return 0;

179 }

 

叉姐 

View Code
  1 #include <cstdio>

  2 #include <cstring>

  3 #include <vector>

  4 #include <climits>

  5 #include <algorithm>

  6 using namespace std;

  7 

  8 const int N = 200000;

  9 const int M = 1 + (N << 1);

 10 const int EMPTY = M - 1;

 11 

 12 const int MOD = 99990001;

 13 

 14 int nodeCount, type[M], parent[M], children[M][2], id[M];

 15 

 16 int scale[M], delta[M], weight[M], size[M], minimum[M];

 17 

 18 void update(int x) {

 19     size[x] = size[children[x][0]] + 1 + size[children[x][1]];

 20     minimum[x] = min(min(minimum[children[x][0]], minimum[children[x][1]]), id[x]);

 21 }

 22 

 23 void modify(int x, int k, int b) {

 24     weight[x] = ((long long)k * weight[x] + b) % MOD;

 25     scale[x] = (long long)k * scale[x] % MOD;

 26     delta[x] = ((long long)k * delta[x] + b) % MOD;

 27 }

 28 

 29 void pushDown(int x) {

 30     for (int i = 0; i < 2; ++ i) {

 31         if (children[x][i] != EMPTY) {

 32             modify(children[x][i], scale[x], delta[x]);

 33         }

 34     }

 35     scale[x] = 1;

 36     delta[x] = 0;

 37 }

 38 

 39 void rotate(int x) {

 40     int t = type[x];

 41     int y = parent[x];

 42     int z = children[x][1 ^ t];

 43     type[x] = type[y];

 44     parent[x] = parent[y];

 45     if (type[x] != 2) {

 46         children[parent[x]][type[x]] = x;

 47     }

 48     type[y] = 1 ^ t;

 49     parent[y] = x;

 50     children[x][1 ^ t] = y;

 51     if (z != EMPTY) {

 52         type[z] = t;

 53         parent[z] = y;

 54     }

 55     children[y][t] = z;

 56     update(y);

 57 }

 58 

 59 void splay(int x) {

 60     if (x == EMPTY) {

 61         return;

 62     }

 63     vector <int> stack(1, x);

 64     for (int i = x; type[i] != 2; i = parent[i]) {

 65         stack.push_back(parent[i]);

 66     }

 67     while (!stack.empty()) {

 68         pushDown(stack.back());

 69         stack.pop_back();

 70     }

 71     while (type[x] != 2) {

 72         int y = parent[x];

 73         if (type[x] == type[y]) {

 74             rotate(y);

 75         } else {

 76             rotate(x);

 77         }

 78         if (type[x] == 2) {

 79             break;

 80         }

 81         rotate(x);

 82     }

 83     update(x);

 84 }

 85 

 86 int goLeft(int x) {

 87     while (children[x][0] != EMPTY) {

 88         x = children[x][0];

 89     }

 90     return x;

 91 }

 92 

 93 int join(int x, int y) {

 94     if (x == EMPTY || y == EMPTY) {

 95         return x != EMPTY ? x : y;

 96     }

 97     y = goLeft(y);

 98     splay(y);

 99     splay(x);

100     type[x] = 0;

101     parent[x] = y;

102     children[y][0] = x;

103     update(y);

104     return y;

105 }

106 

107 pair <int, int> split(int x) {

108     splay(x);

109     int a = children[x][0];

110     int b = children[x][1];

111     children[x][0] = children[x][1] = EMPTY;

112     if (a != EMPTY) {

113         type[a] = 2;

114         parent[a] = EMPTY;

115     }

116     if (b != EMPTY) {

117         type[b] = 2;

118         parent[b] = EMPTY;

119     }

120     return make_pair(a, b);

121 }

122 

123 int newNode(int init, int vid) {

124     int x = nodeCount ++;

125     type[x] = 2;

126     parent[x] = children[x][0] = children[x][1] = EMPTY;

127     id[x] = vid;

128     weight[x] = init;

129     scale[x] = 1;

130     delta[x] = 0;

131     update(x);

132     return x;

133 }

134 

135 int n;

136 int edgeCount, firstEdge[N], to[M], nextEdge[M], initWeight[N], position[M];

137 

138 int root;

139 

140 void addEdge(int u, int v) {

141     to[edgeCount] = v;

142     nextEdge[edgeCount] = firstEdge[u];

143     firstEdge[u] = edgeCount ++;

144 }

145 

146 void dfs(int p, int u) {

147     for (int iter = firstEdge[u]; iter != -1; iter = nextEdge[iter]) {

148         int v = to[iter];

149         if (v != p) {

150             position[iter] = nodeCount;

151             root = join(root, newNode(initWeight[iter >> 1], min(u, v)));

152             dfs(u, v);

153             position[iter ^ 1] = nodeCount;

154             root = join(root, newNode(initWeight[iter >> 1], min(u, v)));

155         }

156     }

157 }

158 

159 int getRank(int x) { // 1-based

160     splay(x);

161     return size[children[x][0]] + 1;

162 }

163 

164 void print(int root) {

165     if (root != EMPTY) {

166         printf("[ ");

167         print(children[root][0]);

168         printf(" %d ", root);

169         print(children[root][1]);

170         printf(" ]");

171     }

172 }

173 

174 int main() {

175     size[EMPTY] = 0;

176     minimum[EMPTY] = INT_MAX;

177     parent[EMPTY] = 2;

178     scanf("%d", &n);

179     edgeCount = 0;

180     memset(firstEdge, -1, sizeof(firstEdge));

181     for (int i = 0; i < n - 1; ++ i) {

182         int a, b;

183         scanf("%d%d%d", &a, &b, initWeight + i);

184         a --;

185         b --;

186         addEdge(a, b);

187         addEdge(b, a);

188     }

189     nodeCount = 0;

190     root = EMPTY;

191     dfs(-1, 0);

192     for (int i = 0; i < n - 1; ++ i) {

193         int id;

194         scanf("%d", &id);

195         id --;

196 

197         int a = position[id << 1];

198         int b = position[(id << 1) ^ 1];

199         if (getRank(a) > getRank(b)) {

200             swap(a, b);

201         }

202         splay(a);

203 

204         int output = weight[a];

205         printf("%d\n", output);

206         fflush(stdout);

207 

208         pair <int, int> ret1 = split(a);

209         pair <int, int> ret2 = split(b);

210         int x = ret1.first;

211         int y = ret2.first;

212         int z = ret2.second;

213         x = join(z, x);

214         splay(x);

215         splay(y);

216         if (size[x] > size[y]) {

217             swap(x, y);

218         }

219         if (size[x] == size[y] && minimum[x] > minimum[y]) {

220             swap(x, y);

221         }

222         modify(x, output, 0);

223         modify(y, 1, output);

224     }

225     return 0;

226 }

 

spoj SEQ2

Vani 

View Code
  1 #include <cstdio>

  2 #include <cctype>

  3 #include <algorithm>

  4 #include <cstring>

  5 

  6 using namespace std;

  7 

  8 namespace Solve {

  9     const int MAXN = 500010;

 10     const int inf = 500000000;

 11 

 12     char BUF[50000000], *pos = BUF;

 13     inline int ScanInt(void) {

 14         int r = 0, d = 0;

 15         while (!isdigit(*pos) && *pos != '-') pos++;

 16         if (*pos != '-') r = *pos - 48; else d = 1; pos++;

 17         while ( isdigit(*pos)) r = r * 10 + *pos++ - 48;

 18         return d ? -r : r;

 19     }

 20     inline void ScanStr(char *st) {

 21         int l = 0;

 22         while (!(isupper(*pos) || *pos == '-')) pos++;

 23         st[l++] = *pos++;

 24         while (isupper(*pos) || *pos == '-') st[l++] = *pos++; st[l] = 0;

 25     }

 26 

 27     struct Node {

 28         Node *ch[2], *p;

 29         int v, lmax, rmax, m, same, rev, sum, size;

 30         inline bool dir(void) {return this == p->ch[1];}

 31         inline void SetC(Node *x, bool d) {ch[d] = x, x->p = this;}

 32         inline void Update(void) {

 33             Node *L = ch[0], *R = ch[1];

 34             size = L->size + R->size + 1;

 35             m = max(L->m, R->m);

 36             m = max(m, L->rmax + v + R->lmax);

 37             lmax = max(L->lmax, L->sum + v + R->lmax);

 38             rmax = max(R->rmax, R->sum + v + L->rmax);

 39             sum = L->sum + R->sum + v;

 40         }

 41         inline void Rev(void) {

 42             if (v == -inf) return;

 43             rev ^= 1;

 44             swap(ch[0], ch[1]);

 45             swap(lmax, rmax);

 46         }

 47         inline void Same(int u) {

 48             if (v == -inf) return;

 49             same = u;

 50             sum = u * size;

 51             if (sum > 0) lmax = rmax = m = sum; else lmax = 0, rmax = 0, m = u;

 52             v = u;

 53         }

 54         inline void Down(void) {

 55             if (rev) {

 56                 ch[0]->Rev(), ch[1]->Rev();

 57                 rev = 0;

 58             }

 59             if (same != -inf) {

 60                 ch[0]->Same(same), ch[1]->Same(same);

 61                 same = -inf;

 62             }

 63         }

 64     } Tnull, *null = &Tnull;

 65 

 66     class Splay {public:

 67         Node *root;

 68         inline void rotate(Node *x) {

 69             Node *p = x->p; bool d = x->dir();

 70             p->Down(); x->Down();

 71             p->p->SetC(x, p->dir());

 72             p->SetC(x->ch[!d], d);

 73             x->SetC(p, !d);

 74             p->Update();

 75         }

 76         inline void splay(Node *x, Node *G) {

 77             if (G == null) root = x;

 78             while (x->p != G) {

 79                 if (x->p->p == G) {rotate(x); break;}

 80                 else {if (x->dir() == x->p->dir()) rotate(x->p), rotate(x); else rotate(x), rotate(x);}

 81             }

 82             x->Update();

 83         }

 84         inline Node *Select(int k) {

 85             Node *t = root;

 86             while (t->Down(), t->ch[0]->size + 1 != k) {

 87                 if (k > t->ch[0]->size + 1) k -= t->ch[0]->size + 1, t = t->ch[1];

 88                 else t = t->ch[0];

 89             }

 90             splay(t, null);

 91             return t;

 92         }

 93         inline Node *getInterval(int l, int r) {

 94             Node *L = Select(l), *R = Select(r + 2);

 95             splay(L, null); splay(R, L);

 96             L->Down(); R->Down();

 97             return R;

 98         }

 99         inline void Insert(int pos, Node *x) {

100             Node *now = getInterval(pos + 1, pos);

101             now->SetC(x, 0);

102             now->Update(); root->Update();

103         }

104         inline void Delete(int l, int r) {

105             Node *now = getInterval(l, r);

106             now->ch[0] = null;

107             now->Update(); root->Update();

108         }

109         inline void Make(int l, int r, int c) {

110             Node *now = getInterval(l, r);

111             now->ch[0]->Same(c);

112             now->Update(); root->Update();

113         }

114         inline void Reverse(int l, int r) {

115             Node *now = getInterval(l, r);

116             now->ch[0]->Rev();

117             now->Update(); root->Update();

118         }

119         inline int Sum(int l, int r) {

120             Node *now = getInterval(l, r);

121             root->Down(); now->Down();

122             return now->ch[0]->sum;

123         }

124         inline int maxSum(int l, int r) {

125             Node *now = getInterval(l, r);

126             root->Down(); now->Down();

127             return now->ch[0]->m;

128         }

129         inline Node* Renew(int c) {

130             Node *ret = new Node;

131             ret->ch[0] = ret->ch[1] = ret->p = null; ret->size = 1;

132             ret->Same(c); ret->same = -inf;

133             return ret;

134         }

135         inline Node* Build(int l, int r, int *a) {

136             if (l > r) return null;

137             int mid = (l + r) >> 1;

138             Node *ret = Renew(a[mid]);

139             ret->ch[0] = Build(l, mid - 1, a);

140             ret->ch[1] = Build(mid + 1, r, a);

141             ret->ch[0]->p = ret->ch[1]->p = ret;

142             ret->Update();

143             return ret;

144         }

145         inline void P(Node *t) {

146             if (t == null) return;

147             t->Down(); t->Update();

148             P(t->ch[0]);

149             printf("%d ", t->v);

150             P(t->ch[1]);

151         }

152     }T;

153 

154 

155     int a[MAXN]; char ch[10];

156 

157     inline void solve(void) {

158         fread(BUF, 1, 50000000, stdin);

159         null->same = null->m = null->v = -inf;

160         int kase = ScanInt();

161         while (kase--) {

162             int n = ScanInt(), m = ScanInt();

163             for (int i = 1; i <= n; i++) a[i] = ScanInt();

164             T.root = T.Build(0, n + 1, a);

165             for (int i = 1; i <= m; i++) {

166                 ScanStr(ch);

167                 if (strcmp(ch, "INSERT") == 0) {

168                     int pos = ScanInt(), t = ScanInt();

169                     for (int j = 1; j <= t; j++) a[j] = ScanInt();

170                     Node *tmp = T.Build(1, t, a);

171                     T.Insert(pos, tmp);

172                 }

173                 if (strcmp(ch, "DELETE") == 0) {

174                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;

175                     T.Delete(l, r);

176                 }

177                 if (strcmp(ch, "MAKE-SAME") == 0) {

178                     int l = ScanInt(), r = ScanInt(), c = ScanInt(); r = l + r - 1;

179                     T.Make(l, r, c);

180                 }

181                 if (strcmp(ch, "REVERSE") == 0) {

182                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;

183                     T.Reverse(l, r);

184                 }

185                 if (strcmp(ch, "GET-SUM") == 0) {

186                     int l = ScanInt(), r = ScanInt(); r = l + r - 1;

187                     int ret = T.Sum(l, r);

188                     printf("%d\n", ret);

189                 }

190                 if (strcmp(ch, "MAX-SUM") == 0) {

191                     int ret = T.maxSum(1, T.root->size - 2);

192                     printf("%d\n", ret);

193                 }

194             }

195         }

196     }

197 }

198 

199 int main(void) {

200     freopen("in", "r", stdin);

201     Solve::solve();

202     return 0;

203 }

 

 

你可能感兴趣的:(play)