codeforces 755G 多项式

题意:n个球,分成一些组,一个组里可以有1个球或相邻的两个球。一个球只能在一个组里或不在任何组里。求组数为1,2,…m时的方案数。
正常的递推式: f[i][j] 表示i个球分成j组的方案数。
f[i][j]=f[i1][j]+f[i1][j1]+f[i2][j1]

那么 f[n] 的生成函数满足:
fn(x)=fn1(x)+xfn1(x)+xfn2(x)
=(x+1)fn1(x)+xfn2(x)

然后设 fi=C1T1(x)i+C2T2(x)i

其中 T0(x),T1(x) 为方程 T2(x)=(x+1)T(x)+x 的两个根

解得: T1(x)=1+x+(x2+6x+1)2 , T2(x)=1+x(x2+6x+1)2

f0=1,f1=1+x 代入解出 C1,C2

然后通分:
fn(x)=(1+x+(x2+6x+1)2)n+1(1+x(x2+6x+1)2)n+1(x2+6x+1)

然后 (1+x(x2+6x+1)2)n+1 最低项次数大于n可以直接舍掉。
然后就是多项式操作了。

#include 
using namespace std;
#define N (1<<16)+10
#define ll long long
#define mod 998244353
const int inv2=499122177;
int n,m,len;
int tmp[N],sq[N],inv_sq[N],rt1[N],ln1[N],a[N],ans[N],inv[N];
int test1[N],test2[N],test3[N];
int qpow(int x,int y)
{
    int ret=1;
    while(y)
    {
        if(y&1)ret=(ll)ret*x%mod;
        x=(ll)x*x%mod;y>>=1;
    }
    return ret;
}
void NTT(int *a,int len,int type)
{
    for(int i=0,t=0;i<len;i++)
    {
        if(ifor(int j=len>>1;(t^=j)>=1);
    }
    for(int i=2;i<=len;i<<=1)
    {
        int wn=qpow(3,(mod-1)/i);
        for(int j=0;j<len;j+=i)
        {
            int w=1,t;
            for(int k=0;k>1;k++,w=(ll)w*wn%mod)
            {
                t=(ll)a[j+k+(i>>1)]*w%mod;
                a[j+k+(i>>1)]=(a[j+k]-t+mod)%mod;
                a[j+k]=(a[j+k]+t)%mod;
            }
        }
    }
    if(type==-1)
    {
        for(int i=1;i<len>>1;i++)swap(a[i],a[len-i]);
        int t=qpow(len,mod-2);
        for(int i=0;i<len;i++)a[i]=(ll)a[i]*t%mod;
    }
}
////////////////
void test_root(int *a,int len)
{
    memset(test1,0,sizeof(test1));
    for(int i=0;i<len;i++)
        test1[i]=a[i];
    NTT(test1,len<<1,1);
    for(int i=0;i<len<<1;i++)
        test1[i]=(ll)test1[i]*test1[i]%mod;
    NTT(test1,len<<1,-1);
    for(int i=0;i<len;i++)
        printf("#%d ",test1[i]);
    puts("");
}
void test_inv(int *a,int *b,int len)
{
    memset(test1,0,sizeof(test1));
    memset(test2,0,sizeof(test2));
    for(int i=0;i<len;i++)
        test1[i]=a[i],test2[i]=b[i];
    NTT(test1,len<<1,1);
    NTT(test2,len<<1,1);
    for(int i=0;i<len<<1;i++)
        test3[i]=(ll)test1[i]*test2[i]%mod;
    NTT(test3,len<<1,-1);
    for(int i=0;i<len;i++)
        printf("#%d ",test3[i]);
    puts("");
}
////////////////
void get_inv(int *a,int *b,int len)
{
    static int tmp[N];
    if(len==1)
    {
        b[0]=qpow(a[0],mod-2);
        return;
    }
    get_inv(a,b,len>>1);
    for(int i=0;i<len;i++)tmp[i]=a[i];
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(tmp,len<<1,1);
    NTT(b,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)b[i]*(2-(ll)b[i]*tmp[i]%mod+mod)%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_root(int *a,int *b,int len)
{
    static int invb[N],tmp[N];
    if(len==1){b[0]=1;return;}
    get_root(a,b,len>>1);
    for(int i=0;i<len<<1;i++)invb[i]=0;
    get_inv(b,invb,len);
    for(int i=0;i<len;i++)tmp[i]=a[i];
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(tmp,len<<1,1);
    NTT(b,len<<1,1);
    NTT(invb,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)inv2*(b[i]+(ll)tmp[i]*invb[i]%mod)%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_ln(int *a,int *b,int len)
{
    static int inva[N],a1[N];
    for(int i=0;i<len<<1;i++)inva[i]=0;
    get_inv(a,inva,len);
    for(int i=0;i<len;i++)a1[i]=(ll)(i+1)*a[i+1]%mod;
    for(int i=len;i<len<<1;i++)a1[i]=0;
    NTT(a1,len<<1,1);
    NTT(inva,len<<1,1);
    for(int i=0;i<len<<1;i++)a1[i]=(ll)a1[i]*inva[i]%mod;
    NTT(a1,len<<1,-1);
    b[0]=0;
    for(int i=1;i<len;i++)
        b[i]=(ll)a1[i-1]*inv[i]%mod;
    for(int i=len;i<len<<1;i++)b[i]=0;
}
void get_exp(int *a,int *b,int len)
{
    static int lnb[N],tmp[N];
    if(len==1){b[0]=1;return;}
    get_exp(a,b,len>>1);
    for(int i=0;i<len<<1;i++)lnb[i]=0;
    get_ln(b,lnb,len);
    for(int i=0;i<len;i++)tmp[i]=(a[i]-lnb[i]+mod)%mod;
    tmp[0]++;
    for(int i=len;i<len<<1;i++)tmp[i]=0;
    NTT(b,len<<1,1);
    NTT(tmp,len<<1,1);
    for(int i=0;i<len<<1;i++)
        b[i]=(ll)b[i]*tmp[i]%mod;
    NTT(b,len<<1,-1);
    for(int i=len;i<len<<1;i++)b[i]=0;
}
int main()
{
    //freopen("tt.in","r",stdin);
    scanf("%d%d",&n,&m);
    for(len=1;len<=m;len<<=1);
    for(int i=1;i<len;i++)inv[i]=qpow(i,mod-2);
    tmp[0]=1;tmp[1]=6;tmp[2]=1;
    get_root(tmp,sq,len);
    get_inv(sq,inv_sq,len);
    rt1[0]=rt1[1]=1;
    for(int i=0;i<len;i++)
        rt1[i]=(ll)(rt1[i]+sq[i])%mod*inv2%mod;
    get_ln(rt1,ln1,len);
    for(int i=0;i<len;i++)
        ln1[i]=(ll)ln1[i]*(n+1)%mod;
    get_exp(ln1,a,len);
    NTT(inv_sq,len<<1,1);
    NTT(a,len<<1,1);
    for(int i=0;i<len<<1;i++)
        ans[i]=(ll)a[i]*inv_sq[i]%mod;
    NTT(ans,len<<1,-1);
    for(int i=1;i<=m;i++)
        printf("%d ",i>n ? 0:ans[i]);
    return 0;
}

你可能感兴趣的:(多项式)