题目很好可惜我不会做。。。
看了别人代码半天才懂。。。
看完感觉还是没学到啥。。。
于是写点东西记一下吧。。。
一堆数异或的问题,一个常用的思路是枚举每一位然后瞎搞
这里的做法是,先确定一个前缀,然后想办法求所有数对能得到多少个这个前缀
然后就可以根据求得的个数和剩下还需要求的个数的关系来计算答案跟确定下一个前缀
于是现在问题变成了有一个前缀,怎么确定这些数能搞出多少这个前缀
于是这里可以枚举所有数,对每个数求出,它异或这个前缀得出来的前缀在其他数中有多少。。。这样每对会被计算两遍,答案除个2就行了,值得一提的是不取模的情况下答案也不会爆longlong,那就可以求出最终答案再取模,当然也可以一边算一边模,就是要多写个快速幂求逆
那么显然这个子问题可以用trie解决,复杂度O(nlogn)
至此问题完美解决!总复杂度O(n(logn)^2)
(如果用静态分配内存记得算好要用多少空间。。。要卡得刚刚好。。。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<string>
#include<iomanip>
#include<vector>
#include<set>
#include<map>
#include<queue>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
#define rep(i,k,n) for(int i=(k);i<=(n);i++)
#define rep0(i,n) for(int i=0;i<(n);i++)
#define red(i,k,n) for(int i=(k);i>=(n);i--)
#define sqr(x) ((x)*(x))
#define clr(x,y) memset((x),(y),sizeof(x))
#define pb push_back
#define mod 1000000007
const int maxn=50003;
const int maxnode=maxn*17;
const int maxl=30;
int n,a[maxn];
LL m;
struct trie
{
int next[maxnode][2],sz[maxnode],c[maxnode][maxl+1][2];
int root,L;
int newnode()
{
for(int i=0;i<2;i++)next[L][i]=-1;
sz[L]=0;
clr(c[L],0);
return L++;
}
void init()
{
L=0;
root=newnode();
}
void insert(char str[])
{
int len=strlen(str);
int now=root;
for(int i=0;i<len;i++)
{
sz[now]++;
char s=str[i]-'0';
if(next[now][s]==-1)
next[now][s]=newnode();
now=next[now][s];
}
sz[now]++;
}
void dfs(int x,int k)
{
rep(i,0,1)
{
if(next[x][i]==-1)continue;
int y=next[x][i];
dfs(y,k-1);
rep(j,0,maxl)
{
c[x][j][0]+=c[y][j][0];
c[x][j][1]+=c[y][j][1];
}
c[x][k][i]+=sz[y];
}
}
int f(int x,int k,int dep,int low)
{
if(dep==low)return x;
int nxt=(k>>dep)&1;
if(next[x][nxt]==-1)return -1;
return f(next[x][nxt],k,dep-1,low);
}
LL g1(int pre,int k,LL &tmp)
{
LL ret=0;tmp=0;
rep(i,1,n)
{
int p=f(root,a[i]^pre,maxl,k-1);
if(p==-1)continue;
ret+=sz[p];
tmp+=(LL)pre*sz[p];
rep(j,0,k-1)tmp+=(LL)c[p][j][((a[i]>>j)&1) ^1]<<j;
}
return ret;
}
LL gao(LL k)
{
LL ret=0,cnt=0,tmp=0;
dfs(root,maxl);
int pre=0;
red(i,maxl,0)
{
pre|=1<<i;
cnt=g1(pre,i,tmp);
if(cnt<=k)
{
k-=cnt;
ret+=tmp;
pre^=1<<i;
}
}
cnt=g1(pre,0,tmp);
ret+=tmp/cnt*k;
return ret;
}
}tr;
inline void convert(int x,char s[])
{
for(int i=30;i>=0;i--)
{
s[i]='0'+x%2;
x>>=1;
}
if(x)cout<<"re"<<endl;
s[31]=0;
}
int main()
{
scanf("%d%I64d",&n,&m);
char s[40];
tr.init();
rep(i,1,n)
{
scanf("%d",&a[i]);
convert(a[i],s);
tr.insert(s);
}
cout<<(tr.gao(m<<1)>>1)%mod<<endl;
return 0;
}