组合计数训练题解

CF40E

题目链接

点击打开链接

题目解法

首先,如果 n , m n,m n,m 一奇一偶,那么答案为 0 0 0
原因是从行和列的角度分析, − 1 -1 1 个数的奇偶性不同

可以发现 k < max ⁡ { n , m } k<\max\{n,m\} k<max{n,m} 的性质很微妙,把它转化为有用的信息,即为:
n > m n>m n>m,则必有一行为空
可以发现,除了空行,其他所有行只需在满足行乘积为 − 1 -1 1 的情况下随便填,然后令空行满足列的限制即可
我们需要考虑空行在满足列的限制的情况下是否仍然满足行的限制,这个对 − 1 -1 1 的个数讨论一下不难证得

于是可以简单地计算,时间复杂度 O ( n 2 ) O(n^2) O(n2)

#include 
using namespace std;
const int N=1100;
int n,m,k,dp[N][N],pw[N];
int cov1[N],cov2[N];
int od1[N],od2[N];
inline int read(){
	int FF=0,RR=1;
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
	for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
	return FF*RR;
}
int main(){
	n=read(),m=read(),k=read();
    if((n+m)&1){ puts("0");exit(0);}
	for(int i=1;i<=k;i++){
		int x=read(),y=read(),z=read();
		cov1[x]++,cov2[y]++;
		if(z==-1) od1[x]^=1,od2[y]^=1;
	}
	int P=read();
	for(int i=1;i<=n;i++) if(cov1[i]==m&&!od1[i]){ puts("0");exit(0);}
	for(int i=1;i<=m;i++) if(cov2[i]==n&&!od2[i]){ puts("0");exit(0);}
	pw[0]=1;
	for(int i=1;i<=max(n,m);i++) pw[i]=pw[i-1]*2%P;
	if(n>m){
		for(int i=1;i<=n;i++)
			if(!cov1[i]){
				int ans=1;
				for(int j=1;j<=n;j++){
					if(i==j) continue;
					if(m-cov1[j]>0) ans=1ll*ans*pw[m-cov1[j]-1]%P;
				}
				printf("%d",ans);exit(0);
			}
	}
	else
		for(int i=1;i<=m;i++)
			if(!cov2[i]){
				int ans=1;
				for(int j=1;j<=m;j++){
					if(i==j) continue;
					if(n-cov2[j]>0) ans=1ll*ans*pw[n-cov2[j]-1]%P;
				}
				printf("%d",ans);exit(0);
			}
	return 0;
}

CF140E

题目链接

点击打开链接

题目解法

先特殊提醒一下,每一层的球不是构成一个环,而是一个序列
可以发现这一层只关心选了几个颜色,而不关心选了那些颜色
所以显然有 d p i , j dp_{i,j} dpi,j 为到第 i i i 层,第 i i i 层有 j j j 个颜色的方案数
考虑转移
容斥一下,用总的方案数 - 不合法的方案数
所以 d p i , j = m j ‾ ∗ f l [ i ] , j ∑ d p i − 1 , k − j ! ∗ d p i − 1 , j dp_{i,j}=m^{\underline{j}}*f_{l[i],j}\sum{dp_{i-1,k}}-j!*dp_{i-1,j} dpi,j=mjfl[i],jdpi1,kj!dpi1,j
其中 f i , j f_{i,j} fi,j 为长度为 i i i 序列,相邻不同,颜色总数为 j j j,且每一种颜色都染过的方案数
f i , j f_{i,j} fi,j 好求,不细讲了,直接给转移式子: f i , j = f i − 1 , j − 1 + f i − 1 , j ∗ ( j − 1 ) f_{i,j}=f_{i-1,j-1}+f_{i-1,j}*(j-1) fi,j=fi1,j1+fi1,j(j1)
因为没有规定染 m m m 个颜色中的哪 j j j 个,所以要乘 m j ‾ m^{\underline{j}} mj
减去的数中因为之前已经规定好了选哪 j j j 个,所以只需要给 j j j 个数排个次序即可(不难发现, f f f 式子中并没有规定 j j j 个颜色的顺序)
时间复杂度 O ( n 2 ) O(n^2) O(n2)

#include 
using namespace std;
const int MX=5100,N=1000100;
int n,m,P,l[N];
int fac[MX],perm[N];
int dp[2][MX],f[MX][MX];
inline int read(){
	int FF=0,RR=1;
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
	for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
	return FF*RR;
}
int main(){
	n=read(),m=read(),P=read();
	for(int i=1;i<=n;i++) l[i]=read();
	fac[0]=1;
	for(int i=1;i<MX;i++) fac[i]=1ll*fac[i-1]*i%P;
    perm[0]=1;
    for(int i=1;i<=m;i++) perm[i]=1ll*perm[i-1]*(m-i+1)%P;
	f[0][0]=1;
	for(int i=1;i<MX;i++) for(int j=1;j<=i;j++) f[i][j]=(f[i-1][j-1]+1ll*f[i-1][j]*(j-1))%P;
	int sum=0;
	for(int i=1;i<=l[1];i++) dp[1][i]=1ll*perm[i]*f[l[1]][i]%P,sum=(sum+dp[1][i])%P;
	for(int i=2;i<=n;i++){
		for(int j=1;j<=l[i];j++) dp[i&1][j]=0;
		for(int j=1;j<=l[i];j++){
			dp[i&1][j]=1ll*perm[j]%P*sum%P*f[l[i]][j]%P;
			if(j<=l[i-1]) dp[i&1][j]=(dp[i&1][j]-1ll*dp[~i&1][j]*fac[j]%P*f[l[i]][j]%P+P)%P;
		}
		sum=0;
		for(int j=1;j<=l[i];j++) sum=(sum+dp[i&1][j])%P;
	}
	printf("%d",sum);
	return 0;
}

CF482D

题目链接

点击打开链接

题目解法

显然树形 d p dp dp
考虑染的颜色只跟染的顺序的奇偶性有关
所以令 d p i , 0 / 1 dp_{i,0/1} dpi,0/1 表示在 i i i 的子树中,选了偶数还是奇数个点
考虑编号升序是好转移的,编号降序的方案和编号升序的方案一样
现在只需要计算编号升序和降序重复的方案数
不难发现,重复的条件为 ∀ p r e i − 1 ≡ s u f i + 1 \forall pre_{i-1}\equiv suf_{i+1} prei1sufi+1
可以发现这个等价于

  1. 每个子树中选的点都为偶数
  2. 每个子树中选的点都为奇数,且只有奇数个子树选到

这个可以在开一个 d p dp dp 数组 f i , 0 / 1 , 0 / 1 f_{i,0/1,0/1} fi,0/1,0/1 表示在 i i i 的子树中选出了奇数或偶数个点,所有选的子树中点数均为奇数或偶数,不难转移

时间复杂度 O ( n 2 ) O(n^2) O(n2)

#include 
using namespace std;
const int N=100100,P=1e9+7;
int n,dp[N][2],f[N][2][2];
vector<int> vec[N];
inline int read(){
	int FF=0,RR=1;
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
	for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
	return FF*RR;
}
void dfs(int u){
	dp[u][1]=1,f[u][0][0]=f[u][0][1]=1;
	for(int v:vec[u]){
		dfs(v);
        int t0=dp[u][0],t1=dp[u][1];
		dp[u][0]=(1ll*t0*dp[v][0]+1ll*t1*dp[v][1])%P;
		dp[u][1]=(1ll*t0*dp[v][1]+1ll*t1*dp[v][0])%P;
        t0=f[u][1][1],t1=f[u][0][1];
        f[u][0][0]=(f[u][0][0]+1ll*f[u][0][0]*(dp[v][0]-1))%P;//子树中选了偶数个点,所有选的子树中都选了偶数个点
        f[u][0][1]=(f[u][0][1]+1ll*t0*dp[v][1])%P;//子树中选了偶数个点,所有选的子树中都选了奇数个点
        f[u][1][1]=(f[u][1][1]+1ll*t1*dp[v][1])%P;//子树中选了奇数个点,所有选的子树中都选了奇数个点
	}
	dp[u][0]=(dp[u][0]*2+1)%P,dp[u][1]=dp[u][1]*2%P;
    dp[u][0]=(dp[u][0]-f[u][1][1]+P)%P,dp[u][1]=(dp[u][1]-f[u][0][0]+P)%P;
}
int main(){
	n=read();
	for(int i=2;i<=n;i++) vec[read()].push_back(i);
	dfs(1);
	printf("%d",(dp[1][0]+dp[1][1]-1)%P);
	return 0;
}

CF325E

题目链接

点击打开链接

题目解法

先考虑 n n n 为偶数的情况
先看一个比较妙的性质:对于 i i i i + n 2 i+\frac{n}{2} i+2n,它们能到达的两个点是一样的,因为 2 i ≡ 2 i + n 2i\equiv2i+n 2i2i+n 2 i + 1 ≡ 2 i + n + 1 2i+1\equiv 2i+n+1 2i+12i+n+1
所以考虑暂时钦定 i i i 连向 2 i 2i 2i i + n 2 i+\frac{n}{2} i+2n 连向 2 i + 1 2i+1 2i+1
考虑这样整张图会形成若干个简单环,现在需要考虑合并简单环
x x x x + n 2 x+\frac{n}{2} x+2n 不在一个简单环内,那么考虑把 x x x 连到 2 x + 1 2x+1 2x+1 x + n 2 x+\frac{n}{2} x+2n 连到 2 x 2x 2x,这样可以把 2 2 2 个环合并成一个环,这个可以用并查集维护

考虑 n n n 为奇数的情况,直觉告诉我们一定是无解的
考虑来证明它:可以发现到终点 0 0 0 的点只能是 n − 1 2 \frac{n-1}{2} 2n1,而到达 n − 1 n-1 n1 的点也只能是 n − 1 2 \frac{n-1}{2} 2n1,这就矛盾了
时间复杂度 O ( n ) O(n) O(n)

#include 
using namespace std;
const int N=100100;
int n,fa[N],rt,nxt[N];
bool vis[N];
inline int read(){
	int FF=0,RR=1;
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
	for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
	return FF*RR;
}
void dfs(int u){
	vis[u]=1,fa[u]=rt;
	if(!vis[nxt[u]]) dfs(nxt[u]);
}
int get_father(int x){ return x==fa[x]?x:fa[x]=get_father(fa[x]);}
int main(){
	n=read();
	if(n&1){ puts("-1");exit(0);}
	for(int i=0;i<n/2;i++) nxt[i]=i<<1,nxt[i+n/2]=i<<1^1;
	for(int i=0;i<n;i++) fa[i]=i;
	for(int i=0;i<n;i++) if(!vis[i]) rt=i,dfs(i);
	for(int i=0;i<n/2;i++){
		int x=get_father(i),y=get_father(i+n/2);
		if(x!=y) fa[x]=y,nxt[i]=i<<1^1,nxt[i+n/2]=i<<1;
	}
	int pos=0;
	while(nxt[pos]) printf("%d ",pos),pos=nxt[pos];
	printf("%d 0",pos);
	return 0;
}

CF896D

题目链接

点击打开链接

题目解法

先考虑 n n n 个人不包含 v i p vip vip 的情况
考虑转化,把 50 50 50 元看成 1 1 1 100 100 100 元看成 − 1 -1 1
题目就是问有多少种方案,所有的前缀和都 ≥ 0 \ge 0 0
这让我们想到了卡特兰数,于是可以考虑用推卡特兰数的方法来推到 ( x , n − x ) (x,n-x) (x,nx) 处的方案数,条件是 n − x + l ≤ x ≤ n − x + r n-x+l\le x\le n-x+r nx+lxnx+r
推断方法不细讲,大概就是对 y = x + 1 y=x+1 y=x+1 作对称,最后推出来的式子应该是 ( n x ) − ( n x + 1 ) \binom{n}{x}-\binom{n}{x+1} (xn)(x+1n)

然后考虑有 v i p vip vip 的情况,即枚举个数,最终的答案即为 ∑ i = 0 n ∑ i − x + l ≤ x ≤ i − x + r ( i x ) − ( i x + 1 ) \sum\limits_{i=0}^{n}\sum\limits_{i-x+l\le x\le i-x+r}\binom{i}{x}-\binom{i}{x+1} i=0nix+lxix+r(xi)(x+1i)
可以发现对于同一个 i i i,中间的组合数都可以消掉,所以只要算最左边和最右边的组合数即可

考虑 P P P 不是质数,我们可以把 p p p 质因数分解,然后把每个计算的数拆分成 x ∗ ∏ p k α k x*\prod p_k^{\alpha_k} xpkαk,其中 ( x , P ) = 1 (x,P)=1 (x,P)=1,这个对每个 p p p 的质因子开个桶即可,考虑 P P P 的质因子不会超过 9 个,所以可以很快维护

时间复杂度 O ( n l o g n ) O(nlogn) O(nlogn)

#include 
#define int long long
using namespace std;
const int N=100100;
int n,P,l,r;
int phi,pr[N],cnt;
int fac[N],invX[N],coef[N][35];
inline int read(){
	int FF=0,RR=1;
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
	for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
	return FF*RR;
}
int qmi(int a,int b){
    int res=1;
    for(;b;b>>=1){
        if(b&1) res=1ll*res*a%P;
        a=1ll*a*a%P;
    }
    return res;
}
void calc(){
    fac[0]=invX[0]=1,fac[1]=invX[1]=1;
    for(int i=2;i<=n;i++){
        int t=i;
        for(int j=1;j<=cnt;j++){
            coef[i][j]=coef[i-1][j];
            while(t%pr[j]==0) t/=pr[j],coef[i][j]++;
        }
        fac[i]=1ll*fac[i-1]*t%P;
        invX[i]=qmi(fac[i],phi-1);
    }
}
int binom(int x,int y){
    if(x<y) return 0;
    int res=fac[x]*invX[y]%P*invX[x-y]%P;
    for(int i=1;i<=cnt;i++) res=res*qmi(pr[i],coef[x][i]-coef[y][i]-coef[x-y][i])%P;
    return res;
}
signed main(){
	n=read(),P=read(),l=read(),r=read();
    phi=P;
    int t=P;
    for(int i=2;i*i<=t;i++)
        if(t%i==0){
            while(t%i==0) t/=i;
            pr[++cnt]=i,phi=phi/i*(i-1);
        }
    if(t>1) pr[++cnt]=t,phi=phi/t*(t-1);
    calc();
	int ans=0;
	for(int i=0;i<=n;i++){
		int lb=(l+i+1)/2;
		int rb=min((r+i)/2,i);
        if(lb<=rb) ans=(ans+(binom(i,lb)-binom(i,rb+1)+P)*binom(n,i)%P)%P;
	}
	printf("%lld",ans);
	return 0;
}

CF750G

题目链接

点击打开链接

题目解法

考虑对答案的路径分 3 种情况

  1. 一个点,方案数为 1
  2. 从祖先到子孙的路径
    令祖先的编号为 u u u,延续到下面的深度为 h h h
    则编号和为 ( 2 h + 1 − 1 ) ∗ u + ∑ i = 0 h − 1 ( 2 i − 1 ) ∗ c i (2^{h+1}-1)*u+\sum_{i=0}^{h-1} (2^i-1)*c_i (2h+11)u+i=0h1(2i1)ci,其中 c i c_i ci 为这条边是否往右走
    考虑 ∑ i = 0 h − 1 2 i ∗ c i < 2 h \sum_{i=0}^{h-1} 2^i*c_i<2^h i=0h12ici<2h,那么枚举 h h h,我们可以直接计算出对应的 u u u
    然后判断是否合法即可
  3. u u u l c a lca lca v v v 的路径( u ≠ l c a ≠ v u\neq lca\neq v u=lca=v
    令祖先的编号为 u u u,延续到下面的深度分别为 h 1 , h 2 h1,h2 h1,h2
    则编号和为 ( 2 h 1 + 1 + 2 h 2 + 1 − 2 ) ∗ u + ∑ i = 0 h 1 − 1 ( 2 i − 1 ) c i + ∑ i = 0 h 2 − 1 ( 2 i − 1 ) c i ′ (2^{h1+1}+2^{h2+1}-2)*u+\sum_{i=0}^{h1-1}(2^i-1)c_i+\sum_{i=0}^{h2-1}(2^i-1)c'_i (2h1+1+2h2+12)u+i=0h11(2i1)ci+i=0h21(2i1)ci
    同理对于相同的 h 1 , h 2 h1,h2 h1,h2,我们可以唯一确定出一个 u u u
    那么考虑计算下面有多少种路径
    考虑 2 i − 1 2^i-1 2i1 − 1 -1 1 不好计算,所以考虑枚举 c i , c i ′ c_i,c'_i ci,ci 1 1 1 的数的个数
    然后考虑 d p dp dp,令 d p i , j , 0 / 1 dp_{i,j,0/1} dpi,j,0/1 表示从低到高到了第 i i i 位,已经选了 j j j 个数,是否进位的方案数
    这样可以轻松转移,要注意 h 1 − 1 h1-1 h11 h 2 − 1 h2-1 h21 位要强制选和不选(因为必须往左和往右)

时间复杂度 O ( H 5 ) O(H^5) O(H5) H H H 为树高

#include 
#define int long long
using namespace std;
int s,bs[60],dp[60][120][2];
inline int read(){
	int FF=0,RR=1;
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
	for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
	return FF*RR;
}
int check(int left,int h){
	for(int i=h;i>=0;i--) if(left>=bs[i]-1) left-=bs[i]-1;
	return !left;
}
int calc(int h1,int h2,int left){
	int res=0;
	for(int i=1;i<=h1+h2;i++){
		int curleft=left+i;
		bool flg=1;
		for(int j=max(h1,h2)+2;j<=50;j++) if(curleft>>j&1){ flg=0;break;}
		if(!flg||(curleft&1)) continue;
        memset(dp,0,sizeof(dp));
        dp[0][0][0]=1;
		for(int j=1;j<=max(h1,h2)+1;j++)
            for(int k=0;k<=i;k++)
                for(int z=0;z<=1;z++){
                    if(!dp[j-1][k][z]) continue;
                    for(int x=0;x<=(j<h1);x++)
                        for(int y=(j==h2);y<=(j<=h2);y++){
                            int curbit=(x+y+z)%2,ne=(x+y+z)/2;
                            if(curbit!=(curleft>>j&1ll)) continue;
                            dp[j][k+x+y][ne]+=dp[j-1][k][z];
                        }
                }
		res+=dp[max(h1,h2)+1][i][0];
	}
	return res;
}
signed main(){
	s=read();
	bs[0]=1;
	for(int i=1;i<=51;i++) bs[i]=bs[i-1]*2;
	int ans=1;//选n
	//端点u,v为祖孙关系
	//枚举相差层数
	for(int h=1;h<=50;h++){
		int u=s/(bs[h+1]-1);//头
		if(!u) continue;
		if(check(s-u*(bs[h+1]-1),h)) ans++;
	}
	//端点u,v不是祖孙关系
	for(int h1=1;h1<=50;h1++)
		for(int h2=1;h2<=50;h2++){
			int u=s/(bs[h1+1]+bs[h2+1]-3);
			if(!u) continue;
			ans+=calc(h1,h2,s-u*(bs[h1+1]+bs[h2+1]-3));
		}
	printf("%lld",ans);
	return 0;
}
/*
dp[i][j][0/1]:到第i位,选了j个数,是否有进位的方案数
*/

CF981H

题目链接

点击打开链接

题目解法

感觉挺好的一道题,除了多项式
显然 k k k 条路径共同经过的部分一定是一段路径
然后考虑分类讨论路径端点 u , v u,v u,v 的关系

  1. u , v u,v u,v 不是祖先关系
    答案即为 s u × s v s_u\times s_v su×sv
    其中 s u s_u su 表示在 u u u 的子树中选出 k k k 个点,并组成排列的方案数
    考虑枚举 u u u 节点的选择个数,因为只有它是可以经过不限制次的, u u u 的儿子的子树中只能选出 1 1 1 个点
    所以考虑除了 u u u i i i 个点且不考虑顺序的方案数即为 [ x i ] ∏ ( 1 + s i z v x ) [x^i]\prod (1+siz_vx) [xi](1+sizvx)
    因为是要求出这个多项式卷积的每一项的,所以考虑用分治 n t t ntt ntt 求出
    然后 s u s_u su 即为 ∑ i = 0 k ( k i ) × i ! × [ x i ] ∏ ( 1 + s i z v x ) \sum\limits_{i=0}^{k}\binom{k}{i}\times i!\times [x^i]\prod (1+siz_vx) i=0k(ik)×i!×[xi](1+sizvx)
    然后所有情况的答案即为 ( ∑ s u ) 2 (\sum s_u)^2 (su)2
    还需要减去 u u u u u u 自己的代价和 u u u u u u 的子树内的点的代价
    这用一个子树和不难求出
  2. u , v u,v u,v 是祖先关系
    这种情况稍微难办一些
    不考虑 u u u 选出 i i i 个点的方案数为 [ x i ] ∑ v ∈ s o n ( u ) ( 1 + ( n − s i z u ) x ) × ∏ w ≠ v ( 1 + s i z w x ) [x^i]\sum\limits_{v\in son(u)}{(1+(n-siz_u)x)\times \prod\limits_{w\neq v} (1+siz_wx)} [xi]vson(u)(1+(nsizu)x)×w=v(1+sizwx)
    这个也可以用分治 n t t ntt ntt 求出
    计算答案的式子和第一类情况一样

记得特判一条路径的情况
因为度数之和为 n n n,所以时间复杂度是分治 n t t ntt ntt O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)

#include 
using namespace std;
const int N=300000,P=998244353;
const int G=3,Gi=332748118;
typedef pair<vector<int>,vector<int> > pvv;
typedef pair<int,int> pii;
int n,k,ans,ans2;
int rev[N],tot;
int f[N],s[N];
int siz[N],fac[N],inv[N];
vector<pii> son;
vector<int> vec[N],Gr[N];
inline int read(){
	int FF=0,RR=1;
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
	for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
	return FF*RR;
}
inline void inc(int &x,int y){
    x+=y;
    if(x>=P) x-=P;
}
int qmi(int a,int b){
	int res=1;
	for(;b;b>>=1){
		if(b&1) res=1ll*res*a%P;
		a=1ll*a*a%P; 
	}
	return res;
}
vector<int> ntt(vector<int> a,bool neg){
	for(int i=0;i<tot;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
	for(int mid=1;mid<tot;mid<<=1){
		int g1=qmi(neg?Gi:G,(P-1)/(mid<<1));
		for(int i=0;i<tot;i+=mid<<1){
			int gk=1;
			for(int j=0;j<mid;j++,gk=1ll*gk*g1%P){
				int x=a[i+j],y=1ll*gk*a[i+j+mid]%P;
				a[i+j]=x+y>=P?x+y-P:x+y,a[i+j+mid]=x-y<0?x-y+P:x-y;
			}
		}
	}
    return a;
}
vector<int> mul(vector<int> A,vector<int> B){
    int bit=0;
    while((1<<bit)<=A.size()+B.size()) bit++;
    tot=1<<bit;
	for(int i=0;i<tot;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    A.resize(tot),B.resize(tot);
    A=ntt(A,0),B=ntt(B,0);
    for(int i=0;i<tot;i++) A[i]=1ll*A[i]*B[i]%P;
    A=ntt(A,1);
    int iv=qmi(tot,P-2);
    for(int i=0;i<tot;i++) A[i]=1ll*A[i]*iv%P;
    while(!A.back()) A.pop_back();
    return A;
}
vector<int> add(vector<int> A,vector<int> B)
    if(A.size()<B.size()) swap(A,B);
    for(int i=0;i<B.size();i++) inc(A[i],B[i]);
    return A;
}
pvv pushup(pvv L,pvv R){
    vector<int> f=mul(L.first,R.first);
    vector<int> g=add(mul(L.first,R.second),mul(L.second,R.first));
    return make_pair(f,g);
}
pvv solve(int l,int r){
	if(l==r) return make_pair(vector<int>{1,son[l].first},vector<int>{son[l].second});
	int mid=(l+r)>>1;
	return pushup(solve(l,mid),solve(mid+1,r));
}
void predfs(int u,int fa){
    siz[u]=1;
    for(int v:vec[u])
        if(v!=fa){
            predfs(v,u);
            siz[u]+=siz[v],Gr[u].push_back(v);
        }
}
int C(int a,int b){ return 1ll*fac[a]*inv[b]%P*inv[a-b]%P;}
void dfs(int u){
	for(int v:Gr[u]) dfs(v),inc(s[u],s[v]);
	if(!Gr[u].size()) f[u]=s[u]=1;
	else{
		son.clear();
		for(int v:Gr[u]) son.push_back({siz[v],s[v]});
		pvv ret=solve(0,son.size()-1);
        vector<int> poly1=ret.first,poly2=ret.second;
		for(int i=0;i<min((int)poly1.size(),k+1);i++) inc(f[u],1ll*poly1[i]*C(k,i)%P*fac[i]%P);
        inc(s[u],f[u]);
        poly2=mul(vector<int>{1,n-siz[u]},poly2);
        for(int i=0;i<min((int)poly2.size(),k+1);i++) inc(ans2,1ll*poly2[i]*C(k,i)%P*fac[i]%P);
	}
    inc(ans,P-1ll*f[u]*(2*s[u]%P-f[u]+P)%P);
}
int main(){
	n=read(),k=read();
    if(k==1){ printf("%d\n",(1ll*n*(n-1)/2)%P);exit(0);
    fac[0]=1;
    for(int i=1;i<=n+1;i++) fac[i]=1ll*fac[i-1]*i%P;
    inv[n+1]=qmi(fac[n+1],P-2);
    for(int i=n;i>=0;i--) inv[i]=1ll*inv[i+1]*(i+1)%P;
	for(int i=1;i<n;i++){
		int x=read(),y=read();
		vec[x].push_back(y),vec[y].push_back(x);
	}
	predfs(1,-1),dfs(1);
    inc(ans,1ll*s[1]*s[1]%P),ans=1ll*ans*qmi(2,P-2)%P;inc(ans,ans2);
    printf("%d\n",ans);
	fprintf(stderr,"%d ms\n",int(1e3*clock()/CLOCKS_PER_SEC));
	return 0;
}

CF698F

题目链接

点击打开链接

题目解法

不是很想写,一些结论我也说不清楚也不会证,自己看题解吧
有时间可能会补,先给出代码:

#include 
using namespace std;
const int N=1000100,P=1e9+7;
int n,p[N],dy1[N],dy2[N],c1[N],c2[N];
int mulp[N],fac[N],chu[N];
int pr[N],v[N],cnt;
vector<int> factor[N];
inline int read(){
	int FF=0,RR=1;
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
	for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
	return FF*RR;
}
int main(){
	n=read();
    for(int i=1;i<=n;i++) mulp[i]=1;
    c1[1]++;
    factor[1].push_back(1),chu[1]=1;
	for(int i=2;i<=n;i++){
		if(!v[i]){
			pr[++cnt]=i,chu[i]=n/i,c1[n/i]++;
			for(int j=i;j<=n;j+=i) factor[j].push_back(i),mulp[j]*=i;
		}
		for(int j=1;j<=cnt&&pr[j]<=n/i;j++){
			v[pr[j]*i]=pr[j];
			if(v[i]==pr[j]) break;
		}
	}
    for(int i=1;i<=n;i++) c2[mulp[i]]++;
	for(int i=1;i<=n;i++) p[i]=read();
	for(int i=1;i<=n;i++){
		if(!p[i]) continue;
		if(factor[i].size()!=factor[p[i]].size()){ puts("0");exit(0);}
		for(int j=0;j<factor[i].size();j++) if(chu[factor[i][j]]!=chu[factor[p[i]][j]]){ puts("0");exit(0);}
		int mxfac1=factor[i].back(),mxfac2=factor[p[i]].back();
		if(dy1[mxfac1]&&dy1[mxfac1]!=mxfac2){ puts("0");exit(0);}
		if(dy2[mxfac2]&&dy2[mxfac2]!=mxfac1){ puts("0");exit(0);}
		if(!dy1[mxfac1]&&!dy2[mxfac2]) c1[chu[mxfac1]]--;
		dy1[mxfac1]=mxfac2,dy2[mxfac2]=mxfac1;
		c2[mulp[p[i]]]--;
	}
    fac[0]=1;
    for(int i=1;i<=n;i++) fac[i]=1ll*fac[i-1]*i%P;
    int ans=1;
    for(int i=1;i<=n;i++) ans=1ll*ans*fac[c1[i]]%P*fac[c2[i]]%P;
    printf("%d",ans);
	return 0;
}

你可能感兴趣的:(其他,算法)