题目链接:http://codeforces.com/gym/102012/problem/G
题目大意:有一棵n个结点的树,现在给出m条树上的路径。现在要从这m条路径中选出k条路径,使得这k条路径至少有一个公共交点,问你总共有多少种方案数。
题目思路:(今年徐州现场的银牌题,我们队肝到最后也没能肝出来,错失了银牌。。。QAQ,当时忘了一个重要的性质,导致正思路都错了。还是太菜了)
感慨一下,继续分析题目。
解决这个题,需要用到一个重要的性质:一个树上任意两条路径如果有交点的话,那么这些交点中肯定有一个为两条路径中的一条路径两端点的lca。
有了这个性质的话,我们可以对通过枚举路径的交点来求答案。
对于每个节点,我们假设通过这个节点的路径有M条,以这个点为LCA且通过这个节点的路径有N条。
那么在这个节点对答案的贡献为:。这个式子计算出来的是,从通过这个节点的路径中选出k条路径,且至少有一条路径的LCA为这个节点的方案数,这样选的话就不会出现重复选的情况了,因为至少有一条路径以该节点为LCA,在以其他点为交点的时候就不会重复计算了。
而通过某个结点的路径数我们可以通过树上差分计算,假设通过u这个节点的路径为sum[u]。那么在更新路径[u,v]的时候,我们就令sum[u]++,sum[v]++,sum[lca(u,v)]--,sum[fa[lca(u,v)]]--。接着再用dfs一遍即可。
具体实现看代码:
#include
#define fi first
#define se second
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define pb push_back
#define MP make_pair
#define lowbit(x) x&-x
#define clr(a) memset(a,0,sizeof(a))
#define _INF(a) memset(a,0x3f,sizeof(a))
#define FIN freopen("in.txt","r",stdin)
#define IOS ios::sync_with_stdio(false)
#define fuck(x) cout<<"["<<#x<<" "<<(x)<<"]"<pii;
typedef pairpll;
const int MX = 3e5 + 5;
const int mod = 1e9 + 7;
int n, m, k;
struct edge {int v, w, nxt;} E[MX << 1];
int head[MX], tot;
int dep[MX], ST[MX][20];
void add_edge(int u, int v) {
E[tot].v = v; E[tot].nxt = head[u];
head[u] = tot++;
}
void dfs(int u, int d, int fa) {
dep[u] = d; ST[u][0] = fa;
for (int i = head[u]; ~i; i = E[i].nxt) {
int v = E[i].v;
if (v == fa) continue;
dfs(v, d + 1, u);
}
}
void pre_solve() {
dfs(1, 0, 1);
for (int i = 1; i < 20; i++) {
for (int j = 1; j <= n; j++) {
ST[j][i] = ST[ST[j][i - 1]][i - 1];
}
}
}
int LCA(int u, int v) {
while (dep[u] != dep[v]) {
if (dep[u] < dep[v]) swap(u, v);
int d = dep[u] - dep[v];
for (int i = 0; i < 20; i++)
if (d >> i & 1)u = ST[u][i];
}
if (u == v) return u;
for (int i = 19; i >= 0; i--) {
if (ST[u][i] != ST[v][i]) {
u = ST[u][i];
v = ST[v][i];
}
}
return ST[u][0];
}
int sum[MX], lca_sum[MX];
void solve(int u, int fa) {
for (int i = head[u]; ~i; i = E[i].nxt) {
int v = E[i].v;
if (v == fa) continue;
solve(v, u);
sum[u] += sum[v];
}
}
ll f[MX], inv[MX];
ll qpow(ll a, ll b) {
ll res = 1;
while (b) {
if (b & 1) res = (res * a) % mod;
a = (a * a) % mod;
b >>= 1;
}
return res;
}
void init() {
f[1] = 1;
for (int i = 2; i < MX; i++) f[i] = (f[i - 1] * i) % mod;
inv[MX - 1] = qpow(f[MX - 1], mod - 2);
for (int i = MX - 2; i >= 1; i--) inv[i] = (inv[i + 1] * (i + 1)) % mod;
}
ll C(int n, int m) {
if (n < 0 || m < 0 || m > n) return 0;
if (m == 0 || m == n) return 1;
return f[n] * inv[n - m] % mod * inv[m] % mod;
}
int main() {
// FIN;
init();
int T; cin >> T;
while (T--) {
scanf("%d%d%d", &n, &m, &k);
for (int i = 1; i <= n; i++) head[i] = -1, sum[i] = lca_sum[i] = 0;
tot = 0;
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
add_edge(u, v); add_edge(v, u);
}
pre_solve();
for (int i = 1; i <= m; i++) {
int u, v;
scanf("%d%d", &u, &v);
int lca = LCA(u, v); lca_sum[lca]++;
sum[u]++; sum[v]++;
sum[lca]--;
if (lca != 1) sum[ST[lca][0]]--;
}
solve(1, 0);
ll ans = 0;
for (int i = 1; i <= n; i++)
ans = (ans % mod + ((C(sum[i], k) - C(sum[i] - lca_sum[i], k)) % mod + mod) % mod) % mod;
printf("%lld\n", ans);
}
return 0;
}