题目大意:给你n个互不相同的数,保证其不大于m。要求你再确定一串数p,满足:1、p个数可以选出一些集合(相同的数可重复选),其和可以组成所有给定的n个数。2、从p中选择任意数集(相同的数可重复选),若其和不大于m,那么这个和必在给定的n个数中。3、在满足1和2的情况下p最小。
n,m<=10^6 , 8s
(若不是tag上友情提示是fft,蒟蒻怎么也不会想到正解)
不难发现p数一定是n中的数(废话),那么n中的任意一个数的任意倍(不大于m)一定也在这n个数中——否则no solution。
这样,我们构造一个多项式:k0*x0+k1*x1+k2*x2+...+km*xm,如果i在这n个数中,那么ki=1,否则ki=0,特别的k0=1
求这个多项式的平方,我们得到的多项式有什么意义呢?——在p个数中选出两个数求和,ki表示和为i的方案数*2。
选出两个数是这样,我们是不是要求三次方、四次方乃至更多的呢?
不要!其实第一步我们保证了n中任意一个数t,若kt<=m,那么kt在这n个数中。那么也就是说对于n中的两个数a,b,若pa+qb<=m,那么pa+qb前的系数必定大于1*2
最后我们只要看平方得到的系数ki,若i未在原来的n个数中出现且ki>1*2,no solution。
若i在原来的n个数中出现且ki>1*2,最后的p个数中不需要此数,否则一定需要。
求平方就用fft吧。。。。
(好题点赞!)
#include <cmath> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int Maxn=4000005; int n,m,N,M,i,j,k,t,tmp,dig[25],rev[Maxn],ans; bool vis[Maxn],v[Maxn]; struct CP { double x,y; CP operator +(const CP &a)const{ return (CP){x+a.x,y+a.y}; } CP operator -(const CP &a)const{ return (CP){x-a.x,y-a.y}; } CP operator *(const CP &a)const{ return (CP){x*a.x-y*a.y,x*a.y+y*a.x}; } } a[Maxn]; void FFT(CP a[],int flag){ for (i=0;i<N;i++) if (i<rev[i]) swap(a[i],a[rev[i]]); for (i=2;i<=N;i<<=1){ CP wn = (CP) { cos(2*M_PI/i), flag*sin(2*M_PI/i) }; for (j=0;j<N;j+=i){ CP w = (CP) {1,0}; for (k=j;k<j+i/2;k++){ CP x=a[k], y=a[k+i/2]*w; a[k]=x+y; a[k+i/2]=x-y; w=w*wn; } } } if (flag==1) return; for (i=0;i<N;i++) a[i].x/=N; } bool Judge(){ for (N=2, M=1;N<(m<<1);N<<=1, M++); for (i=0;i<N;i++){ int len=0; for (j=i;j>0;j>>=1) dig[len++]=(j&1); for (j=0;j<M;j++) rev[i] = ((rev[i]<<1)|dig[j]) ; } FFT(a,1); for (i=0;i<N;i++) a[i]=a[i]*a[i]; FFT(a,-1); for (i=1;i<=m;i++) if (!vis[i]){ tmp = int (a[i].x+0.5); if (tmp!=0) return 0; } return 1; } int main(){ //freopen("286E.in","r",stdin); //freopen("286E.out","w",stdout); scanf("%d%d",&n,&m); for (i=1;i<=n;i++){ scanf("%d",&t); vis[t]=1; a[t].x=1; } a[0].x=1; bool flag=1; for (i=1;i<=m;i++) if (!v[i] && vis[i]){ for (j=i;j<=m;j+=i) if (!vis[j]) {flag=0;break;} else v[j]=1; if (!flag) break; } if (!flag || !Judge()){ printf("NO\n"); return 0; } printf("YES\n"); for (i=1;i<=m;i++) if (vis[i]){ tmp = int (a[i].x+0.5); if (tmp<=2) ans++; } printf("%d\n",ans); for (i=1;i<=m;i++) if (vis[i]){ tmp = int (a[i].x+0.5); if (tmp<=2) ans--, printf("%d%c",i,ans>0?' ':'\n'); } return 0; }