loj#2541. 「PKUWC 2018」猎人杀【容斥+概率dp+生成函数+分治FFT】

传送门

解题思路:

思路巧妙……

原题中每轮概率都在变化,一脸不可做,但注意到对问题的转化:
我们杀人后将其打上标记,但还是可以以他为目标重复选,直到选到一个未打标记的人。
这和原问题等价,而且这样每轮选中每人的概率都不变。

考虑容斥,枚举强制在1号后面死的人,即1号至少在这些人前面,令 A=wi A = ∑ w i S S 为枚举到的人的 wi w i 之和, t t 为人数,则

ans=(1)ti=0(1S+w1A)iw1A a n s = ( − 1 ) t ∑ i = 0 ∞ ( 1 − S + w 1 A ) i w 1 A

T=i=0(1S+w1A)i T = ∑ i = 0 ∞ ( 1 − S + w 1 A ) i
(1S+w1A)T=i=1(1S+w1A)i ( 1 − S + w 1 A ) T = ∑ i = 1 ∞ ( 1 − S + w 1 A ) i
S+w1AT=(1S+w1A)0=1 S + w 1 A T = ( 1 − S + w 1 A ) 0 = 1
T=AS+w1 T = A S + w 1

所以
ans=(1)tw1S+w1 a n s = ( − 1 ) t w 1 S + w 1

考虑直接算出每个 S S 的容斥系数和 fS f S ,当做生成函数算,即是
F(x)=i=2n(1xwi) F ( x ) = ∏ i = 2 n ( 1 − x w i )

分治FFT即可。

q=wi q = ∑ w i ,那么复杂度为 O(qlog2q) O ( q l o g 2 q )

#include
#define ll long long
#define pb push_back
using namespace std;
int getint()
{
    ll i=0,f=1;char c;
    for(c=getchar();(c!='-')&&(c<'0'||c>'9');c=getchar());
    if(c=='-')c=getchar(),f=-1;
    for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
    return i*f;
}
const int N=100005,mod=998244353,g=3;
int n,tot,a[N],pos[N<<3];
ll w[65],invw[65],f1[N<<3],f2[N<<3];
ll Pow(ll x,int y)
{
    ll res=1;
    for(;y;y>>=1,x=x*x%mod)
        if(y&1)res=res*x%mod;
    return res;
}
void NTT(ll *f,int len,int on)
{
    for(int i=1;i1)?pos[i>>1]>>1|(len>>1):pos[i>>1]>>1;
    for(int i=1;iif(ifor(int i=1,num=1;i1,num++)
    {
        ll wi=on==1?w[num]:invw[num];
        for(int j=0;j1))
        {
            ll wn=1;
            for(int k=j;kif(on==-1)for(int i=0,inv=Pow(len,mod-2);ivector solve(int l,int r)
{
    vectorA,B,C;
    if(l==r)
    {
        C.resize(a[l]+1);
        C[0]=1,C[a[l]]=-1;
        return C;
    }
    int mid=l+r>>1;
    A=solve(l,mid),B=solve(mid+1,r);
    int l1=A.size()-1,l2=B.size()-1,len=1;
    while(len<=l1+l2)len<<=1;
    for(int i=0;i<=l1;i++)f1[i]=A[i];
    for(int i=l1+1;i<=len;i++)f1[i]=0;
    for(int i=0;i<=l2;i++)f2[i]=B[i];
    for(int i=l2+1;i<=len;i++)f2[i]=0;
    NTT(f1,len,1),NTT(f2,len,1);
    for(int i=0;i1);
    for(int i=0;i<=l1+l2;i++)C.pb(f1[i]);
    return C;
}
int main()
{
    //freopen("lx.in","r",stdin);
    n=getint();
    for(int i=1;i<=n;i++)a[i]=getint(),tot+=a[i]*(i>1);
    int len=1,num=0;
    while(len<=tot)
    {
        len<<=1;
        w[++num]=Pow(g,(mod-1)/len);
        invw[num]=Pow(w[num],mod-2);
    }
    vectorA=solve(2,n);
    ll ans=0;
    for(int S=0;S<=tot;S++)
        ans=(ans+A[S]*a[1]%mod*Pow(S+a[1],mod-2))%mod;
    cout<'\n';
}


你可能感兴趣的:(容斥原理,多项式运算,概率dp)