洛谷 P3784 [SDOI2017]遗忘的集合(任意模数NTT+多项式求ln)

题目链接
emmm,只要做过付公主的背包,这题就还算清真了
首先还是对生成函数求逆,我们知道 l n ( f ( x ) ) = ∑ i = 1 n a i x ∗ [ x   % a i   = 0 ] ln(f(x))=\sum_{i=1}^{n}{\frac{a_i}{x}*[x\, \%a_i\,=0]} ln(f(x))=i=1nxai[x%ai=0]
假设我们已经求出了这个逆
感受一下上面那个式子,显然求逆以后第一个有值的位置对应的数字一定是被选的,这个应该很好理解。
然后这个数字i会对他的每个倍数j产生 i j \frac{i}{j} ji的贡献
把他从对应的位置减掉就可以了,然后显然下一个要选的肯定是他之后第一个有值的位置,如此搞下去可以调和级数求解。
最后把有数字的位置输出来就可以了
这个用贪心的思路去想一定是字典序最小的最优解。

然后每次都被卡常……
三模NTT要整整十秒

#include
#define gg 3
#define N 600030
using namespace std;

long long ans[N],f[3][N],g[3][N],mod1[]={998244353,469762049,1004535809};
long long inv[N],tmp1[N],tmp2[N],s[N],ln[N],de[N],in[N]; 
int r[N],n,m,p;

inline long long mul(long long a,long long b,long long mod)
{
    long long res=a*b-(long long)((long double)a*b/mod+0.5)*mod;
    return res<0?res+mod:res;
}

long long kasumi(long long a,long long b,long long mod)
{
    long long ans=1;
    while(b)
    {
        if(b&1) ans=ans*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ans;
}

void NTT(long long *a,int kd,int mod,int lim)
{
    for(int i=0;i<lim;i++)
    {
        if(i<r[i]) swap(a[i],a[r[i]]);
    }
    for(int mid=1;mid<lim;mid<<=1)
    {
        long long wn=kasumi(gg,(mod-1)/(mid<<1),mod);
        if(kd) wn=kasumi(wn,mod-2,mod);
        for(int i=0;i<lim;i+=mid<<1)
        {
            long long w=1;
            for(int j=0;j<mid;j++,w=wn*w%mod)
            {
                long long x=a[i+j];
                long long y=a[i+j+mid]*w%mod;
                a[i+j]=(x+y)%mod;
                a[i+j+mid]=(x-y+mod)%mod;
            }
        }
    }
    if(kd)
    {
        int inv=kasumi(lim,mod-2,mod);
        for(int i=0;i<lim;i++) a[i]=a[i]*inv%mod;
    }
}

void mul1(long long *a,long long *b,int cnt)
{
    int lim=1<<cnt;
    for(int i=0;i<lim;i++)
    {
        f[0][i]=f[1][i]=f[2][i]=a[i];
        g[0][i]=g[1][i]=g[2][i]=b[i];
        ans[i]=0;
    }
    for(int i=0;i<lim;i++)
    {
        r[i]=(r[i>>1]>>1)|((i&1)<<(cnt-1));
    }
    for(int i=0;i<=2;i++)
    {
        NTT(f[i],0,mod1[i],lim);NTT(g[i],0,mod1[i],lim);
        for(int j=0;j<lim;j++)
        {
            f[i][j]=f[i][j]*g[i][j]%mod1[i];
        }
        NTT(f[i],1,mod1[i],lim);
    }
    long long inv1=kasumi(mod1[0],mod1[1]-2,mod1[1]);
    long long inv2=kasumi(mod1[1],mod1[0]-2,mod1[0]);
    long long mul1=mod1[0]*mod1[1];
    for(int i=0;i<lim;i++)
    {
        ans[i]+=mul(f[0][i]*inv2%mul1,mod1[1],mul1);
        ans[i]+=mul(f[1][i]*inv1%mul1,mod1[0],mul1);
        ans[i]%=mul1;
    }
    long long inv3=kasumi(mul1%mod1[2],mod1[2]-2,mod1[2]);
    for(int i=0;i<lim;i++)
    {
        ans[i]=((f[2][i]-ans[i]%mod1[2]+mod1[2])%mod1[2]*inv3%mod1[2]*(mul1%p)%p+ans[i]%p)%p;
    }
}

void der(const long long *a,int cnt)
{
    int lim=1<<cnt;
    for(int i=0;i<lim;i++)
    {
        de[i]=a[i+1]*(i+1)%p;
    } 
}

void inte(const long long *a,int cnt)
{
    int lim=1<<cnt;
    for(int i=lim-1;i>=1;i--)
    {
        in[i]=a[i-1]*kasumi(i,p-2,p)%p;
    }
}

void get_inv(const long long *a,int len)
{

    int cnt=0;
    for(int i=0;i<len;i++) tmp1[i]=tmp2[i]=inv[i]=0;
    inv[0]=kasumi(a[0],p-2,p);
    int lim=1;
    while(lim<len)
    {
        cnt++;
        lim<<=1;
        for(int i=0;i<lim>>1;i++)
        {
            tmp1[i]=inv[i];
            tmp2[i]=a[i];
        }
        mul1(tmp1,tmp2,cnt);
        for(int i=0;i<lim>>1;i++)
        {
            tmp2[i]=ans[i];
        }
        mul1(tmp1,tmp2,cnt);
        for(int i=0;i<lim>>1;i++)
        {
            inv[i]=(inv[i]*2ll%p-ans[i]+p)%p;
        }
    }
}

void get_ln(const long long *a,int len)
{
    int lim=1,cnt=0;
    while(lim<len) lim<<=1,cnt++;
    get_inv(a,lim);
    der(a,cnt);
    mul1(inv,de,cnt);
    inte(ans,cnt);
}

int main()
{
    scanf("%d%d",&n,&p);
    s[0]=1;
    for(int i=1;i<=n;i++)
    {
        scanf("%lld",&s[i]);
    }
    get_ln(s,(n+1)*2);
    for(int i=1;i<=n;i++) in[i]=in[i]*i%p;
    for(int i=1;i<=n;i++)
    {
        for(int j=i*2;j<=n;j+=i)
        {
            in[j]=(in[j]-in[i]+p)%p;
        }
    }
    int ans=0;
    for(int i=1;i<=n;i++)
    {
        if(in[i]) ans++;
    }
    printf("%d\n",ans);
    for(int i=1;i<=n;i++) if(in[i]) printf("%d ",i);
}

换上不加黑科技的拆系数FFT就三秒左右了

#include
#define sz 32768
#define N 600030
using namespace std;

long long ans[N];
long long inv[N],tmp1[N],tmp2[N],s[N],ln[N],de[N],in[N]; 
int r[N],n,m,p;

long long kasumi(long long a,long long b,long long mod)
{
    long long aa=1;
    while(b)
    {
        if(b&1) aa=aa*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return aa;
}

const long double pi=std::acos(-1);

struct comp
{
    long double r,i;
    comp(){}
    comp(long double a,long double b):r(a),i(b){}
}f[2][N],g[2][N],t1[N],t2[N],t3[N];

inline comp operator +(const comp a,const comp b) {return comp(a.r+b.r,a.i+b.i);}

inline comp operator -(const comp a,const comp b) {return comp(a.r-b.r,a.i-b.i);}

inline comp operator *(const comp a,const comp b) {return comp(a.r*b.r-a.i*b.i,a.r*b.i+b.r*a.i);}

void FFT(comp *a,int kd,int lim)
{
    for(int i=0;i<lim;i++) if(i<r[i]) swap(a[i],a[r[i]]);
    for(int mid=1;mid<lim;mid<<=1)
    {
        comp wn=comp(std::cos(pi/mid),kd*std::sin(pi/mid));
        for(int i=0;i<lim;i+=(mid<<1))
        {
            comp w=comp(1.0,0.0);
            for(int j=0;j<mid;j++,w=wn*w)
            {
                comp x=a[i+j];
                comp y=a[i+j+mid]*w;
                a[i+j]=x+y;
                a[i+j+mid]=x-y;
            }
        }
    }
    if(kd==-1)
    {
        for(int i=0;i<lim;i++)
        {
            a[i].r/=lim;
        }
    }
}


void mul1(long long *a,long long *b,int cnt)
{
    int lim=1<<cnt;
    for(int i=0;i<lim;i++)
    {
        f[0][i].r=a[i]/sz;
        f[0][i].i=0;
        f[1][i].r=a[i]%sz;
        f[1][i].i=0;
        g[0][i].r=b[i]/sz;
        g[0][i].i=0;
        g[1][i].r=b[i]%sz;
        g[1][i].i=0;
        ans[i]=0;
    }
    for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(cnt-1));
    FFT(f[0],1,lim);FFT(f[1],1,lim);
    FFT(g[0],1,lim);FFT(g[1],1,lim);
    for(int i=0;i<lim;i++)
    {
        t1[i]=f[0][i]*g[0][i];
        t2[i]=f[0][i]*g[1][i]+g[0][i]*f[1][i];
        t3[i]=f[1][i]*g[1][i];
    }
    FFT(t1,-1,lim);FFT(t2,-1,lim);FFT(t3,-1,lim);
    for(int i=0;i<lim;i++)
    {
        ans[i]=(((long long)(t1[i].r+0.5))%p*sz%p*sz%p+(((long long)(t2[i].r+0.5))%p*sz%p)+(long long)(t3[i].r+0.5)%p)%p;
    }
}

void der(const long long *a,int cnt)
{
    int lim=1<<cnt;
    for(int i=0;i<lim;i++)
    {
        de[i]=a[i+1]*(i+1)%p;
    } 
}

void inte(const long long *a,int cnt)
{
    int lim=1<<cnt;
    for(int i=lim-1;i>=1;i--)
    {
        in[i]=a[i-1]*kasumi(i,p-2,p)%p;
    }
}

void get_inv(const long long *a,int len)
{

    int cnt=0;
    for(int i=0;i<len;i++) tmp1[i]=tmp2[i]=inv[i]=0;
    inv[0]=kasumi(a[0],p-2,p);
    int lim=1;
    while(lim<len)
    {
        cnt++;
        lim<<=1;
        for(int i=0;i<lim>>1;i++)
        {
            tmp1[i]=inv[i];
            tmp2[i]=a[i];
        }
        mul1(tmp1,tmp2,cnt);
        for(int i=0;i<lim>>1;i++)
        {
            tmp2[i]=ans[i];
        }
        mul1(tmp1,tmp2,cnt);
        for(int i=0;i<lim>>1;i++)
        {
            inv[i]=(inv[i]*2ll%p-ans[i]+p)%p;
        }
    }
}

void get_ln(const long long *a,int len)
{
    int lim=1,cnt=0;
    while(lim<len) lim<<=1,cnt++;
    get_inv(a,lim);
    der(a,cnt);
    mul1(inv,de,cnt);
    inte(ans,cnt);
}

int main()
{
    scanf("%d%d",&n,&p);
    s[0]=1;
    for(int i=1;i<=n;i++)
    {
        scanf("%lld",&s[i]);
    }
    get_ln(s,(n+1)*2);
    for(int i=1;i<=n;i++) in[i]=in[i]*i%p;
    for(int i=1;i<=n;i++)
    {
        for(int j=i*2;j<=n;j+=i)
        {
            in[j]=(in[j]-in[i]+p)%p;
        }
    }
    int ans=0;
    for(int i=1;i<=n;i++)
    {
        if(in[i]) ans++;
    }
    printf("%d\n",ans);
    for(int i=1;i<=n;i++) if(in[i]) printf("%d ",i);
}

你可能感兴趣的:(洛谷,FFT/NTT,多项式全家桶,任意模数NTT)