【洛谷】P4705 玩游戏-生成函数

传送门:luoguP4705


题解

t t t次价值的期望: 1 n m ∑ i = 1 n ∑ j = 1 m ∑ ( a i + b j ) t \dfrac{1}{nm}\sum\limits_{i=1}^n\sum\limits_{j=1}^m\sum(a_i+b_j)^t nm1i=1nj=1m(ai+bj)t

二项式定理展开一下:
t ! n m ∑ k = 0 t 1 k ! ∑ i = 1 n a i k 1 ( t − k ) ! ∑ j = 1 m b j t − k \dfrac{t!}{nm}\sum\limits_{k=0}^t\dfrac {1}{k!}\sum\limits_{i=1}^na_i^k\dfrac{1}{(t-k)!}\sum\limits_{j=1}^mb_j^{t-k} nmt!k=0tk!1i=1naik(tk)!1j=1mbjtk

所以需要构造的生成函数 F ( x ) , G ( x ) F(x),G(x) F(x),G(x)的第 i i i项系数分别为 ∑ j = 1 n a j i , ∑ j = 1 m b j i \sum\limits_{j=1}^na_j^i,\sum\limits_{j=1}^mb_j^i j=1naji,j=1mbji

单独考虑 a j a_j aj F F F的贡献: 1 + a i x + a i 2 x + . . . 1+a_ix+a_i^2x+... 1+aix+ai2x+...,生成函数为 1 1 − a j x \dfrac1{1-a_jx} 1ajx1
所以 F ( x ) = ∑ i = 1 n 1 1 − a i x F(x)=\sum\limits_{i=1}^n\dfrac1{1-a_ix} F(x)=i=1n1aix1

1 − a i x 1-a_ix 1aix在分母不好看,积分转成 l n ln ln,得到 F ( x ) = ∑ i = 1 n ln ⁡ ′ ( 1 − a i x ) F(x)=\sum\limits_{i=1}^n\ln'(1-a_ix) F(x)=i=1nln(1aix)

还是不好看。。。再把求导转到整个函数外面:

F ′ ( x ) = ( ln ⁡ ( ∏ i = 1 n ( 1 − a i x ) ) ) ′ F'(x)=(\ln(\prod\limits_{i=1}^n(1-a_ix)))' F(x)=(ln(i=1n(1aix)))

F = − x F ′ + n F=-xF'+n F=xF+n

G G G同理,然后套多项式模板即可。


代码

#include
#define mem(f) memset(f,0,sizeof(f))
using namespace std;
typedef long long ll;
const int N=1e5+10,M=2e6+10,mod=998244353,gen=3;

int n,m,bs,ivg,a[M],b[M],f[M],g[M];
int rv[M],s[20][M],frac[N],nv[N];

char cp,OS[100];
inline void rd(int &x)
{
    cp=getchar();x=0;
    for(;!isdigit(cp);cp=getchar());
    for(;isdigit(cp);cp=getchar()) x=(x<<3)+(x<<1)+(cp^48);
}

inline void ot(int x)
{
	int re=0;OS[0]='\n';
	for(;(!re)||x;x/=10) OS[++re]='0'+x%10;
	for(;~re;--re) putchar(OS[re]);
}

inline int ad(int x,int y){x+=y;return x>=mod?x-mod:x;}
inline int dc(int x,int y){x-=y;return x<0?x+mod:x;}

inline int fp(int x,int y)
{
    int re=1;
    for(;y;y>>=1,x=(ll)x*x%mod)
      if(y&1) re=(ll)re*x%mod;
    return re;
}

inline void ntt(int *e,int pr,int n)
{
    int i,j,k,ix,iy,ori,pd,g=pr?gen:ivg;
    for(i=1;i<n;++i) if(i<rv[i]) swap(e[i],e[rv[i]]);
    for(i=1;i<n;i<<=1){
        ori=fp(g,(mod-1)/(i<<1));
        for(j=0;j<n;j+=(i<<1))
            for(pd=1,k=0;k<i;++k,pd=(ll)pd*ori%mod){
                ix=e[j+k];iy=(ll)pd*e[j+k+i]%mod;
                e[j+k]=ad(ix,iy);e[j+k+i]=dc(ix,iy);
            }
    }
    if(pr) return;
    g=fp(n,mod-2);for(i=0;i<n;++i) e[i]=(ll)e[i]*g%mod;
}

inline void init(int n,int &len)
{
	int i,L=0;n<<=1;
	for(len=1;len<n;len<<=1) L++;
	for(i=1;i<len;++i) rv[i]=(rv[i>>1]>>1)|((i&1)<<(L-1));
}

inline void mul(int *f,int *g,int len)
{
	ntt(f,1,len);ntt(g,1,len);
	for(int i=0;i<len;++i) f[i]=(ll)f[i]*g[i]%mod;
	ntt(f,0,len);
}

inline void cal(int dep,int l,int r)
{
    if(l>r) return;
    if(l==r) {s[dep][0]=1;s[dep][1]=(mod-a[l])%mod;return;}
    int mid=(l+r)>>1,i,len,lim=mid-l+1;cal(dep+1,l,mid);
    for(i=0;i<=lim;++i) s[dep][i]=s[dep+1][i];cal(dep+1,mid+1,r);
	init(r-l+2,len);
	fill(s[dep]+lim+1,s[dep]+len,0);
    fill(s[dep+1]+r-mid+1,s[dep+1]+len,0);
    mul(s[dep],s[dep+1],len);
}

inline void gtder(int n,int *f,int *g)
{for(int i=1;i<n;++i) f[i-1]=(ll)g[i]*i%mod;f[n-1]=0;}

void gtinv(int n,int *f,int *g)
{
    if(n==1) {f[0]=fp(g[0],mod-2);return;}
    gtinv((n+1)>>1,f,g);int i,j,len;static int cont[M];
    for(i=0;i<n;++i) cont[i]=g[i];init(n,len);
	fill(cont+n,cont+len,0);ntt(f,1,len);ntt(cont,1,len);
	for(i=0;i<len;++i) f[i]=(ll)f[i]*dc(2,(ll)cont[i]*f[i]%mod)%mod;
	ntt(f,0,len);fill(f+n,f+len,0);
}

inline void gtln(int n,int *f,int *g)
{
    int i,j,len=1,L=0;static int der[M],nv[M];
    for(;len<n+n;len<<=1) L++;fill(der,der+len,0);fill(nv,nv+len,0);
    gtder(n,der,g);gtinv(n,nv,g);
    for(i=1;i<len;++i) rv[i]=(rv[i>>1]>>1)|((i&1)<<(L-1));
    mul(der,nv,len);
	for(i=0;i<n;++i) f[i]=der[i];fill(f+n,f+len,0);
}

int main(){
    int i,j,len,rev;
    rd(n);rd(m);ivg=fp(gen,mod-2);rev=fp((ll)n*m%mod,mod-2);
    for(i=1;i<=n;++i) rd(f[i]);for(i=1;i<=m;++i) rd(g[i]);
    
    frac[0]=frac[1]=nv[0]=nv[1]=1;
    for(rd(bs),i=2;i<=bs;++i)
      frac[i]=(ll)frac[i-1]*i%mod,nv[i]=(ll)(mod-mod/i)*nv[mod%i]%mod;
    for(i=2;i<=bs;++i) nv[i]=(ll)nv[i-1]*nv[i]%mod;

    len=max(bs,max(n,m))+1;
    memcpy(a,f,(n+2)<<2);cal(0,1,n);f[0]=n;
    gtln(len,b+1,s[0]);for(i=bs;i;--i) f[i]=(ll)nv[i]*(mod-b[i])%mod;
    
	mem(s[0]);mem(b);
    memcpy(a,g,(m+2)<<2);cal(0,1,m);g[0]=m;
    gtln(len,b+1,s[0]);for(i=bs;i;--i) g[i]=(ll)nv[i]*(mod-b[i])%mod;
    
    init(bs+1,len);
	fill(f+bs+1,f+len,0);fill(g+bs+1,g+len,0);////记得清空,不然只有66pts 
	mul(f,g,len);
    for(i=1;i<=bs;++i) ot((ll)frac[i]*f[i]%mod*(ll)rev%mod);
    return 0;
}

你可能感兴趣的:(---多项式---,FFT,NTT,---组合数学---)