4 1024 512 256 256 4 1024 1024 1024 1024 5 1024 512 512 512 1 0
Case #1: 1 Case #2: 11 Case #3: 8HintIn the first case, we should choose all the numbers. In the second case, all the subsequences which contain more than one number are good.
题目大意
给定一个数的集合,问其中有多少子集满足:
两个相同的数字可以加在一起,重复若干次这样的操(可为0次)作得到2048。结果对998244353取模
解题思路
去除不为2的整数次幂的所有数字(设有fr个),将最后结果乘上2^fr(是否取这些数字不影响2048的操作)
枚举
dp[i][j]表示取到第i个2^i的数,其最大的和在j*2^i至(j+1)*2^i-1的方案数
答案就是可的总方案减去dp[12][0] (即结果落在0~2047的方案个数)
对于每个dp[i][j]
我们有枚举下一步状态时dp[i+1][k]增加的方案数为dp[i][j]*C(cnt[i],k-j/2)
其中,cnt[i]为2^i出现的次数,C为组合数,998244353是质数所以组合数采用逆元处理。
PS:
原题标答自带的输入优化跑了500ms……
读入优化之后……
G++反复提交通不过
C++1A……1390ms……
code:
#include <cstdio> #include <iostream> #include <algorithm> #include <ctime> #include <cctype> #include <cmath> #include <string> #include <cstring> #include <stack> #include <queue> #include <list> #include <vector> #include <map> #include <set> #define sqr(x) ((x)*(x)) #define LL long long #define INF 0x3f3f3f3f #define PI acos(-1.0) #define eps 1e-10 #define mod 998244353ll using namespace std; inline int ReadInt() { int flag=0; char ch=getchar(); int data=0; while (ch<'0'||ch>'9') { if (ch=='-') flag=1; ch=getchar(); } do { data=data*10+ch-'0'; ch=getchar(); }while (ch>='0'&&ch<='9'); return data; } void egcd(LL a,LL b,LL &x,LL &y) { if (b==0) { x=1,y=0; return ; } egcd(b,a%b,x,y); LL t=x; x=y;y=t-a/b*x; } int a[100005]; LL inv[100005]; LL cnt[15],fr; int base[3000]; LL dp[20][2050]; int lim[20]; LL mypow(LL x,LL y) { LL res=1; LL mul=x; while (y) { if (y&1) res=res*x%mod; x=x*x%mod; y>>=1; } return res; } int main() { lim[0]=2048; for (int i=1;i<=12;i++) lim[i]=2048>>(i-1); for (LL i=1;i<=100000;i++) { LL x,y; egcd(i,mod,x,y); x=(x+mod)%mod; inv[i]=x; } int n,ca=0; while (~scanf("%d",&n),n) { fr=n; memset(base,0,sizeof base); for (int i=1;i<=n;i++){ scanf("%d",&a[i]); base[a[i]]++; } memset(cnt,0,sizeof cnt); for (int i=1;i<=12;i++) { cnt[i]=base[1<<(i-1)]; fr-=cnt[i]; } memset(dp,0,sizeof dp); dp[0][0]=1; for (int i=1;i<=12;i++) { for (int j=0;j<=lim[i-1];j++) if (dp[i-1][j]) { LL cal=1; for (int k=j/2;k<lim[i];k++) { int pos=j/2; if ((k-pos)>cnt[i]) break; dp[i][k]+=dp[i-1][j]*cal%mod; dp[i][k]%=mod; cal=(cal*(LL)(cnt[i]-((k-pos))))%mod; cal=(cal*inv[(k-pos)+1])%mod; } } } LL ans; ans=(mypow(2ll,(LL)n-fr)-dp[12][0]+mod)%mod; ans=ans*mypow(2ll,fr)%mod; printf("Case #%d: %I64d\n",++ca,ans); } return 0; }