[JSOI2019]神经网络

题目

ddy讲的牛逼题。

由于树和树之间是完全图,所以我们要做的就是把树拆成一堆路径,之后把这些路径合并起来,就能得到哈密顿回路了;

所以首先对每棵树求一个链划分,设\(dp_{i,j,0/1/2}\)表示在子树\(i\)中划分出了\(j\)条链,\(0\)表示点\(i\)已经划分好了,\(1\)表示点\(i\)自己在一条链中,\(2\)表示点\(i\)在一条还能继续加点的长度大于\(1\)的链中,注意到长度大于\(1\)的链有两个方向,计算贡献的时候需要乘\(2\);大力树上背包即可,复杂度\(O(n^2)\)

再来考虑合并的问题,现在我们把问题转化成了一共有\(n\)种颜色,每种颜色的有\(a_i\)个可区分的点,求使得不存在两个相邻同色点的环排列的个数;

首先考虑不是环的情况,考虑容斥,我们枚举一下第\(i\)中颜色至多分成\(b_i\)段,大概是

\[\sum \frac{(\sum b_i)!}{\prod b_i!}\prod_{i=1}^n (-1)^{a_i-b_i}\binom{a_i-1}{b_i-1}a_i!\]

\((-1)^{a_i-b_i}\)是容斥系数,分成\(b_i\)段就用组合数插板一下;由于乘上\(a_i!\),所以所有颜色排列在一起的时候需要保证相对顺序,就是一个有重复元素的排列问题。

不难发现\(\frac{1}{b_i!}\)可以拆到里面来,于是不难想到搞一个\(\rm EGF\)

于是某一棵树的\(\rm EGF\)就是

\[\sum_{i=1}^nf_ii!\sum_{j=1}^i (-1)^{i-j}\binom{i-1}{j-1}\frac{x^j}{j!}\]

\(f_i\)是把这棵树拆成\(i\)条链的贡献。

再来考虑环的情况,不妨钦定第一棵树的\(1\)号节点作为开头,于是第一棵树的生成函数就是

\[\sum_{i=1}^nf_i(i-1)!\sum_{j=1}^i (-1)^{i-j}\binom{i-1}{j-1}\frac{x^{j-1}}{(j-1)!}\]

由于首位不能是都是第一棵树,于是我们直接减掉这种情况,直接钦定开头是第一棵树的\(1\)号节点,结尾是第一棵树的某个节点,就有

\[\sum_{i=1}^nf_i(i-1)!\sum_{j=1}^i (-1)^{i-j}\binom{i-1}{j-1}\frac{x^{j-2}}{(j-2)!}\]

最后把所有\(\rm EGF\)大力卷起来就好了;

代码

#include
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
inline int read() {
    char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
    while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=5e3+5;const int mod=998244353;
inline int dqm(int x) {return x<0?x+mod:x;}
inline int qm(int x) {return x>=mod?x-mod:x;}
inline int ksm(int a,int b) {
    int S=1;for(;b;b>>=1,a=1ll*a*a%mod)if(b&1)S=1ll*S*a%mod;return S;
}
struct E{int v,nxt;}e[maxn<<1];
int fac[maxn],ifac[maxn],lim=5000;
int m,n,sum,num,head[maxn],ans[maxn],f[maxn],g[maxn];
int dp[maxn][maxn][3],tmp[maxn][3],sz[maxn];
inline void add(int x,int y) {
    e[++num].v=y;e[num].nxt=head[x];head[x]=num;
}
void dfs(int x,int fa) {
    sz[x]=1;dp[x][0][1]=1;
    for(re int i=head[x];i;i=e[i].nxt) {
        if(e[i].v==fa) continue;
        dfs(e[i].v,x);int v=e[i].v;
        for(re int j=0;j<=sz[v];j++) 
            for(re int k=0;k<=sz[x];++k) {
                if(dp[x][k][0]) {
                    tmp[k+j][0]=qm(tmp[j+k][0]+1ll*dp[v][j][0]*dp[x][k][0]%mod);
                    tmp[k+j+1][0]=qm(tmp[j+k+1][0]+1ll*dp[v][j][1]*dp[x][k][0]%mod);
                    tmp[k+j+1][0]=qm(tmp[j+k+1][0]+2ll*dp[v][j][2]*dp[x][k][0]%mod);
                }
                if(dp[x][k][1]) {
                    tmp[k+j][1]=qm(tmp[j+k][1]+1ll*dp[v][j][0]*dp[x][k][1]%mod);
                    tmp[k+j+1][1]=qm(tmp[j+k+1][1]+1ll*dp[v][j][1]*dp[x][k][1]%mod);
                    tmp[k+j][2]=qm(tmp[k+j][2]+1ll*dp[v][j][1]*dp[x][k][1]%mod);
                    tmp[k+j+1][1]=qm(tmp[k+j+1][1]+2ll*dp[v][j][2]*dp[x][k][1]%mod);
                    tmp[k+j][2]=qm(tmp[k+j][2]+1ll*dp[v][j][2]*dp[x][k][1]%mod);    
                }
                if(dp[x][k][2]) {
                    tmp[k+j][2]=qm(tmp[j+k][2]+1ll*dp[v][j][0]*dp[x][k][2]%mod);
                    tmp[k+j+1][2]=qm(tmp[j+k+1][2]+1ll*dp[v][j][1]*dp[x][k][2]%mod);
                    tmp[k+j+1][0]=qm(tmp[k+j+1][0]+2ll*dp[v][j][1]*dp[x][k][2]%mod);
                    tmp[k+j+1][2]=qm(tmp[k+j+1][2]+2ll*dp[v][j][2]*dp[x][k][2]%mod);
                    tmp[k+j+1][0]=qm(tmp[k+j+1][0]+2ll*dp[v][j][2]*dp[x][k][2]%mod);
                }   
            }
        sz[x]+=sz[e[i].v];
        for(re int j=0;j<=sz[x];++j) 
            for(re int k=0;k<3;k++) dp[x][j][k]=tmp[j][k],tmp[j][k]=0;
    }
}
inline int C(int n,int m) {return 1ll*fac[n]*ifac[n-m]%mod*ifac[m]%mod;}
inline void solve(int id) {
    for(re int i=1;i<=n;i++) 
        for(re int j=0;j<=sz[i];++j) 
            dp[i][j][0]=dp[i][j][1]=dp[i][j][2]=0;
    for(re int i=1;i<=n;i++)head[i]=0;
    n=read();sum+=n;num=0;
    for(re int x,y,i=1;i=0;ans[i]=nw,nw=0,i--) 
        for(re int j=0;j<=i&&j<=n;j++) nw=qm(nw+1ll*ans[i-j]*g[j]%mod);
}
int main() {
    m=read();fac[0]=ifac[0]=1;ans[0]=1;
    for(re int i=1;i<=lim;i++)fac[i]=1ll*fac[i-1]*i%mod;
    ifac[lim]=ksm(fac[lim],mod-2);
    for(re int i=lim-1;i;i--)ifac[i]=1ll*ifac[i+1]*(i+1)%mod;
    for(re int i=1;i<=m;++i)solve(i==1);int cnt=0;
    for(re int i=0;i<=sum;i++)cnt=qm(cnt+1ll*ans[i]*fac[i]%mod);
    printf("%d\n",cnt);
    return 0;
}

你可能感兴趣的:([JSOI2019]神经网络)