【HDU 4747 Mex】线段数

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4747

 

题意:有一组序列a[i](1<=i<=N), 让你求所有的mex(l,r), mex(l,r)表示区间[l,r]中最小的未在序列中出现的非负整数。

 

思路:冥思苦想半天无想法,白做了那么多线段树。 很明显的维护区间问题,容易想到线段树,比较难想到操作。 枚举一个序列的所mex(1,i),mex(2,i)……可以发现序列mex(x,i)是一个单调递增序列,我们需要求得就是所有以x开头的序列和,mex(x,i)(x<=i<=n)。这点确定了就好办了,记录每个位置的数后面最早重复出现的位置next[x],如果无则为设n+1。那么我们就可以发现,当第x个数所对应的序列 mex(x,i)(x<=i<=n)所对应的序列求完之后,删去此位置的数,位置x+1~next[x]-1序列中mex值大于a[x]的都改为a[x],因为a[x]没有了,下一个a[x]还未出现,所以可以证明这样做是正确的。从1到n扫一遍亦求出了所有的mex()。

基本上所有的操作都可以用到线段树。开始没有想到一点的是如何找序列中刚好大于a[x]的位置,并且此位置到next[x]-1赋值为a[x],怎么都没想到log(n)的操作,其实这里依然可以用到线段树,因为序列是单调递增的,另开一个区间维护序列mavv[u]表示区间中最大的mex值,随着询问以及其他操作成段更新即可。

 

  1 #include <iostream>

  2 #include <cstdio>

  3 #include <cmath>

  4 #include <map>

  5 #include <algorithm>

  6 #include <cstring>

  7 #include <sstream>

  8 using namespace std;

  9 

 10 #define lz 2*u,l,mid

 11 #define rz 2*u+1,mid+1,r

 12 typedef long long lld;

 13 const int maxn=222222;

 14 int a[maxn], b[maxn], next[maxn];

 15 lld sum[4*maxn], mavv[4*maxn], flag[4*maxn];

 16 map<int,int>mp;

 17 

 18 void push_up(int u, int l, int r)

 19 {

 20     sum[u]=sum[2*u]+sum[2*u+1];

 21     mavv[u]=mavv[2*u+1];

 22 }

 23 

 24 void push_down(int u, int l, int r)

 25 {

 26     int mid=(l+r)>>1;

 27     if(flag[u]!=-1)

 28     {

 29         flag[2*u]=flag[2*u+1]=flag[u];

 30         mavv[2*u]=mavv[2*u+1]=flag[u];

 31         sum[2*u]=(lld)(mid-l+1)*flag[u];

 32         sum[2*u+1]=(lld)(r-mid)*flag[u];

 33         flag[u]=-1;

 34     }

 35 }

 36 

 37 void build(int u, int l, int r)

 38 {

 39     flag[u]=-1;

 40     int mid=(l+r)>>1;

 41     if(l==r)

 42     {

 43         sum[u]=mavv[u]=b[l];

 44         return ;

 45     }

 46     build(lz);

 47     build(rz);

 48     push_up(u,l,r);

 49 }

 50 

 51 void Update(int u, int l, int r, int tl, int tr, int val)

 52 {

 53     if(tl>tr) return ;

 54     if(tl<=l&&r<=tr)

 55     {

 56         mavv[u]=val;

 57         sum[u]=(lld)val*(r-l+1);

 58         flag[u]=val;

 59         return ;

 60     }

 61     push_down(u,l,r);

 62     int mid=(l+r)>>1;

 63     if(tr<=mid) Update(lz,tl,tr,val);

 64     else if(tl>mid) Update(rz,tl,tr,val);

 65     else

 66     {

 67         Update(lz,tl,mid,val);

 68         Update(rz,mid+1,tr,val);

 69     }

 70     push_up(u,l,r);

 71 }

 72 

 73 int find(int u, int l, int r, int tmp)

 74 {

 75     if(l==r) return l;

 76     push_down(u,l,r);

 77     int mid=(l+r)>>1;

 78     if(mavv[2*u]>tmp) return find(lz,tmp);

 79     else return find(rz,tmp);

 80 }

 81 

 82 int main()

 83 {

 84     int n;

 85     while(cin >> n,n)

 86     {

 87         for(int i=1; i<=n; i++) scanf("%d",a+i);

 88         mp.clear();

 89         for(int i=n; i>=1; i--)

 90         {

 91             if(mp[ a[i] ]) next[i]=mp[ a[i] ];

 92             else next[i]=n+1;

 93             mp[ a[i] ]=i;

 94         }

 95         mp.clear();

 96         int x=0;

 97         for(int i=1; i<=n; i++)

 98         {

 99             mp[ a[i] ]=1;

100             while(mp[x]) ++x;

101             b[i]=x;

102         }

103         build(1,1,n);

104         lld ans=0;

105         for(int i=1; i<=n; i++)

106         {

107             ans+=sum[1];

108             if(mavv[1]>a[i])

109             {

110                 int id=find(1,1,n,a[i]);

111                 Update(1,1,n,max(id,i+1),next[i]-1,a[i]);

112             }

113             Update(1,1,n,i,i,0);

114         }

115         cout << ans <<endl;

116     }

117 }
View Code

 

 

你可能感兴趣的:(HDU)