题目链接:
http://acm.hdu.edu.cn/showproblem.php?pid=3693
题目意思:
有n个数,(n<=50),m1,m2,m3,......,mn (mi<=2^31-1),给定k,求x1,x2,x3,...,xn,(xi<=mi)的种数,使得x1^x2^x3...^xn=k。
解题思路:
数位dp.
首先归纳出一个很重要的性质:若x1^x2^x3^..xn=k,假设xi很大,则总数为(m1+1)*(m2+1)*,..*(m[i-1]+1)*(m[i+1]+1)*...*(mn+1).
因为对于除了xi的n-1个数无论取什么值,都可以通过x1^x2^...^x(i-1)^(xi+1)^....^xn^k得到唯一的xi。
通过这个性质,可以对每一位1的的取值枚举。通过控制超过枚举位置的值更新来避免重复。
dp[i][j]表示前i个数中当前位含有j个1的总的种数。
代码解释的很详细。
PS:思维变换的数位dp.
代码:
#include<iostream> #include<cmath> #include<cstdio> #include<sstream> #include<cstdlib> #include<string> #include<cstring> #include<algorithm> #include<vector> #include<map> #include<set> #include<stack> #include<list> #include<queue> #include<ctime> #include<bitset> #define eps 1e-6 #define INF 0x3f3f3f3f #define PI acos(-1.0) #define ll __int64 #define LL long long #define lson l,m,(rt<<1) #define rson m+1,r,(rt<<1)|1 #pragma comment(linker, "/STACK:1024000000,1024000000") using namespace std; #define Maxn 55 #define m 30 #define M 1000000003 int sa[Maxn],n,k,num; ll sum[m+5]; ll dp[Maxn][Maxn],bina[m+5]; void init() { bina[0]=1; for(int i=1;i<=m;i++) bina[i]=bina[i-1]*2; } int Sum1(ll tmp) { int res=0; for(int i=1;i<=n;i++) if(sa[i]&tmp) res++; return res; } ll left(int a,int b) //0到第a号元素的低b位 { if(!b) return 1; //0 else return (sa[a]&(bina[b]-1))+1; //加上一个0 } ll Cal(int pos,int num,int s,int need) { if(s==0)//计算当前位满的情况 也就是num个1全部用上时 { if((num&1)!=need) //不满足 return 0; else { if(pos==0) //第一位 return 1; else //根据低位的情况,得出当前位 return sum[pos-1]; } } dp[0][0]=1; //dp[i][j]表示前i个数中当前位含有j个1的总的种数 for(int i=1;i<=n;i++) for(int j=0;j<num;j++) { dp[i][j]=0; if(bina[pos]&sa[i]) //当前位为1 { if(i==s) dp[i][j]=dp[i-1][j]%M; else if(i>s) //后面的 通过控制每次只考虑后面的,避免重复计数 dp[i][j]=(dp[i-1][j]*bina[pos])%M; //当前位放0,后面的低pos-1位可以任意放 } else // dp[i][j]=(dp[i-1][j]*left(i,pos))%M; //放0 if(bina[pos]&sa[i]) //放1 { if(i!=s) { if(j) dp[i][j]=(dp[i][j]+dp[i-1][j-1]*left(i,pos))%M; } } } ll res=0; for(int i=0;i<num;i++) if((i&1)==need) //统计1的个数 res=(res+dp[n][i])%M; return res; } int main() { init(); while(~scanf("%d%d",&n,&k)) { if(n+k==0) break; for(int i=1;i<=n;i++) scanf("%d",&sa[i]); memset(sum,0,sizeof(sum)); for(int pos=0;pos<=m;pos++) //从最低位开始 { ll tmp=bina[pos]; num=Sum1(tmp); int need; if(k&tmp) need=1; else need=0; for(int i=0;i<=n;i++) if(!i||(tmp&sa[i]))//枚举是1的情况 sum[pos]=(sum[pos]+Cal(pos,num,i,need))%M; } printf("%I64d\n",sum[m]); } return 0; }