给一棵 n n 个节点的树,现在要从树上按顺序选出 k k 条路径(可以相同),满足任意一条边要么被覆盖不超过 1 1 次,要么被覆盖恰好 k k 次,且被覆盖 k k 次的边数不能为 0 0 。问方案。
n,k≤105 n , k ≤ 10 5
先考虑暴力,我们可以枚举两个端点 u u 和 v v ,然后保证每条选出的路径都包含这两个点之间的路径。
那么现在要从这两个点为根的子树中分别选出 k k 个端点,使得这些端点到根的路径没有公共边。
设 szv s z v 表示节点 v v 的子树大小, s1,s2,...,sm s 1 , s 2 , . . . , s m 表示 u u 的所有儿子,考虑多项式 (1+szs1x)(1+szs2x)...=∑(aixi) ( 1 + s z s 1 x ) ( 1 + s z s 2 x ) . . . = ∑ ( a i x i ) ,那么选出 k k 个端点的方案就是 fv=∑axCxkx! f v = ∑ a x C k x x ! 。
而这个多项式可以用分治FFT来算,复杂度是 O(nlog2n) O ( n l o g 2 n ) 。
如果我们把原树定为有根树,然后算出每个点 v v 为根的 fv f v ,就可以通过树形dp来计算所有两端点不为祖先关系的答案。
对于那些两端点为祖先关系的答案,我们可以枚举深度较小的点 v v ,那么如果我们选了 v v 的一个儿子 u u 子树中的点作为另一个端点,节点 v v 对应的多项式就要乘上 1+(n−szv)x1+szux 1 + ( n − s z v ) x 1 + s z u x 。
显然乘或除以一个一次多项式可以在 O(degree) O ( d e g r e e ) 时间内完成,而注意到不同的 szu s z u 只有 O(n−−√) O ( n ) 种,所以复杂度就是 O(nn−−√) O ( n n ) 。
总的复杂度就是 O(nlog2n+nn−−√) O ( n l o g 2 n + n n ) 。
#include
#include
#include
#include
#include
typedef long long LL;
const int N=100005;
const int MOD=998244353;
int n,k,cnt,last[N],jc[N],ny[N],a[20][N*2],b[N],tot,rev[N*2],size[N],f[N],s[N],L,ans;
struct edge{int to,next;}e[N*2];
struct data{int x,y;}t[N];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void addedge(int u,int v)
{
e[++cnt].to=v;e[cnt].next=last[u];last[u]=cnt;
e[++cnt].to=u;e[cnt].next=last[v];last[v]=cnt;
}
bool cmp(data a,data b)
{
return a.xx;
}
int calc(int *a,int n)
{
int ans=0;
for (int i=0;i<=std::min(n,k-1);i++)
(ans+=(LL)a[i]*jc[k]%MOD*ny[k-i]%MOD)%=MOD;
return ans;
}
int ksm(int x,int y)
{
int ans=1;
while (y)
{
if (y&1) ans=(LL)ans*x%MOD;
x=(LL)x*x%MOD;y>>=1;
}
return ans;
}
void NTT(int *a,int f)
{
for (int i=0;iif (ifor (int i=1;i1)
{
int wn=ksm(3,f==1?(MOD-1)/i/2:MOD-1-(MOD-1)/i/2);
for (int j=0;j1))
{
int w=1;
for (int k=0;kint u=a[j+k],v=(LL)w*a[j+k+i]%MOD;
a[j+k]=(u+v)%MOD;a[j+k+i]=(u+MOD-v)%MOD;
w=(LL)w*wn%MOD;
}
}
}
int ny=ksm(L,MOD-2);
if (f==-1) for (int i=0;i*ny%MOD;
}
void solve(int l,int r,int d)
{
if (l==r) {a[d][0]=1;a[d][1]=t[l].x;return;}
int mid=(l+r)/2;
solve(l,mid,d+1);
for (int i=0;i<=mid-l+1;i++) a[d][i]=a[d+1][i];
solve(mid+1,r,d+1);
int lg=0;
for (L=1;L<=r-l+1;L<<=1,lg++);
for (int i=0;i>1]>>1)|((i&1)<<(lg-1));
for (int i=mid-l+2;i0;
for (int i=r-mid+1;i1][i]=0;
NTT(a[d],1);NTT(a[d+1],1);
for (int i=0;i*a[d+1][i]%MOD;
NTT(a[d],-1);
}
void dfs1(int x,int fa)
{
size[x]=1;
for (int i=last[x];i;i=e[i].next)
{
if (e[i].to==fa) continue;
dfs1(e[i].to,x);
size[x]+=size[e[i].to];
(s[x]+=s[e[i].to])%=MOD;
}
tot=0;
for (int i=last[x];i;i=e[i].next)
if (e[i].to!=fa) t[++tot].x=size[e[i].to],t[tot].y=s[e[i].to];
if (!tot) {f[x]=s[x]=1;return;}
std::sort(t+1,t+tot+1,cmp);
solve(1,tot,0);
f[x]=calc(a[0],tot);(s[x]+=f[x])%=MOD;
a[0][tot+1]=0;
for (int i=tot+1;i>=1;i--) (a[0][i]+=(LL)a[0][i-1]*(n-size[x])%MOD)%=MOD;
int w;
for (int i=1;i<=tot;i++)
{
if (t[i].x==t[i-1].x) {(ans+=(LL)w*t[i].y%MOD)%=MOD;continue;}
for (int j=0;j<=tot+1;j++) b[j]=a[0][j];
for (int j=1;j<=tot+1;j++) (b[j]+=MOD-(LL)b[j-1]*t[i].x%MOD)%=MOD;
w=calc(b,tot);
(ans+=(LL)w*t[i].y%MOD)%=MOD;
}
}
void dfs2(int x,int fa)
{
int w=0;
for (int i=last[x];i;i=e[i].next)
{
if (e[i].to==fa) continue;
dfs2(e[i].to,x);
(ans+=(LL)w*s[e[i].to]%MOD)%=MOD;
(w+=s[e[i].to])%=MOD;
}
}
int main()
{
n=read();k=read();
jc[0]=jc[1]=ny[0]=ny[1]=1;
for (int i=2;i<=k;i++) jc[i]=(LL)jc[i-1]*i%MOD,ny[i]=(LL)(MOD-MOD/i)*ny[MOD%i]%MOD;
for (int i=2;i<=k;i++) ny[i]=(LL)ny[i-1]*ny[i]%MOD;
for (int i=1;iint x=read(),y=read();
addedge(x,y);
}
dfs1(1,0);
dfs2(1,0);
printf("%d",ans);
return 0;
}