bzoj 1558: [JSOI2009]等差数列 (线段树)

题目描述

传送门

题目大意:给定一个长度为N的数列,初始时第i个数为vi。
操作(1)A s t a b在序列的[s,t]区间上加上初值为a,步长为b的等差数列。即vi变为vi+a+b*(i-s)
操作(2)B s t询问当前序列的[s,t]区间最少能划分成几段,使得每一段都是等差数列。

题解

首先可以想到的是线段树维护的是差分后的值, vali=vi+1vi ,而不是原序列的值。
那么对于修改操作来说可以分成两部分
(1)单点修改, valx1+=a,valy=((yx)b+a)
(2)区间修改 val[x,y1]+=b

比较麻烦的是操作二,刚开始看的时候以为就是直接维护区间颜色段数。对于两个等差数列的首尾差是单独的颜色可以忽略不计。所以就需要各种分情况讨论。
大体上的思路是维护区间左右两端单独数的个数(没有连续的相同数值) ,以及除去两端单独数中间部分的答案。
随便写个数列,划分一下会发现,如果某端的单独数个数为 p ,那么最少可以划分成 (p+1)/2
中间的部分在合并答案的时候需要各种分类讨论,我把所有需要特判的特殊情况注释到了程序中
注释中的第一行是差分后的答案
注释中的第二行是符合差分的合法序列,和最小的划分方案。

代码

#include
#include
#include
#include
#include
#define N 100003
using namespace std;
struct data{
    int sum,ls,rs,tag,ll,rr,all;
}tr[N*4];
int n,m,a[N],val[N],pos[N];
data update(data l,data r)
{
    data now; now.tag=0;
    now.ls=l.ls; now.rs=r.rs;//ls,rs表示区间左右端点的数值 
    now.ll=l.ll; now.rr=r.rr;//ll,rr表示区间左右两端单独数的个数(没有连续的相同数值) 
    if (l.all) //all表示区间中是否都是单独数 
     if (l.rs==r.ls) now.ll--;
     else now.ll+=r.ll;
    if (r.all)
     if (l.rs==r.ls) now.rr--;
     else now.rr+=l.rr;
    now.sum=l.sum+r.sum;
    if (l.all&&r.all&&l.rs!=r.ls) now.all=1;
    else now.all=0;
    if (l.all) {
        if (l.rs==r.ls) {
            if (r.all) now.sum=1;// 1 2 3 | 3 2 1
            else if (r.ll) now.sum+=(r.ll-1)/2+1; // 1 2 3 5 4 | 4 2 3 3 2 
            // (1 2) (4 7) |(12 16 20) (22 25 28)| (30)
        }
    }  
    else if (l.rr) {
        if (l.rs==r.ls){
            if (r.all) now.sum+=(l.rr-1)/2+1;// 1 2 2 3|3 2 1 3
            // (1) |(2 4 6) (9 12)| (14 15) (18)
            else if (r.ll) now.sum+=(r.ll-1)/2+(l.rr-1)/2+1; // 1 2 2 3 | 3 2 2 1
            // (1) |(2 4 6) (9) (12 14 16)| (17)
            else now.sum+=(l.rr-1)/2; //1 2 2 3 |3 3 3 1
            // (1) (2 4 6) (9 12 15 18) (19)
        }
        else if (!r.all) now.sum+=(l.rr+r.ll)/2; // 1 2 2 3 | 4 5 5 7 
        //1 |(2 4 6) (9 13) (18 23)| 30
    }
    else {
        if (l.rs==r.ls) {
            if (!r.all) {
                if (r.ll) now.sum+=(r.ll-1)/2; // 1 2 2 | 2 3 3
                // (1) (2 4 6 8) (11 14)
                else now.sum--; //1 2 2| 2 2 1
                // (1) (2 4 6 8 10) (11)
            }
        }
        else {
            if (!r.all&&r.ll) now.sum+=r.ll/2; //1 2 2| 3 2 2
            // (1) (2 4 6) (9 11 13)
        }
    }
    return now;
}
void build(int now,int l,int r)
{
    if (l==r) {
        tr[now].sum=0; tr[now].ll=tr[now].rr=tr[now].all=1;
        tr[now].ls=tr[now].rs=a[l];
        pos[l]=now;
        return;
    }
    int mid=(l+r)/2;
    build(now<<1,l,mid);
    build(now<<1|1,mid+1,r);
    tr[now]=update(tr[now<<1],tr[now<<1|1]);
}
void change(int now,int val)
{
    tr[now].ls+=val; tr[now].rs+=val;
    tr[now].tag+=val;
}
void pushdown(int now)
{
    if (tr[now].tag){
        change(now<<1,tr[now].tag);
        change(now<<1|1,tr[now].tag);
        tr[now].tag=0;
    }
}
data qjsum(int now,int l,int r,int ll,int rr)
{
    if (ll<=l&&r<=rr) return tr[now];
    int mid=(l+r)/2;
    pushdown(now); data ans; bool pd=false;
    if (ll<=mid) ans=qjsum(now<<1,l,mid,ll,rr),pd=true;
    if (rr>mid) {
        if (!pd) ans=qjsum(now<<1|1,mid+1,r,ll,rr);
        else ans=update(ans,qjsum(now<<1|1,mid+1,r,ll,rr));
    }
    return ans;
}
void pointchange(int now,int l,int r,int x,int val)
{
    if (l==r) {
        tr[now].ls=tr[now].rs+=val;
        tr[now].sum=0; tr[now].ll=tr[now].rr=tr[now].all=1;
        return;
    }
    int mid=(l+r)/2;
    pushdown(now);
    if (x<=mid) pointchange(now<<1,l,mid,x,val);
    else pointchange(now<<1|1,mid+1,r,x,val);
    tr[now]=update(tr[now<<1],tr[now<<1|1]);
}
void qjadd(int now,int l,int r,int ll,int rr,int val)
{
    if (ll>rr) return;
    if (ll<=l&&r<=rr) {
        change(now,val);
        return;
    }
    int mid=(l+r)/2;
    pushdown(now);
    if (ll<=mid) qjadd(now<<1,l,mid,ll,rr,val);
    if (rr>mid) qjadd(now<<1|1,mid+1,r,ll,rr,val);
    tr[now]=update(tr[now<<1],tr[now<<1|1]);
}
int main()
{
    freopen("a.in","r",stdin);
    freopen("my.out","w",stdout);
    scanf("%d",&n);
    for (int i=1;i<=n;i++) scanf("%d",&val[i]);
    for (int i=1;i<=n-1;i++) a[i]=val[i+1]-val[i];
    n--; 
    if (n) build(1,1,n);
    scanf("%d",&m);
    for (int i=1;i<=m;i++) {
        char s[10]; int x,y; scanf("%s%d%d",s+1,&x,&y);
        if (s[1]=='A') {
            int a,b; scanf("%d%d",&a,&b);
            if (x-1>=1&&x-1<=n) pointchange(1,1,n,x-1,a); 
            if (y<=n) pointchange(1,1,n,y,-((y-x)*b+a));
            qjadd(1,1,n,x,y-1,b);
        }
        if (s[1]=='B') {
           data t; 
           if (x<=y-1) t=qjsum(1,1,n,x,y-1);
           else t.sum=1;
           if (x==y) printf("%d\n",1);
           else {
             if (t.sum==0) printf("%d\n",(y-x+2)/2);
             else printf("%d\n",t.sum+(t.ll+1)/2+(t.rr+1)/2);
           }    
        }
    }
}

你可能感兴趣的:(线段树)