HDU 4747 Mex 区间更新

题意:给一个长度为n的数组a[n],然后定义mex[l,r]为[l,r]这个区间内最小的非负整数,然后求sum(mex[l,r])(1<=l<=r<=n)

先求出mex[1,1]~mex[1,n]的值,然后枚举删掉a[i]后的变化
首先可以知道mex[1,1]~mex[1,n]为非递减的
如果删掉a[1],那么mex[2,2]~mex[2,n]的变化为,下一个a[1]出现前大于a[1]的都要变为a[1],又因为其是非递减的,所以可以找到第一个mex值大于a[i]的那个位置到下一个出现a[1]的位置之前的这段区间赋值为a[1],然后线段树求和就行了。

//author: CHC
//First Edit Time:  2015-07-20 21:09
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <set>
#include <vector>
#include <map>
#include <queue>
#include <set>
#include <algorithm>
#include <limits>
using namespace std;
typedef long long LL;
const int MAXN=200000+1000;
const int INF = numeric_limits<int>::max();
const LL LL_INF= numeric_limits<LL>::max();
#define lson L,mid,rt<<1
#define rson mid+1,R,rt<<1|1
struct Tree {
    LL sum;
    LL Ma;
    int flag;
    LL v;
}tr[MAXN<<2];
void pushup(int rt){
    tr[rt].sum=tr[rt<<1].sum+tr[rt<<1|1].sum;
    tr[rt].Ma=max(tr[rt<<1].Ma,tr[rt<<1|1].Ma);
}
void pushdown(int L,int R,int rt){
    if(L==R)return ;
    int mid=(L+R)>>1;
    if(tr[rt].flag){
        tr[rt<<1].sum=(L-mid+1)*tr[rt].v;
        tr[rt<<1|1].sum=(R-(mid+1)+1)*tr[rt].v;
        tr[rt<<1].v=tr[rt<<1|1].v=tr[rt].v;
        tr[rt<<1].Ma=tr[rt<<1|1].Ma=tr[rt].v;
        tr[rt<<1].flag=tr[rt<<1|1].flag=1;
        tr[rt].flag=0;
        tr[rt].v=0;
    }
}
LL A[MAXN];
void build(int L,int R,int rt){
    memset(&tr[rt],0,sizeof(Tree));
    if(L==R){
        tr[rt].sum=A[L];
        tr[rt].Ma=A[L];
        return ;
    }
    int mid=(L+R)>>1;
    build(lson);build(rson);
    pushup(rt);
}
void update(int L,int R,int rt,int l,int r,LL v){
    pushdown(L,R,rt);
    if(l<=L&&R<=r){
        tr[rt].sum=v*(R-L+1);
        tr[rt].v=v;
        tr[rt].Ma=v;
        tr[rt].flag=1;
        return ;
    }
    int mid=(L+R)>>1;
    if(l<=mid)update(lson,l,r,v);
    if(r>mid)update(rson,l,r,v);
    pushup(rt);
}
LL query(int L,int R,int rt,int l,int r,int flag){
    //printf("%d %d %d %d %d\n",L,R,rt,l,r);
    //system("pause");
    pushdown(L,R,rt);
    if(l<=L&&R<=r){
        //printf("%d %d %I64d\n",L,R,tr[rt].sum);
        if(flag==0)return tr[rt].Ma;
        else return tr[rt].sum;
    }
    int mid=(L+R)>>1;
    LL ans=0;
    if(l<=mid){
        if(flag==0)ans=max(ans,query(lson,l,r,flag));
        else ans+=query(lson,l,r,flag);
    }
    if(r>mid){
        if(flag==0)ans=max(ans,query(rson,l,r,flag));
        else ans+=query(rson,l,r,flag);
    }
    pushup(rt);
    return ans;
}
int resl,resr;
LL queryp(int L,int R,int rt,int l,int r,LL v){
    //printf("%d %d %I64d\n",L,R,tr[rt].Ma);
    pushdown(L,R,rt);
    int mid=(L+R)>>1;
    if(l<=L&&R<=r){
        if(tr[rt].Ma<=v)return -1;
        if(L==R)return L;
        if(resl!=-1&&(resr<L))return -1;
        resl=L;resr=R;
        if(tr[rt<<1].Ma>v){
            return queryp(lson,l,r,v);
        }
        else return queryp(rson,l,r,v);
    }
    int ansl=-1,ansr=-1;
    if(l<=mid)ansl=queryp(lson,l,r,v);
    if(r>mid)ansr=queryp(rson,l,r,v);
    if(ansl==-1)return ansr;
    if(ansr==-1)return ansl;
    return min(ansl,ansr);
}
LL B[MAXN];
int n;
int pre[MAXN],vis[MAXN];
struct node {
    LL val;
    int pos;
}cs[MAXN];
int cmp(const node &x,const node &y){
    if(x.val!=y.val)return x.val<y.val;
    return x.pos<y.pos;
}
int main()
{
    while(~scanf("%d",&n)){
        if(n==0)break;
        memset(vis,0,sizeof(vis));
        for(int i=1;i<=n;i++)scanf("%I64d",&B[i]);
        LL s=0;
        for(int i=1;i<=n;i++){
            if(B[i]<=n)vis[B[i]]=1;
            while(vis[s])++s;
            A[i]=s;
            //printf("%I64d ",A[i]);
        }
        //puts("");
        memset(pre,-1,sizeof(pre));
        for(int i=1;i<=n;i++)cs[i].val=B[i],cs[i].pos=i;
        sort(cs+1,cs+1+n,cmp);
        for(int i=1;i<n;i++){
            if(cs[i].val==cs[i+1].val)pre[cs[i].pos]=cs[i+1].pos;
        }
        //for(int i=1;i<=n;i++)printf("%d ",pre[i]);
        //puts("");
        build(1,n,1);
        LL ans=query(1,n,1,1,n,1);
        for(int i=2;i<=n;i++){
            resl=resr=-1;
            int pos1=queryp(1,n,1,i,n,B[i-1]);
            int pos2=pre[i-1];
            if(pos1!=-1){
                if(pos2==-1){
                    update(1,n,1,pos1,n,B[i-1]);
                }
                else {
                    update(1,n,1,pos1,pos2-1,B[i-1]);
                }
            }
            ans+=query(1,n,1,i,n,1);
        }
        printf("%I64d\n",ans);
    }
    return 0;
}

你可能感兴趣的:(线段树,区间求和,区间更新,区间最值)