小黑和小白在玩游戏。小黑有一个序列,每个元素都形如 2 a i 2^{a_i} 2ai,其中 a i a_i ai为整数。小白每次可以选择序列里连续的一段,然后计算这段区间内所有元素的总和,记为 s s s也就是将这段区间合并为一个数。为了让游戏更有难度,小黑要求小白合并时必须保证 s s s是 2 2 2的若干次幂。
然而,小白不擅长计算,因此她很难找到一个合法的区间。于是她向你求助,想知道对于给定的初始序列,有多少区间可以保证合并后产生的 s s s是 2 2 2的若干次幂。
如果一个区间只有 1 1 1个数,也被视为是合法的区间。
1 ≤ n ≤ 2 × 1 0 5 , 1 ≤ a i ≤ 1 0 9 1\leq n\leq 2\times 10^5,1\leq a_i\leq 10^9 1≤n≤2×105,1≤ai≤109
时间限制 4000 m s 4000ms 4000ms,空间限制 512 M B 512MB 512MB。
首先,求出每个元素的哈希值,然后做 C D Q CDQ CDQ分治。
对于每个区间 [ l , r ] [l,r] [l,r],分别求出左区间的后缀最大值 l m a x i lmax_i lmaxi与后缀和 l s u m i lsum_i lsumi以及右区间的前缀最大值 r m a x i rmax_i rmaxi与前缀和 r s u m rsum rsum。然后,分类讨论合法的区间的最大的数在左区间还是在右区间。设这个最大的数为 2 m x 2^{mx} 2mx,则这个合法区间的区间和只可能为 2 k ( m x ≤ k ≤ k + log n ) 2^k(mx\leq k\leq k+\log n) 2k(mx≤k≤k+logn)。用 i i i遍历左区间,用 j j j遍历右区间,则分类讨论情况如下:
最后,注意 l = = r l==r l==r时对答案的贡献为 1 1 1。
时间复杂度为 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)。
#include
using namespace std;
const int N=200000,P=19260817;
const long long mod1=998244353,mod2=1e9+7;
int n,cl=0,a[N+5],lmx[N+5],rmx[N+5],z[P+5];
int tot=0,l[N+5],r[P+5],hv[N+5],w1[N+5],w2[N+5];
long long ans=0,pw1[N+5],pw2[N+5],pb1[N+5],pb2[N+5],v1[N+5],v2[N+5];
long long ls1[N+5],rs1[N+5],ls2[N+5],rs2[N+5];
long long gt1(int i){return pb1[i/N]*pw1[i%N]%mod1;}
long long gt2(int i){return pb2[i/N]*pw2[i%N]%mod2;}
void init(){
pw1[0]=pw2[0]=1;
for(int i=1;i<=N;i++){
pw1[i]=pw1[i-1]*2%mod1;
pw2[i]=pw2[i-1]*2%mod2;
}
pb1[0]=pb2[0]=1;
for(int i=1;i<=N;i++){
pb1[i]=pb1[i-1]*pw1[N]%mod1;
pb2[i]=pb2[i-1]*pw2[N]%mod2;
}
}
void add(int x,int h1,int h2){
l[++tot]=r[x];
w1[tot]=h1;w2[tot]=h2;hv[tot]=1;
r[x]=tot;
}
void clr(){
++cl;tot=0;
}
void pl(int h1,int h2){
int u=h1%P;
if(z[u]<cl) z[u]=cl,r[u]=0;
for(int i=r[u];i;i=l[i]){
if(w1[i]==h1&&w2[i]==h2){
++hv[i];return;
}
}
add(u,h1,h2);
}
int find(int h1,int h2){
int u=h1%P;
if(z[u]<cl) z[u]=cl,r[u]=0;
for(int i=r[u];i;i=l[i]){
if(w1[i]==h1&&w2[i]==h2) return hv[i];
}
return 0;
}
void solve(int l,int r){
if(l==r){
++ans;return;
}
int mid=l+r>>1;
solve(l,mid);solve(mid+1,r);
lmx[mid]=a[mid];ls1[mid]=v1[mid];ls2[mid]=v2[mid];
for(int i=mid-1;i>=l;i--){
lmx[i]=max(lmx[i+1],a[i]);
ls1[i]=(ls1[i+1]+v1[i])%mod1;
ls2[i]=(ls2[i+1]+v2[i])%mod2;
}
rmx[mid+1]=a[mid+1];rs1[mid+1]=v1[mid+1];rs2[mid+1]=v2[mid+1];
for(int i=mid+2;i<=r;i++){
rmx[i]=max(rmx[i-1],a[i]);
rs1[i]=(rs1[i-1]+v1[i])%mod1;
rs2[i]=(rs2[i-1]+v2[i])%mod2;
}
clr();
for(int i=mid,j=mid+1;i>=l;i--){
while(j<=r&&lmx[i]>rmx[j]){
pl(rs1[j],rs2[j]);++j;
}
long long p1=gt1(lmx[i]),p2=gt2(lmx[i]);
for(int k=0;k<=17;k++){
ans+=find((p1-ls1[i]+mod1)%mod1,(p2-ls2[i]+mod2)%mod2);
p1=p1*2%mod1;p2=p2*2%mod2;
}
}
clr();
for(int i=mid,j=mid+1;j<=r;j++){
while(i>=l&&lmx[i]<=rmx[j]){
pl(ls1[i],ls2[i]);--i;
}
long long p1=gt1(rmx[j]),p2=gt2(rmx[j]);
for(int k=0;k<=17;k++){
ans+=find((p1-rs1[j]+mod1)%mod1,(p2-rs2[j]+mod2)%mod2);
p1=p1*2%mod1;p2=p2*2%mod2;
}
}
}
int main()
{
freopen("sequence.in","r",stdin);
freopen("sequence.out","w",stdout);
init();
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
v1[i]=gt1(a[i]);v2[i]=gt2(a[i]);
}
solve(1,n);
printf("%lld",ans);
return 0;
}