bzoj5004 开锁魔法II【概率dp&&生成函数】

题目大意:

给一个n的点的图,每个点只有一条出边和入边(组成了若干环),现在从中选k个点,问每个环至少选中一个点的概率。
n、k<=300;

解题思路:

大概的思路是求可行方案数除以总方案数 (nk) ( n k )

f[i][j] f [ i ] [ j ] 表示前i个环选了j个点的方案数,则有 f[i][j]=f[i1][js](size[i]s) f [ i ] [ j ] = f [ i − 1 ] [ j − s ] ∗ ( s i z e [ i ] s ) size[i] s i z e [ i ] 为第i个环的大小。

其实如果可以也能用生成函数,这样n、k可以开到100000.
考虑 Fi F i 表示第i个环选不同的点的方案的生成函数,则 Fi=j=1(size[i]j)xj F i = ∑ j = 1 ( s i z e [ i ] j ) x j ,将所有 Fi F i 相乘后的 xk x k 的系数即为方案数,使用分治NTT即可。

普通代码:

#include
#define ll long long
using namespace std;

int getint()
{
    int i=0,f=1;char c;
    for(c=getchar();(c!='-')&&(c<'0'||c>'9');c=getchar());
    if(c=='-')f=-1,c=getchar();
    for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
    return i*f;
}

const int N=305;
int T,n,m,k,p[N],size[N];
double c[N][N],f[N][N],ans;
bool vis[N];

int dfs(int x)
{
    if(vis[x])return 0;
    vis[x]=1;return dfs(p[x])+1;
}

int main()
{
    //freopen("lx.in","r",stdin);
    for(int i=0;i0]=1;
        for(int j=1;j<=i;j++)
            c[i][j]=c[i-1][j-1]+c[i-1][j];
    }
    T=getint();
    while(T--)
    {
        n=getint(),k=getint();m=0;
        for(int i=1;i<=n;i++)p[i]=getint(),vis[i]=0;
        for(int i=1;i<=n;i++)
            if(!vis[i])size[++m]=dfs(i);
        for(int i=0;i<=m;i++)
            for(int j=0;j<=k;j++)
                f[i][j]=0;
        f[0][0]=1;
        for(int i=1;i<=m;i++)
            for(int j=1;j<=k;j++)
                for(int s=1;s<=j&&s<=size[i];s++)
                    f[i][j]=f[i][j]+f[i-1][j-s]*c[size[i]][s];
        ans=f[m][k]/c[n][k];
        printf("%0.9lf\n",ans);
    }
    return 0;
}

生成函数:

#include
#define ll long long
using namespace std;

int getint()
{
    int i=0,f=1;char c;
    for(c=getchar();(c!='-')&&(c<'0'||c>'9');c=getchar());
    if(c=='-')f=-1,c=getchar();
    for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
    return i*f;
}

const int N=200005,mod=998244353,g=3;
int T,n,m,k,p[N],size[N];
int fac[N],fac_inv[N],w[N],w_inv[N],pos[N*10],A[N*10],B[N*10];
bool vis[N];
vector<int>a[N];

int C(int x,int y){return 1ll*fac[x]*fac_inv[y]%mod*fac_inv[x-y]%mod;}

int Pow(int x,int y)
{
    int res=1;
    for(;y;y>>=1,x=1ll*x*x%mod)
        if(y&1)res=(ll)res*x%mod;
    return res;
}

int dfs(int x)
{
    if(vis[x])return 0;
    vis[x]=1;return dfs(p[x])+1;
}

void rev(int k)
{
    for(int i=1;ipos[i]=(i&1)?pos[i>>1]>>1|(k>>1):pos[i>>1]>>1;
}

void NTT(int *f,int len,int on)
{
    for(int i=0;iif(i<pos[i])swap(f[i],f[pos[i]]);
    for(int i=1,num=1;i1,num++)
    {
        int wn=(on==1?w[num]:w_inv[num]);
        for(int j=0;j1))
        {
            int wi=1;
            for(int k=j;kint u=f[k],v=(ll)f[k+i]*wi%mod;
                f[k]=(u+v)%mod,f[k+i]=(u-v+mod)%mod;
                wi=(ll)wi*wn%mod;
            }
        }
    }
    if(on==-1)
        for(int i=0;i*w_inv[0]%mod;
}

int multi(int *A,int *B,int len)
{
    w_inv[0]=Pow(len,mod-2);
    rev(len);
    NTT(A,len,1),NTT(B,len,1);
    for(int i=0;i*B[i]%mod;
    NTT(A,len,-1);--len;
    while(!A[len])--len;
    return len;
}

int solve(int l,int r)
{
    if(l==r)return a[l].size()-1;
    int mid=l+r>>1;
    int l1=solve(l,mid),l2=solve(mid+1,r);
    int len=1;
    while(len<=l1+l2)len<<=1;
    for(int i=0;i<=l1;i++)A[i]=a[l][i];
    for(int i=l1+1;i0;
    for(int i=0;i<=l2;i++)B[i]=a[mid+1][i];
    for(int i=l2+1;i0;
    a[l].clear(),a[mid+1].clear();
    len=multi(A,B,len);
    for(int i=0;i<=len;i++)a[l].push_back(A[i]);
    return len;
}

int main()
{
    //freopen("lx.in","r",stdin);
    fac[0]=1;
    for(int i=1;i1]*i%mod;
    fac_inv[N-1]=Pow(fac[N-1],mod-2);
    for(int i=N-2;i>=0;i--)fac_inv[i]=(ll)fac_inv[i+1]*(i+1)%mod;
    int len=1,num=0;
    while(len<(N<<1))len<<=1,w[++num]=Pow(g,(mod-1)/len),w_inv[num]=Pow(w[num],mod-2);
    T=getint();
    while(T--)
    {
        n=getint(),k=getint(),m=0;
        for(int i=1;i<=n;i++)p[i]=getint(),vis[i]=0;
        for(int i=1;i<=n;i++)if(!vis[i])size[++m]=dfs(i);
        for(int i=1;i<=m;i++)
        {
            a[i].push_back(0);
            for(int j=1;j<=size[i];j++)a[i].push_back(C(size[i],j));
        }
        solve(1,m);a[1].clear();
        printf("%d\n",(ll)a[1][k]*Pow(C(n,k),mod-2)%mod);
    }
    return 0;
}

你可能感兴趣的:(概率dp,多项式运算,生成函数,bzoj)