给出一个长度为 n n n的序列 A A A,再给出一个整数 x x x,如果一个子序列满足以下的条件,则它是一个符合条件的子序列:
求符合条件的子序列的个数,模 998244353 998244353 998244353。
两个子序列不同,当且仅当它们取自于原序列中的位置中有至少一个位置不同。
1 ≤ n ≤ 3 × 1 0 5 , 1 ≤ A i ≤ 2 60 , 1 ≤ x ≤ 2 60 1\leq n\leq 3\times 10^5,1\leq A_i\leq 2^{60},1\leq x\leq 2^{60} 1≤n≤3×105,1≤Ai≤260,1≤x≤260
令 x x x的最高位为第 d d d位。如果有两个数,它们在第 d d d位以上的部分不同,那么显然它们异或之后一定有一位比第 d d d位高且为 1 1 1。也就是说,如果有两个数,它们在第 d d d位以上的部分不同,那么它们不会有冲突。
我们把序列 A A A中具有相同的第 d d d位以上的部分的数放在一起,那么整个序列就被分成若干个同前缀序列。我在其中一个同前缀序列中选符合条件的数,并不影响我在另一个同前缀序列中选的数。所以我们可以先把每个同前缀序列中符合条件的子序列个数求出,然后求积,即可得出答案。
那么怎么求一个同前缀序列中符合条件的子序列个数呢?
因为分在一个同前缀序列中的数的第 d d d位以上的部分相同,所以任意两个数的异或和的最高位一定不高于第 d d d位。假如当前的子序列有两个数,那么这两个数的第 d d d位一定满足一个是 0 0 0。一个是 1 1 1,否则异或和的第 d d d位为 0 0 0,也就小于 x x x。如果我们要再加一个数,那么无论这个数的第 d d d位是 0 0 0还是 1 1 1,它与先前的数异或之后一定会有一种情况使得第 d d d位为 0 0 0。所以,可以证明,在一个同前缀序列中符合条件的子序列的数的个数最多只有两个。
数的个数为 0 0 0或 1 1 1的子序列的个数很好求,分别为 1 1 1和该同前缀序列的大小。若子序列的数的个数为 2 2 2,那么对于每一个数,我们需要找到另一个数,使得这两个数的异或和大于等于 x x x。
我们可以用字典树来维护。假设当前枚举的数为 v v v,那么在字典树中:
求出每一个同前缀序列的答案求出后,将每个同前缀序列的答案求积,即为最终的答案。
不要忘了最后减去序列为空的情况。
时间复杂度为 O ( 60 × n ) O(60\times n) O(60×n)。
#include
using namespace std;
int n,v1=0,tot,siz[15000005],ch[15000005][2];
long long x,tx,s=1,ans=1,now,mi[65],a[300005],v[300005];
long long mod=998244353;
vector<long long>w[300005];
void pt(long long v){
int q=1,vq;
for(int i=60;i>=0;i--){
vq=(v>>i)&1;
if(!ch[q][vq]) ch[q][vq]=++tot;
q=ch[q][vq];
++siz[q];
}
}
void find(long long v){
int q=1,vq;
for(int i=60;i>=0;i--){
vq=(v>>i)&1;
if((x>>i)&1){
if(!ch[q][vq^1]) return;
q=ch[q][vq^1];
}
else{
if(ch[q][vq^1]) now=(now+siz[ch[q][vq^1]])%mod;
if(!ch[q][vq]) return;
q=ch[q][vq];
}
}
now=(now+siz[q])%mod;
}
void cl(int x){
if(ch[x][0]) cl(ch[x][0]);
if(ch[x][1]) cl(ch[x][1]);
ch[x][0]=ch[x][1]=0;
siz[x]=0;
}
int main()
{
scanf("%d%lld",&n,&x);
mi[0]=1;
for(int i=1;i<=60;i++) mi[i]=mi[i-1]*2;
tx=x;
while(x){
x>>=1;s<<=1;
}
s=mi[60]-s;
x=tx;
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
sort(a+1,a+n+1);
for(int i=1;i<=n;i++){
if(v1==0||(a[i]&s)!=v[v1]){
v[++v1]=(a[i]&s);
}
w[v1].push_back((a[i]|s)^s);
}
for(int i=1;i<=v1;i++){
int l=w[i].size();
now=l+1;
tot=1;
for(int j=0;j<l;j++){
find(w[i][j]);
pt(w[i][j]);
}
ans=ans*now%mod;
cl(1);
}
ans=(ans+mod-1)%mod;
printf("%lld",ans);
return 0;
}