若路径并为 P ( u , v ) P(u, v) P(u,v), 形态如上图所示, 则 答案 等于以 u u u 为根的子树中选择 k k k 个起点的方案数 乘上 以 v v v 为根的子树中选择 k k k 个终点的方案数, 且要满足 k k k 个点 两两 都分别以 u u u, v v v 为 l c a lca lca
设 F [ i , j ] F[i, j] F[i,j] 表示 整棵树以 i i i 节点为根 的子树中有 j j j 个子树有球的方案数 (每个子树有一个球),
则 F [ i , j ] = F [ t o , j − 1 ] × s i z e t o F[i, j] = F[to, j-1]\times size_{to} F[i,j]=F[to,j−1]×sizeto,
当询问 u , v u, v u,v 时, 分别去除 u u u v v v 相对方向子树的贡献,
再分别枚举 u u u, v v v 本身放置几个点, 以 u u u 为例, 在 u u u 的子树里放置 i i i 个节点, 方案数为 ( k i ) ⋅ F [ u , i ] ⋅ i ! \begin{pmatrix} k \\ i \end{pmatrix} \cdot F[u, i] \cdot i! (ki)⋅F[u,i]⋅i! .
得到方案数后, 相乘即为答案 .
#include
#define reg register
const int maxn = 100005;
const int maxl = 502;
const int mod = 998244353;
int read(){
char c;
int s = 0, flag = 1;
while((c=getchar()) && !isdigit(c))
if(c == '-'){ flag = -1, c = getchar(); break ; }
while(isdigit(c)) s = s*10 + c-'0', c = getchar();
return s * flag;
}
int N;
int M;
int Q_;
int num0;
int du[maxn];
int dep[maxn];
int fac[maxn];
int inv[maxn];
int ifac[maxn];
int size[maxn];
int head[maxn];
int Fk[maxn][20];
int F[maxn][maxl];
struct Edge{ int nxt, to; } edge[maxn << 1];
void Add(int from, int to){ edge[++ num0] = (Edge){ head[from], to }; head[from] = num0; }
int C(int n, int m){ return 1ll*fac[n]*ifac[n-m]%mod*ifac[m]%mod; }
void add(int k, int s){
du[k] ++;
for(reg int i = du[k]; i >= 1; i --)
F[k][i] = (F[k][i] + 1ll*F[k][i-1]*s%mod) % mod;
}
void del(int k, int s){
for(reg int i = 1; i <= du[k]; i ++)
F[k][i] = (F[k][i] - 1ll*F[k][i-1]*s%mod + mod) % mod;
du[k] --;
}
void DFS(int k, int fa){
size[k] = 1, dep[k] = dep[fa] + 1;
for(reg int i = 1; i <= 19; i ++) Fk[k][i] = Fk[Fk[k][i-1]][i-1];
F[k][0] = 1;
for(reg int i = head[k]; i; i = edge[i].nxt){
int to = edge[i].to;
if(to == fa) continue ;
Fk[to][0] = k; DFS(to, k);
size[k] += size[to], add(k, size[to]);
}
if(fa) add(k, N - size[k]);
}
int Lca(int a, int b){
if(dep[a] < dep[b]) std::swap(a, b);
for(reg int i = 19; i >= 0; i --)
if(dep[Fk[a][i]] >= dep[b]) a = Fk[a][i];
if(a == b) return a;
for(reg int i = 19; i >= 0; i --)
if(Fk[a][i] != Fk[b][i]) a = Fk[a][i], b = Fk[b][i];
return Fk[a][0];
}
int jump(int x, int lim){
for(reg int i = 19; i >= 0; i --)
if(dep[Fk[x][i]] > lim) x = Fk[x][i];
return x;
}
int calc(int k, int s){
int res = 0, lim = std::min(s, du[k]);
for(reg int i = 0; i <= lim; i ++)
res = (res + 1ll*C(s, i)*F[k][i]%mod*fac[i]%mod) % mod;
return res;
}
int main(){
N = read(), Q_ = read(), M = read();
inv[1] = 1; for(reg int i = 2; i < maxn; i ++) inv[i] = ((-1ll*mod/i*inv[mod%i])%mod + mod) % mod;
fac[0] = 1; for(reg int i = 1; i < maxn; i ++) fac[i] = 1ll*fac[i-1]*i % mod;
ifac[0] = 1; for(reg int i = 1; i < maxn; i ++) ifac[i] = 1ll*ifac[i-1]*inv[i] % mod;
for(reg int i = 1; i < N; i ++){
int u = read(), v = read();
Add(u, v), Add(v, u);
}
DFS(1, 0);
while(Q_ --){
int u = read(), v = read(), k = read();
if(dep[u] > dep[v]) std::swap(u, v);
int lca = Lca(u, v), fuck;
if(lca == u) fuck = size[jump(v, dep[u])];
else fuck = N - size[u];
del(u, fuck), del(v, N - size[v]);
printf("%lld\n", 1ll*calc(u, k)*calc(v, k)%mod);
add(u, fuck), add(v, N - size[v]);
}
return 0;
}