今天真是有纪念意义啊……
以前试着捉了N遍的cashier今天竟然AC了,本沙茶终于掌握了平衡树!!!
【Splay Tree及其实现】
<1>结点记录的信息:
一般情况下Splay Tree是用线性存储器(结构数组)来存储的,可以避免在Linux下的指针异常问题。
这样对于某个结点,至少要记录以下的域:值(又叫关键字)、左子结点的下标、右子结点的下标、父结点下标、子树大小(就是以这个结点为根的子树中结点的总数)以及左右标志(为一个bool值,表示该结点是其父结点的左子结点还是右子结点),所要记录的其它域根据题目要求而定。另外还有一个域:重复次数mul,就是整棵树中与这个结点值相同的结点总数(关于这个域的作用将在下一篇里面总结)。
另外,为了防止越界,将T[0]预留出来作为哨兵结点。在树中,根结点的p值和叶结点的c值均为0。这个T[0]的sz值必须是0,其余的域无意义。
<2>旋转操作和伸展操作:
右旋(ZIG):如果某个非根结点X是其父结点Y的左子结点,则可以通过右旋操作将X旋转到Y的位置,即:先将Y的左子结点设为X的右子结点,再将X的右子结点设为Y;
左旋(ZAG):如果某个非根结点X是其父结点Y的右子结点,则可以通过左旋操作将X旋转到Y的位置,即:先将Y的右子结点设为X的左子结点,再将X的左子结点设为Y;
ZIG和ZAG操作可以合并,称为rot,代码如下:
其中的sc是set child的缩写,sc(_p, _c, _d)表示将T[_p]的_d子结点(0:左;1:右。下同)置为_c,代码如下(_c也可以是0,表示删除T[_p]对应的子结点):
以前试着捉了N遍的cashier今天竟然AC了,本沙茶终于掌握了平衡树!!!
【Splay Tree及其实现】
<1>结点记录的信息:
一般情况下Splay Tree是用线性存储器(结构数组)来存储的,可以避免在Linux下的指针异常问题。
这样对于某个结点,至少要记录以下的域:值(又叫关键字)、左子结点的下标、右子结点的下标、父结点下标、子树大小(就是以这个结点为根的子树中结点的总数)以及左右标志(为一个bool值,表示该结点是其父结点的左子结点还是右子结点),所要记录的其它域根据题目要求而定。另外还有一个域:重复次数mul,就是整棵树中与这个结点值相同的结点总数(关于这个域的作用将在下一篇里面总结)。
struct
node {
int v, c[ 2 ], p, sz, mul;
bool d;
} T[MAXN];
以上v为值、c[0]和c[1]表示左右子结点下标、p表示父结点下标、sz表示子树大小、mul表示重复次数、d表示左右标志。
int v, c[ 2 ], p, sz, mul;
bool d;
} T[MAXN];
另外,为了防止越界,将T[0]预留出来作为哨兵结点。在树中,根结点的p值和叶结点的c值均为0。这个T[0]的sz值必须是0,其余的域无意义。
<2>旋转操作和伸展操作:
右旋(ZIG):如果某个非根结点X是其父结点Y的左子结点,则可以通过右旋操作将X旋转到Y的位置,即:先将Y的左子结点设为X的右子结点,再将X的右子结点设为Y;
左旋(ZAG):如果某个非根结点X是其父结点Y的右子结点,则可以通过左旋操作将X旋转到Y的位置,即:先将Y的右子结点设为X的左子结点,再将X的左子结点设为Y;
ZIG和ZAG操作可以合并,称为rot,代码如下:
void
rot(
int
x)
{
int y = T[x].p, d = T[x].d;
if (y == root) {root = x; T[root].p = 0 ;} else sc(T[y].p, x, T[y].d);
sc(y, T[x].c[ ! d], d); sc(x, y, ! d); upd(y);
}
{
int y = T[x].p, d = T[x].d;
if (y == root) {root = x; T[root].p = 0 ;} else sc(T[y].p, x, T[y].d);
sc(y, T[x].c[ ! d], d); sc(x, y, ! d); upd(y);
}
void
sc(
int
_p,
int
_c,
bool
_d)
{
T[_p].c[_d] = _c; T[_c].p = _p; T[_c].d = _d;
}
其中的upd是update的缩写,upd(x)表示当x的子结点改变时,更新x的一些可维护域(这里只有sz值,有的题目比如NOI2005 sequence里面有其它的可维护域):
{
T[_p].c[_d] = _c; T[_c].p = _p; T[_c].d = _d;
}
void
upd(
int
x)
{
T[x].sz = T[T[x].c[ 0 ]].sz + T[T[x].c[ 1 ]].sz + T[x].mul;
}
{
T[x].sz = T[T[x].c[ 0 ]].sz + T[T[x].c[ 1 ]].sz + T[x].mul;
}
然后就是Splay Tree的核心操作——伸展操作(Splay):
Splay(x, r)表示将x伸展到r的子结点处,若r=0,则表示伸展到根(因为根的父结点为T[0])。过程如下:
(1)设x的父结点为p。若p的父结点即是r,则rot(x);
(2)若p的父结点不是r且T[x].d=T[p].d,则先rot(p)再rot(x);
(3)若p的父结点不是r且T[x].d!=T[p].d,则两次rot(x);
(4)重复以上过程直到x的父结点为r;
void
splay(
int
x,
int
r)
{
int p; while ((p = T[x].p) != r) if (T[p].p == r) rot(x); else if (T[x].d == T[p].d) {rot(p); rot(x);} else {rot(x); rot(x);} upd(x);
}
这里有一个问题:为什么在旋转操作中只更新y不更新x,而在伸展操作的最后则要更新x?这个在JZP神犇的论文中有解释:因为在旋转过程中,x的子结点一直在改变,故过早地跟新x没有意义。
{
int p; while ((p = T[x].p) != r) if (T[p].p == r) rot(x); else if (T[x].d == T[p].d) {rot(p); rot(x);} else {rot(x); rot(x);} upd(x);
}
<3>查找操作:
下面开始进入正式操作了。
首先是查找。find(x)表示在树中找值为x的结点,若找到返回其下标,若找不到返回0。这个应该是很容易实现的。
int
find(
int
x)
{
int i = x, v0; while (i) {v0 = T[i].v; if (v0 == x) break ; else i = T[i].c[v0 > x];} return i;
}
<4>插入操作:
{
int i = x, v0; while (i) {v0 = T[i].v; if (v0 == x) break ; else i = T[i].c[v0 > x];} return i;
}
ins(_v)表示在树中插入一个值为_v的结点。由于树是否为空的问题以及mul的引入,插入操作有三种可能结果:
(1)树为空(根结点为0):此时将插入一个新的结点,值为_v,初始sz、mul值均为1,并将其作为根结点;
(2)树非空且值为_v的结点在树中不存在:此时将插入一个新的结点,值为_v,初始sz、mul值均为1;
(3)树非空且值为_v的结点在树中已存在:此时会将树中这个值为_v的结点的mul值加1;
void
ins(
int
_v)
{
if ( ! root) {T[ ++ n].v = _v; T[n].c[ 0 ] = T[n].c[ 1 ] = T[n].p = 0 ; T[n].sz = T[n].mul = 1 ; root = n; return ;}
int i = root, j;
while ( 1 ) {
T[i].sz ++ ;
if (T[i].v == _v) {T[i].mul ++ ; splay(i, 0 ); return ;}
j = T[i].c[_v > T[i].v];
if ( ! j) break ; else i = j;
}
T[ ++ n].v = _v; T[n].c[ 0 ] = T[n].c[ 1 ] = 0 ; T[n].sz = T[n].mul = 1 ; sc(i, n, _v > T[i].v); splay(n, 0 );
}
<5>删除操作:
{
if ( ! root) {T[ ++ n].v = _v; T[n].c[ 0 ] = T[n].c[ 1 ] = T[n].p = 0 ; T[n].sz = T[n].mul = 1 ; root = n; return ;}
int i = root, j;
while ( 1 ) {
T[i].sz ++ ;
if (T[i].v == _v) {T[i].mul ++ ; splay(i, 0 ); return ;}
j = T[i].c[_v > T[i].v];
if ( ! j) break ; else i = j;
}
T[ ++ n].v = _v; T[n].c[ 0 ] = T[n].c[ 1 ] = 0 ; T[n].sz = T[n].mul = 1 ; sc(i, n, _v > T[i].v); splay(n, 0 );
}
del(x)表示将下标为x的结点删除。
除了一般的二叉查找树的删除方法外,Splay Tree还有一种删除方式:先找到x的前趋P和x的后继S(具体操作见<6>),并将P伸展到根,S伸展到P的右子结点处,这样S的左子树中只有一个结点,就是x,然后再将S的左子结点置为0即可。需要注意的是几种特殊情况:
(1)x无前趋或无后继:此时将x伸展到根后,x只有一棵子树,直接将根结点设为x的那个子结点即可;
(2)x既无前趋也无后继:此时x就是树中的唯一一个结点,将根结点设为0即可;
(3)T[x].mul>1,直接将T[x].mul值减1(与插入类似);
void
del(
int
x)
{
if (T[x].mul > 1 ) T[x].mul -- ; else {
splay(x, 0 );
int y = succ(), y2 = pred();
if ( ! y) {root = T[x].c[ 0 ]; T[root].p = 0 ;} else if ( ! y2) {root = T[x].c[ 1 ]; T[root].p = 0 ;} else {
splay(y2, 0 ); splay(y, root);
T[x].p = 0 ; T[y].c[T[x].d] = 0 ; upd(y); upd(root);
}
}
}
<6>找前趋和后继:
{
if (T[x].mul > 1 ) T[x].mul -- ; else {
splay(x, 0 );
int y = succ(), y2 = pred();
if ( ! y) {root = T[x].c[ 0 ]; T[root].p = 0 ;} else if ( ! y2) {root = T[x].c[ 1 ]; T[root].p = 0 ;} else {
splay(y2, 0 ); splay(y, root);
T[x].p = 0 ; T[y].c[T[x].d] = 0 ; upd(y); upd(root);
}
}
}
pred(x)和succ(x):分别求出x的前趋和后继(x的前趋表示树中值小于x的最大的结点;x的后继表示树中值大于x的最小的结点),并返回它们的下标,若不存在返回0。
先将x伸展到根,然后x的左子结点的右链上的最后一个结点就是x的前趋,x的右子结点的左链上的最后一个结点就是x的后继。
int
pred()
{
int i = T[root].c[ 0 ], j;
if ( ! i) return 0 ;
while (j = T[i].c[ 1 ]) i = j;
return i;
}
int succ()
{
int i = T[root].c[ 1 ], j;
if ( ! i) return 0 ;
while (j = T[i].c[ 0 ]) i = j;
return i;
}
注意,这里的pred和succ是求根结点的前趋和后继,因此在调用pred或succ前必须保证结点x已经被伸展到了根的位置。{
int i = T[root].c[ 0 ], j;
if ( ! i) return 0 ;
while (j = T[i].c[ 1 ]) i = j;
return i;
}
int succ()
{
int i = T[root].c[ 1 ], j;
if ( ! i) return 0 ;
while (j = T[i].c[ 0 ]) i = j;
return i;
}
前趋和后继还有另一种定义方式:x的前趋表示树中值 不大于 x的最大的结点,x的后继表示树中值 不小于x的最小的结点,此时只需在pred和succ函数的开头分别加入if (T[root].mul > 1) return root;即可。
<7>找第K小以及找指定结点是第几小:
找第K小以及找指定结点是第几小的操作是平衡树的特有操作。Splay Tree找第K小的操作与其它平衡树相同。
int
Find_Kth(
int
K)
{
int i = root, s0, m0;
while ( 1 ) {
s0 = T[T[i].c[ 0 ]].sz; m0 = T[i].mul;
if (K <= s0) i = T[i].c[ 0 ]; else if (K <= s0 + m0) return T[i].v; else {K -= s0 + m0; i = T[i].c[ 1 ];}
}
}
而Splay Tree找指定结点是第几小的操作则是特有的:将这个结点伸展到根,则其(左子树大小+1)即为结果。
{
int i = root, s0, m0;
while ( 1 ) {
s0 = T[T[i].c[ 0 ]].sz; m0 = T[i].mul;
if (K <= s0) i = T[i].c[ 0 ]; else if (K <= s0 + m0) return T[i].v; else {K -= s0 + m0; i = T[i].c[ 1 ];}
}
}
int
rank(
int
x)
{
splay(x, 0 ); return T[T[x].c[ 0 ]].sz + 1 ;
}
{
splay(x, 0 ); return T[T[x].c[ 0 ]].sz + 1 ;
}
【例题】 NOI2004 cashier:
本题中涉及到一个进阶操作:删除值在某一区间内的所有结点。这一操作会在下一篇里总结。
#include
<
iostream
>
#include < stdio.h >
using namespace std;
#define re(i, n) for (int i=0; i<n; i++)
const int MAXN = 100001 , INF = ~ 0U >> 2 ;
struct node {
int v, c[ 2 ], p, sz, mul;
bool d;
} T[MAXN];
int n = 0 , root, res, tot = 0 ;
void upd( int x)
{
T[x].sz = T[T[x].c[ 0 ]].sz + T[T[x].c[ 1 ]].sz + T[x].mul;
}
void sc( int _p, int _c, bool _d)
{
T[_p].c[_d] = _c; T[_c].p = _p; T[_c].d = _d;
}
void rot( int x)
{
int y = T[x].p, d = T[x].d;
if (y == root) {root = x; T[root].p = 0 ;} else sc(T[y].p, x, T[y].d);
sc(y, T[x].c[ ! d], d); sc(x, y, ! d); upd(y); upd(x);
}
void splay( int x, int r)
{
int i = x, p, p0;
while ((p0 = T[i].p) != r) {
p = T[p0].p;
if (p == r) rot(i); else if (T[i].d == T[p0].d) {rot(p0); rot(i);} else {rot(i); rot(i);}
}
}
void ins( int _v)
{
if ( ! root) {T[ ++ n].v = _v; T[n].c[ 0 ] = T[n].c[ 1 ] = T[n].p = 0 ; T[n].sz = T[n].mul = 1 ; root = n; return ;}
int i = root, j;
while ( 1 ) {
T[i].sz ++ ;
if (T[i].v == _v) {T[i].mul ++ ; splay(i, 0 ); return ;}
j = T[i].c[_v > T[i].v];
if ( ! j) break ; else i = j;
}
T[ ++ n].v = _v; T[n].c[ 0 ] = T[n].c[ 1 ] = 0 ; T[n].sz = T[n].mul = 1 ; sc(i, n, _v > T[i].v); splay(n, 0 );
}
void del( int lmt)
{
if ( ! root) return ;
int i = root, _min = INF, b = 0 , v0;
while (i) {
v0 = T[i].v;
if (v0 == lmt) {b = i; break ;}
if (v0 < lmt) i = T[i].c[ 1 ]; else { if (v0 < _min) {_min = v0; b = i;} i = T[i].c[ 0 ];}
}
if ( ! b) {tot += T[root].sz; root = 0 ;} else {splay(b, 0 ); tot += T[T[root].c[ 0 ]].sz; T[T[root].c[ 0 ]].p = 0 ; T[root].c[ 0 ] = 0 ; upd(root);}
}
int Find_Kth( int K)
{
int i = root, s0, m0;
while ( 1 ) {
s0 = T[T[i].c[ 0 ]].sz; m0 = T[i].mul;
if (K <= s0) i = T[i].c[ 0 ]; else if (K <= s0 + m0) return T[i].v; else {K -= s0 + m0; i = T[i].c[ 1 ];}
}
}
int main()
{
freopen( " cashier.in " , " r " , stdin);
freopen( " cashier.out " , " w " , stdout);
int m, minv, delt = 0 , x;
char ch;
scanf( " %d%d%*c " , & m, & minv);
re(i, m) {
scanf( " %c%d%*c " , & ch, & x);
switch (ch) {
case ' I ' : { if (x >= minv) ins(x - delt); break ;}
case ' A ' : {delt += x; break ;}
case ' S ' : {delt -= x; del(minv - delt); break ;}
case ' F ' : if (T[root].sz - x + 1 <= 0 ) printf( " %d\n " , - 1 ); else printf( " %d\n " , Find_Kth(T[root].sz - x + 1 ) + delt);
}
}
printf( " %d\n " , tot);
fclose(stdin); fclose(stdout);
return 0 ;
}
#include < stdio.h >
using namespace std;
#define re(i, n) for (int i=0; i<n; i++)
const int MAXN = 100001 , INF = ~ 0U >> 2 ;
struct node {
int v, c[ 2 ], p, sz, mul;
bool d;
} T[MAXN];
int n = 0 , root, res, tot = 0 ;
void upd( int x)
{
T[x].sz = T[T[x].c[ 0 ]].sz + T[T[x].c[ 1 ]].sz + T[x].mul;
}
void sc( int _p, int _c, bool _d)
{
T[_p].c[_d] = _c; T[_c].p = _p; T[_c].d = _d;
}
void rot( int x)
{
int y = T[x].p, d = T[x].d;
if (y == root) {root = x; T[root].p = 0 ;} else sc(T[y].p, x, T[y].d);
sc(y, T[x].c[ ! d], d); sc(x, y, ! d); upd(y); upd(x);
}
void splay( int x, int r)
{
int i = x, p, p0;
while ((p0 = T[i].p) != r) {
p = T[p0].p;
if (p == r) rot(i); else if (T[i].d == T[p0].d) {rot(p0); rot(i);} else {rot(i); rot(i);}
}
}
void ins( int _v)
{
if ( ! root) {T[ ++ n].v = _v; T[n].c[ 0 ] = T[n].c[ 1 ] = T[n].p = 0 ; T[n].sz = T[n].mul = 1 ; root = n; return ;}
int i = root, j;
while ( 1 ) {
T[i].sz ++ ;
if (T[i].v == _v) {T[i].mul ++ ; splay(i, 0 ); return ;}
j = T[i].c[_v > T[i].v];
if ( ! j) break ; else i = j;
}
T[ ++ n].v = _v; T[n].c[ 0 ] = T[n].c[ 1 ] = 0 ; T[n].sz = T[n].mul = 1 ; sc(i, n, _v > T[i].v); splay(n, 0 );
}
void del( int lmt)
{
if ( ! root) return ;
int i = root, _min = INF, b = 0 , v0;
while (i) {
v0 = T[i].v;
if (v0 == lmt) {b = i; break ;}
if (v0 < lmt) i = T[i].c[ 1 ]; else { if (v0 < _min) {_min = v0; b = i;} i = T[i].c[ 0 ];}
}
if ( ! b) {tot += T[root].sz; root = 0 ;} else {splay(b, 0 ); tot += T[T[root].c[ 0 ]].sz; T[T[root].c[ 0 ]].p = 0 ; T[root].c[ 0 ] = 0 ; upd(root);}
}
int Find_Kth( int K)
{
int i = root, s0, m0;
while ( 1 ) {
s0 = T[T[i].c[ 0 ]].sz; m0 = T[i].mul;
if (K <= s0) i = T[i].c[ 0 ]; else if (K <= s0 + m0) return T[i].v; else {K -= s0 + m0; i = T[i].c[ 1 ];}
}
}
int main()
{
freopen( " cashier.in " , " r " , stdin);
freopen( " cashier.out " , " w " , stdout);
int m, minv, delt = 0 , x;
char ch;
scanf( " %d%d%*c " , & m, & minv);
re(i, m) {
scanf( " %c%d%*c " , & ch, & x);
switch (ch) {
case ' I ' : { if (x >= minv) ins(x - delt); break ;}
case ' A ' : {delt += x; break ;}
case ' S ' : {delt -= x; del(minv - delt); break ;}
case ' F ' : if (T[root].sz - x + 1 <= 0 ) printf( " %d\n " , - 1 ); else printf( " %d\n " , Find_Kth(T[root].sz - x + 1 ) + delt);
}
}
printf( " %d\n " , tot);
fclose(stdin); fclose(stdout);
return 0 ;
}
【相关论文】
(1) The Magical Splay
(2) 运用伸展树解决数列维护问题
【感谢】
CLJ 神犇
GYZ 神犇
Jollwish 神犇
以及网上提供cashier标程的
etc.