题目链接
emmm,只要做过付公主的背包,这题就还算清真了
首先还是对生成函数求逆,我们知道 l n ( f ( x ) ) = ∑ i = 1 n a i x ∗ [ x   % a i   = 0 ] ln(f(x))=\sum_{i=1}^{n}{\frac{a_i}{x}*[x\, \%a_i\,=0]} ln(f(x))=∑i=1nxai∗[x%ai=0]
假设我们已经求出了这个逆
感受一下上面那个式子,显然求逆以后第一个有值的位置对应的数字一定是被选的,这个应该很好理解。
然后这个数字i会对他的每个倍数j产生 i j \frac{i}{j} ji的贡献
把他从对应的位置减掉就可以了,然后显然下一个要选的肯定是他之后第一个有值的位置,如此搞下去可以调和级数求解。
最后把有数字的位置输出来就可以了
这个用贪心的思路去想一定是字典序最小的最优解。
然后每次都被卡常……
三模NTT要整整十秒
#include
#define gg 3
#define N 600030
using namespace std;
long long ans[N],f[3][N],g[3][N],mod1[]={998244353,469762049,1004535809};
long long inv[N],tmp1[N],tmp2[N],s[N],ln[N],de[N],in[N];
int r[N],n,m,p;
inline long long mul(long long a,long long b,long long mod)
{
long long res=a*b-(long long)((long double)a*b/mod+0.5)*mod;
return res<0?res+mod:res;
}
long long kasumi(long long a,long long b,long long mod)
{
long long ans=1;
while(b)
{
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
void NTT(long long *a,int kd,int mod,int lim)
{
for(int i=0;i<lim;i++)
{
if(i<r[i]) swap(a[i],a[r[i]]);
}
for(int mid=1;mid<lim;mid<<=1)
{
long long wn=kasumi(gg,(mod-1)/(mid<<1),mod);
if(kd) wn=kasumi(wn,mod-2,mod);
for(int i=0;i<lim;i+=mid<<1)
{
long long w=1;
for(int j=0;j<mid;j++,w=wn*w%mod)
{
long long x=a[i+j];
long long y=a[i+j+mid]*w%mod;
a[i+j]=(x+y)%mod;
a[i+j+mid]=(x-y+mod)%mod;
}
}
}
if(kd)
{
int inv=kasumi(lim,mod-2,mod);
for(int i=0;i<lim;i++) a[i]=a[i]*inv%mod;
}
}
void mul1(long long *a,long long *b,int cnt)
{
int lim=1<<cnt;
for(int i=0;i<lim;i++)
{
f[0][i]=f[1][i]=f[2][i]=a[i];
g[0][i]=g[1][i]=g[2][i]=b[i];
ans[i]=0;
}
for(int i=0;i<lim;i++)
{
r[i]=(r[i>>1]>>1)|((i&1)<<(cnt-1));
}
for(int i=0;i<=2;i++)
{
NTT(f[i],0,mod1[i],lim);NTT(g[i],0,mod1[i],lim);
for(int j=0;j<lim;j++)
{
f[i][j]=f[i][j]*g[i][j]%mod1[i];
}
NTT(f[i],1,mod1[i],lim);
}
long long inv1=kasumi(mod1[0],mod1[1]-2,mod1[1]);
long long inv2=kasumi(mod1[1],mod1[0]-2,mod1[0]);
long long mul1=mod1[0]*mod1[1];
for(int i=0;i<lim;i++)
{
ans[i]+=mul(f[0][i]*inv2%mul1,mod1[1],mul1);
ans[i]+=mul(f[1][i]*inv1%mul1,mod1[0],mul1);
ans[i]%=mul1;
}
long long inv3=kasumi(mul1%mod1[2],mod1[2]-2,mod1[2]);
for(int i=0;i<lim;i++)
{
ans[i]=((f[2][i]-ans[i]%mod1[2]+mod1[2])%mod1[2]*inv3%mod1[2]*(mul1%p)%p+ans[i]%p)%p;
}
}
void der(const long long *a,int cnt)
{
int lim=1<<cnt;
for(int i=0;i<lim;i++)
{
de[i]=a[i+1]*(i+1)%p;
}
}
void inte(const long long *a,int cnt)
{
int lim=1<<cnt;
for(int i=lim-1;i>=1;i--)
{
in[i]=a[i-1]*kasumi(i,p-2,p)%p;
}
}
void get_inv(const long long *a,int len)
{
int cnt=0;
for(int i=0;i<len;i++) tmp1[i]=tmp2[i]=inv[i]=0;
inv[0]=kasumi(a[0],p-2,p);
int lim=1;
while(lim<len)
{
cnt++;
lim<<=1;
for(int i=0;i<lim>>1;i++)
{
tmp1[i]=inv[i];
tmp2[i]=a[i];
}
mul1(tmp1,tmp2,cnt);
for(int i=0;i<lim>>1;i++)
{
tmp2[i]=ans[i];
}
mul1(tmp1,tmp2,cnt);
for(int i=0;i<lim>>1;i++)
{
inv[i]=(inv[i]*2ll%p-ans[i]+p)%p;
}
}
}
void get_ln(const long long *a,int len)
{
int lim=1,cnt=0;
while(lim<len) lim<<=1,cnt++;
get_inv(a,lim);
der(a,cnt);
mul1(inv,de,cnt);
inte(ans,cnt);
}
int main()
{
scanf("%d%d",&n,&p);
s[0]=1;
for(int i=1;i<=n;i++)
{
scanf("%lld",&s[i]);
}
get_ln(s,(n+1)*2);
for(int i=1;i<=n;i++) in[i]=in[i]*i%p;
for(int i=1;i<=n;i++)
{
for(int j=i*2;j<=n;j+=i)
{
in[j]=(in[j]-in[i]+p)%p;
}
}
int ans=0;
for(int i=1;i<=n;i++)
{
if(in[i]) ans++;
}
printf("%d\n",ans);
for(int i=1;i<=n;i++) if(in[i]) printf("%d ",i);
}
换上不加黑科技的拆系数FFT就三秒左右了
#include
#define sz 32768
#define N 600030
using namespace std;
long long ans[N];
long long inv[N],tmp1[N],tmp2[N],s[N],ln[N],de[N],in[N];
int r[N],n,m,p;
long long kasumi(long long a,long long b,long long mod)
{
long long aa=1;
while(b)
{
if(b&1) aa=aa*a%mod;
a=a*a%mod;
b>>=1;
}
return aa;
}
const long double pi=std::acos(-1);
struct comp
{
long double r,i;
comp(){}
comp(long double a,long double b):r(a),i(b){}
}f[2][N],g[2][N],t1[N],t2[N],t3[N];
inline comp operator +(const comp a,const comp b) {return comp(a.r+b.r,a.i+b.i);}
inline comp operator -(const comp a,const comp b) {return comp(a.r-b.r,a.i-b.i);}
inline comp operator *(const comp a,const comp b) {return comp(a.r*b.r-a.i*b.i,a.r*b.i+b.r*a.i);}
void FFT(comp *a,int kd,int lim)
{
for(int i=0;i<lim;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(int mid=1;mid<lim;mid<<=1)
{
comp wn=comp(std::cos(pi/mid),kd*std::sin(pi/mid));
for(int i=0;i<lim;i+=(mid<<1))
{
comp w=comp(1.0,0.0);
for(int j=0;j<mid;j++,w=wn*w)
{
comp x=a[i+j];
comp y=a[i+j+mid]*w;
a[i+j]=x+y;
a[i+j+mid]=x-y;
}
}
}
if(kd==-1)
{
for(int i=0;i<lim;i++)
{
a[i].r/=lim;
}
}
}
void mul1(long long *a,long long *b,int cnt)
{
int lim=1<<cnt;
for(int i=0;i<lim;i++)
{
f[0][i].r=a[i]/sz;
f[0][i].i=0;
f[1][i].r=a[i]%sz;
f[1][i].i=0;
g[0][i].r=b[i]/sz;
g[0][i].i=0;
g[1][i].r=b[i]%sz;
g[1][i].i=0;
ans[i]=0;
}
for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(cnt-1));
FFT(f[0],1,lim);FFT(f[1],1,lim);
FFT(g[0],1,lim);FFT(g[1],1,lim);
for(int i=0;i<lim;i++)
{
t1[i]=f[0][i]*g[0][i];
t2[i]=f[0][i]*g[1][i]+g[0][i]*f[1][i];
t3[i]=f[1][i]*g[1][i];
}
FFT(t1,-1,lim);FFT(t2,-1,lim);FFT(t3,-1,lim);
for(int i=0;i<lim;i++)
{
ans[i]=(((long long)(t1[i].r+0.5))%p*sz%p*sz%p+(((long long)(t2[i].r+0.5))%p*sz%p)+(long long)(t3[i].r+0.5)%p)%p;
}
}
void der(const long long *a,int cnt)
{
int lim=1<<cnt;
for(int i=0;i<lim;i++)
{
de[i]=a[i+1]*(i+1)%p;
}
}
void inte(const long long *a,int cnt)
{
int lim=1<<cnt;
for(int i=lim-1;i>=1;i--)
{
in[i]=a[i-1]*kasumi(i,p-2,p)%p;
}
}
void get_inv(const long long *a,int len)
{
int cnt=0;
for(int i=0;i<len;i++) tmp1[i]=tmp2[i]=inv[i]=0;
inv[0]=kasumi(a[0],p-2,p);
int lim=1;
while(lim<len)
{
cnt++;
lim<<=1;
for(int i=0;i<lim>>1;i++)
{
tmp1[i]=inv[i];
tmp2[i]=a[i];
}
mul1(tmp1,tmp2,cnt);
for(int i=0;i<lim>>1;i++)
{
tmp2[i]=ans[i];
}
mul1(tmp1,tmp2,cnt);
for(int i=0;i<lim>>1;i++)
{
inv[i]=(inv[i]*2ll%p-ans[i]+p)%p;
}
}
}
void get_ln(const long long *a,int len)
{
int lim=1,cnt=0;
while(lim<len) lim<<=1,cnt++;
get_inv(a,lim);
der(a,cnt);
mul1(inv,de,cnt);
inte(ans,cnt);
}
int main()
{
scanf("%d%d",&n,&p);
s[0]=1;
for(int i=1;i<=n;i++)
{
scanf("%lld",&s[i]);
}
get_ln(s,(n+1)*2);
for(int i=1;i<=n;i++) in[i]=in[i]*i%p;
for(int i=1;i<=n;i++)
{
for(int j=i*2;j<=n;j+=i)
{
in[j]=(in[j]-in[i]+p)%p;
}
}
int ans=0;
for(int i=1;i<=n;i++)
{
if(in[i]) ans++;
}
printf("%d\n",ans);
for(int i=1;i<=n;i++) if(in[i]) printf("%d ",i);
}