BZOJ 1500[NOI2005] 维修数列

Description

BZOJ 1500[NOI2005] 维修数列_第1张图片

解题思路:
输入的第1 行包含两个数N 和M(M ≤20 000),N 表示初始时数列中数的个数,M表示要进行的操作数目。
第2行包含N个数字,描述初始时的数列。
以下M行,每行一条命令,格式参见问题描述中的表格。
任何时刻数列中最多含有500 000个数,数列中任何一个数字均在[-1 000, 1 000]内。
插入的数字总数不超过4 000 000个,输入文件大小不超过20MBytes。
Output
对于输入数据中的GET-SUM和MAX-SUM操作,向输出文件依次打印结果,每个答案(数字)占一行。
本题可用Splay来维护这个数列,而每次插入操作都将插入的数列建一棵Splay然后并入原树即可。而查询区间以及各种修改操作则依赖于伸展操作。查询区间[l,r]时,将节点l-1转至树根,将节点r+1转至根节点的右子节点,然后root->ch[1]->ch[0]就是要操作的区间,将标记打在该节点即可,在旋转和查询的时候要注意标记的下传和节点的更新。
特殊的,要查询区间的最大连续子序列和:该操作需要对每个节点维护几个值,该节点及其子树中的最大前缀和,最大后缀和,以及最大子序列和;则在更新的时候该节点的最大子序列和=max(左子树的最大后缀+该节点的值,右子树的最大前缀+该节点的值,左子树的最大子序列,右子树的最大子序列,左子树最大后缀+节点权值+右子树最大前缀,该节点权值);这样进行更新,则可以保证该节点维护值的正确性,最大前缀与后缀的维护与其相似。注意在更新前要先将该节点及其子节点的标记下传;
还有一个问题,那就是怎样保证这样做的正确性,我们在修改区间值和翻转时只是将标记打在了子树的根节点上,那么该节点的信息还是正确的么?我们在建树时是递归进行的,显然每个点的信息都是最新的,并且是从底向上进行更新,正确性显然可以保证;在我们打翻转标记的时候,该子树的最大子序列显然是不变的(子序列翻转后相邻的数字是不变的),最大前缀与后缀只是交换了位置而已,最大前缀变成了后缀,后缀变成了前缀,将两个值交换即可;在修改值时,将标记打在节点上,当需要用到该节点或其子树的信息时,将标记下传,同时更新节点的信息,如果修改后的值为正数,则很显然max_sum=max_pre=max_sur=该子树节点数x修改后的值,若为负数,则等于该值。所以,只要将标记打在要修改的子树,在调用时更新即可;还有,在每次操作完成后,将该节点到根的路径更新,保证信息的正确性;
在操作时还有一个问题:如果l==1或r==n怎么办;显然节点l-1与r+1是不存在的。于是我们另设两个节点,分别加在原树的前面与后面,设成一个极小的负数,这样它就不会对其他节点的信息有影响,因为当查询区间时,只有在区间内的值才有效,很明显这两个新节点是无论如何都不会在区间内的;;而唯一查询全部节点的操作时最大子序列,由于新节点的值为极小值,所以显然不会将它们选中。(注意极小值不要太小,不然更新时相加可能会爆int!)
嗯。。就是这样。
还有注意加个特判,,,如果tot==0直接输出0或直接return,,数据有毒。。

代码如下:

#include<iostream>
#include<cstdio>
#include<climits>
#include<queue>
#define N 500010
#define INF 1500000
using namespace std;
struct node { 
    int size,num,sum,pre,sur,max_sum,to; bool sa,re; node *ch[2],*fa;
    void clean(int s) { fa=ch[0]=ch[1]=NULL;size=1;re=sa=0;sum=pre=sur=max_sum=num=s; }
}*head,*null,o[N];
queue<node*> q;int a[500001],tot,top;
int in() {
    int s=0,v=0;char c;
    while((c=getchar())<'0'||c>'9') if(c=='-') v=1;s=c-'0';
    while((c=getchar())>='0'&&c<='9') s=s*10+c-'0';
    return v?-s:s;
}
int check(node *now) { return now->fa->ch[0]==now?0:1; }
void push(node *now) {
    if(now==NULL) return;
    push(now->ch[0]);push(now->ch[1]); q.push(now);now->clean(0);
}
node *out(node *fa,int num) {
    node *x;if(!q.empty()) x=q.front(),q.pop(); else x=o+tot++;x->clean(num); x->fa=fa; return x;
}
void mark_down(node *now) {
    if(now==NULL) return;
    if(now->re) {
      swap(now->ch[0],now->ch[1]);now->re=0;swap(now->pre,now->sur);
      if(now->ch[0]!=NULL) now->ch[0]->re^=1;
      if(now->ch[1]!=NULL) now->ch[1]->re^=1;
    }
    if(now->sa) {
      now->sa=0;
      if(now->ch[0]!=NULL) now->ch[0]->sa=1,now->ch[0]->to=now->to;
      if(now->ch[1]!=NULL) now->ch[1]->sa=1,now->ch[1]->to=now->to;
      now->num=now->to; now->sum=now->size*now->to;
      now->max_sum=now->sur=now->pre=max(now->to,now->to*now->size);
    }
}
void update(node *now) { 
    if(now==NULL) return;
    mark_down(now);
    if(now->ch[0]!=NULL) mark_down(now->ch[0]);
    if(now->ch[1]!=NULL) mark_down(now->ch[1]);
    now->sum=now->num;
    if(now->ch[0]!=NULL) now->sum+=now->ch[0]->sum;
    if(now->ch[1]!=NULL) now->sum+=now->ch[1]->sum;
    now->size=1;
    if(now->ch[0]!=NULL) now->size+=now->ch[0]->size;
    if(now->ch[1]!=NULL) now->size+=now->ch[1]->size;
    now->pre=now->sur=now->max_sum=-INF;
    if(now->ch[0]!=NULL) {
      now->pre=max(now->ch[0]->pre,now->ch[0]->sum+now->num);
      if(now->ch[1]!=NULL) now->pre=max(now->pre,now->ch[0]->sum+now->num+now->ch[1]->pre);
    }
    else {
      now->pre=max(now->pre,now->num);
      if(now->ch[1]!=NULL) now->pre=max(now->pre,now->ch[1]->pre+now->num);
    }
    if(now->ch[1]!=NULL) {
      now->sur=max(now->ch[1]->sur,now->ch[1]->sum+now->num);
      if(now->ch[0]!=NULL) now->sur=max(now->sur,now->ch[1]->sum+now->num+now->ch[0]->sur);
    }
    else {
      now->sur=max(now->sur,now->num);
      if(now->ch[0]!=NULL) now->sur=max(now->sur,now->num+now->ch[0]->sur);
    }
    now->max_sum=max(now->max_sum,now->num);
    if(now->ch[0]!=NULL&&now->ch[1]!=NULL) {
      now->max_sum=max(max(now->ch[0]->max_sum,now->ch[1]->max_sum),now->max_sum);
      now->max_sum=max(now->max_sum,now->ch[1]->pre+now->num+now->ch[0]->sur);
    }
    if(now->ch[0]!=NULL) now->max_sum=max(max(now->max_sum,now->ch[0]->max_sum),now->ch[0]->sur+now->num);
    if(now->ch[1]!=NULL) now->max_sum=max(max(now->max_sum,now->ch[1]->max_sum),now->ch[1]->pre+now->num);
}
void rorate(node *now) {
    node *fa=now->fa;int d=check(now),c=check(fa);
    now->fa=fa->fa;fa->fa->ch[c]=now; fa->ch[d]=NULL;
    if(now->ch[d^1]!=NULL) now->ch[d^1]->fa=fa,fa->ch[d]=now->ch[d^1];
    now->ch[d^1]=fa;fa->fa=now;
    update(fa);update(now); return;
}
void splay(node *now,node *fa) {
    for(;now->fa!=fa;) 
      if(now->fa->fa==fa) rorate(now); 
      else {
        node *x=now->fa;
        if(check(now)==check(x)) rorate(x),rorate(now);
        else rorate(now),rorate(now);
      } return;
}
node *build(int l,int r,node *fa,int d) {
    if(l>r) return NULL;
    int mid=l+r>>1;
    node *now=out(fa,a[mid]);
    now->ch[0]=build(l,mid-1,now,0);
    now->ch[1]=build(mid+1,r,now,1);
    update(now); return now;
}
node *find(int k) {
    node *now=head;mark_down(now); int ti=1;
    if(now->ch[0]!=NULL) ti=now->ch[0]->size+1;
    for(;ti!=k;) {
      if(k>ti) now=now->ch[1],k-=ti;
      else now=now->ch[0];mark_down(now);
      ti=1;if(now->ch[0]!=NULL) ti=now->ch[0]->size+1;
    } return now;
}
void insert() {
    int point=in(),n=in();
    for(int i=1;i<=n;i++) a[i]=in(); node *root=build(1,n,null,0);
    splay(head=find(point+1),null),splay(find(point+2),head);
    head->ch[1]->ch[0]=root;root->fa=head->ch[1];
    update(head->ch[1]);update(head); return;
}
void del() {
    int point=in(),tot=in();
    splay(head=find(point),null);splay(find(point+tot+1),head);
    push(head->ch[1]->ch[0]);head->ch[1]->ch[0]=NULL;
    update(head->ch[1]);update(head); return;
}
void make_same() {
    int point=in(),tot=in(),c=in();
    splay(head=find(point),null);splay(find(point+tot+1),head);
    head->ch[1]->ch[0]->sa=1;head->ch[1]->ch[0]->to=c;
    update(head->ch[1]);update(head);return;
}
void max_sum() { printf("%d\n",head->max_sum);return; }
void reverse() {
    int point=in(),tot=in();
    splay(head=find(point),null);splay(find(point+tot+1),head);
    head->ch[1]->ch[0]->re^=1;update(head->ch[1]);update(head);
}
void get_sum() {
    int point=in(),tot=in();if(tot==0) { printf("0\n");return; }
    splay(head=find(point),null);splay(find(point+tot+1),head);
    update(head->ch[1]->ch[0]);update(head->ch[1]);update(head);
    printf("%d\n",head->ch[1]->ch[0]->sum);
}
int main() {
    null=out(NULL,-INF);null->ch[0]=null->ch[1]=null->fa=null;
    null->size=null->num=null->sum=0;
    int n=in(),m=in();for(int i=1;i<=n;i++) a[i]=in();
    a[0]=a[n+1]=-INF;head=build(0,n+1,null,0);
    while(m--) {
      char ord[10];
      scanf("%s",ord);
      if(ord[0]=='I') insert();
      else if(ord[0]=='D') del();
      else if(ord[2]=='K') make_same();
      else if(ord[2]=='X') max_sum();
      else if(ord[0]=='R') reverse();
      else get_sum();
    }
}

你可能感兴趣的:(splay,区间操作)