noi2005维护数列 splay

用splay来维护数列的一道很好的模版练习题

有些要注意的地方:

  • 不用指针来实现splay的话要注意内存回收,不然就会无限RE+TLE
  • 凡是有对序列进行修改的操作都要up一下
  • reverse的时候要注意交换左起最大连续和跟右起最大连续和
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<string>
#include<iomanip>
#include<vector>
#include<set>
#include<map>
#include<queue>

using namespace std;
typedef long long LL;
typedef unsigned long long ULL;

#define rep(i,k,n) for(int i=(k);i<=(n);i++)
#define rep0(i,n) for(int i=0;i<(n);i++)
#define red(i,k,n) for(int i=(k);i>=(n);i--)
#define sqr(x) ((x)*(x))
#define clr(x,y) memset((x),(y),sizeof(x))
#define pb push_back
#define mod 1000000007
#define L ch[x][0]
#define R ch[x][1]
#define KT (ch[ch[rt][1]][0])
const int inf=-2000;
const int maxn=510000;
int numa[maxn],numb[maxn];

struct SplayTree
{
    int sz[maxn];
    int ch[maxn][2];
    int pre[maxn];
    int rt,top;
    int pool[maxn],ptop;

    int flip[maxn];
    int same[maxn];
    int val[maxn];
    int sum[maxn];
    int maxs[maxn],maxl[maxn],maxr[maxn];

    inline void down(int x)
    {
        if(flip[x])
        {
            flip[L]^=1;
            flip[R]^=1;
            swap(L,R);
            swap(maxl[L],maxr[L]);
            swap(maxl[R],maxr[R]);
            flip[x]=0;
        }
        if(same[x]!=inf)
        {
            rep0(i,2)if(ch[x][i])
            {
                same[ch[x][i]]=same[x];
                val[ch[x][i]]=same[x];
                sum[ch[x][i]]=same[x]*sz[ch[x][i]];
                if(same[x]<=0)maxs[ch[x][i]]=maxl[ch[x][i]]=maxr[ch[x][i]]=same[x];
                else maxs[ch[x][i]]=maxl[ch[x][i]]=maxr[ch[x][i]]=sum[ch[x][i]];
            }
            same[x]=inf;
        }
    }

    inline void up(int x)
    {
        sz[x]=1+sz[L]+sz[R];
        sum[x]=sum[L]+val[x]+sum[R];
        maxl[x]=max(max(maxl[L],sum[L]+val[x]),sum[L]+val[x]+maxl[R]);
        maxr[x]=max(max(maxr[R],sum[R]+val[x]),sum[R]+val[x]+maxr[L]);
        maxs[x]=max(max(maxs[L],maxs[R]),max(maxr[L],maxl[R])+val[x]);
        maxs[x]=max(max(maxs[x],val[x]),maxr[L]+val[x]+maxl[R]);
    }

    inline void Rotate(int x,int f)
    {
        int y=pre[x];
        down(y);
        down(x);
        ch[y][!f]=ch[x][f];
        pre[ch[x][f]]=y;
        pre[x]=pre[y];
        if(pre[x])ch[pre[y]][ch[pre[y]][1]==y]=x;
        ch[x][f]=y;
        pre[y]=x;
        up(y);
    }
    void Splay(int x,int goal)
    {
        down(x);
        while(pre[x]!=goal)
        {
            down(pre[pre[x]]);down(pre[x]);down(x);
            if(pre[pre[x]]==goal)Rotate(x,ch[pre[x]][0]==x);
            else
            {
                int y=pre[x],z=pre[y];
                int f=(ch[z][0]==y);
                if(ch[y][f]==x)Rotate(x,!f),Rotate(x,f);
                else Rotate(y,f),Rotate(x,f);
            }
        }
        up(x);
        if(goal==0)rt=x;
    }
    inline void RTO(int k,int goal)
    {
        int x=rt;
        down(x);
        while(sz[L]+1!=k)
        {
            if(k<sz[L]+1)x=L;
            else
            {
                k-=(sz[L]+1);
                x=R;
            }
            down(x);
        }
        Splay(x,goal);
    }
    void Newnode(int &x,int c,int f)
    {
        if(ptop)
        {
            x=pool[ptop--];
        }
        else
        {
            x=++top;
        }
        flip[x]=0;same[x]=inf;
        L=R=0;pre[x]=f;
        sz[x]=1;
        val[x]=sum[x]=maxs[x]=maxl[x]=maxr[x]=c;
    }
    void build(int &x,int l,int r,int f,int num[])
    {
        if(l>r)return ;
        int m=l+r>>1;
        Newnode(x,num[m],f);
        build(L,l,m-1,x,num);
        build(R,m+1,r,x,num);
        pre[x]=f;
        up(x);
    }
    void init(int n)
    {
        ch[0][0]=ch[0][1]=pre[0]=sz[0]=0;
        rt=top=0;flip[0]=val[0]=0;
        ptop=0;
        same[0]=inf;
        sum[0]=0;
        maxs[0]=maxl[0]=maxr[0]=inf;
        Newnode(rt,inf,0);
        Newnode(ch[rt][1],inf,rt);
        sz[rt]=2;
        build(KT,1,n,ch[rt][1],numa);
        up(ch[rt][1]);up(rt);
    }
    void SAME(int pos,int len,int c)
    {
        if(len<=0)return;
        int a=pos,b=pos+len-1;
        RTO(a,0);
        RTO(b+2,rt);
        same[KT]=val[KT]=c;
        sum[KT]=c*sz[KT];
        if(c<=0)maxs[KT]=maxl[KT]=maxr[KT]=c;
        else maxs[KT]=maxl[KT]=maxr[KT]=sum[KT];
        up(ch[rt][1]);up(rt);
    }
    void REVERSE(int pos,int len)
    {
        if(len<=0)return;
        int a=pos,b=pos+len-1;
        RTO(a,0);
        RTO(b+2,rt);
        flip[KT]^=1;
        swap(maxl[KT],maxr[KT]);
        up(ch[rt][1]);up(rt);
    }
    void INSERT(int pos,int len,int c[])
    {
        if(len<=0)return;
        RTO(pos+1,0);
        RTO(pos+2,rt);
        build(KT,1,len,ch[rt][1],c);
        up(ch[rt][1]);up(rt);
    }
    void collect(int x)
    {
        if(x==0)return;
        pool[++ptop]=x;
        collect(L);collect(R);
    }
    void DELETE(int pos,int len)
    {
        if(len<=0)return;
        int a=pos,b=pos+len-1;
        RTO(a,0);
        RTO(b+2,rt);
        collect(KT);
        KT=0;
        up(ch[rt][1]);up(rt);
    }
    int GETSUM(int pos,int len)
    {
        if(len<=0)return 0;
        int a=pos,b=pos+len-1;
        RTO(a,0);
        RTO(b+2,rt);
        return sum[KT];
    }
    int MAXSUM()
    {
        if(sz[rt]==2)return 0;
        return maxs[rt];
    }
//    void debug(int x)
//    {
//        down(x);
//        if(L)debug(L);
//        printf("%d ",val[x]);
//        if(R)debug(R);
//    }
}spt;


int main()
{
    int n,m,tmp,pos,len;
    char str[20];
    scanf("%d%d",&n,&m);
    rep(i,1,n)scanf("%d",&numa[i]);
    spt.init(n);

    rep(i,1,m)
    {
        //spt.debug(spt.rt);
        scanf("%s",str);
        if(str[0]=='I')
        {
            scanf("%d%d",&pos,&len);
            rep(j,1,len)scanf("%d",&numb[j]);
            spt.INSERT(pos,len,numb);
        }
        else if(str[0]=='D')
        {
            scanf("%d%d",&pos,&len);
            spt.DELETE(pos,len);
        }
        else if(str[0]=='R')
        {
            scanf("%d%d",&pos,&len);
            spt.REVERSE(pos,len);
        }
        else if(str[0]=='G')
        {
            scanf("%d%d",&pos,&len);
            printf("%d\n",spt.GETSUM(pos,len));
        }
        else
        {
            if(str[2]=='X')
            {
                printf("%d\n",spt.MAXSUM());
            }
            else if(str[2]=='K')
            {
                scanf("%d%d%d",&pos,&len,&tmp);
                spt.SAME(pos,len,tmp);
            }
        }
    }
	return 0;
}


你可能感兴趣的:(noi2005维护数列 splay)