给一棵 n n n 个点的有根树和 m m m 条祖先-后代链,要求给每条边赋值 0 0 0 或 1 1 1,问有多少种方案满足每条链上至少有一条边的值为 1 1 1。
n , m ≤ 5 ∗ 1 0 5 n,m\le 5*10^5 n,m≤5∗105
考虑容斥。强制让若干条链不满足,贡献就是链上的边只能取 0 0 0,其余边的值可以随便取的方案数,容斥系数为 ( − 1 ) k (-1)^k (−1)k,其中 k k k 为选择的链数量。
树形dp。令 f i , j f_{i,j} fi,j 表示以 i i i 为根的子树,从 i i i 到深度为 j j j 的祖先路径上的每条边都要赋值为 0 0 0 的方案数。加入儿子 t o to to 的子树时,先让所有 f t o , j ( j ≥ d e p t o ) f_{to,j}(j\ge dep_{to}) fto,j(j≥depto) 乘以 2 2 2,表示 i i i 到 t o to to 这条边可以随便取,然后转移为 f x , j ∗ f t o , k → f x , min ( j , k ) f_{x,j}*f_{to,k}\to f_{x,\min(j,k)} fx,j∗fto,k→fx,min(j,k)
注意到若只有 s s s 条链的下端点位于 i i i 的子树中,则 f i , j ≠ 0 f_{i,j}\neq 0 fi,j=0 的 j j j 只有 O ( s ) O(s) O(s) 种取值。用线段树维护这 O ( s ) O(s) O(s) 种取值,就可以每次 O ( log n ) O(\log n) O(logn) 实现区间乘法,同时在线段树合并的过程中完成转移。总的时间复杂度为 O ( n log n ) O(n\log n) O(nlogn)。
#include
using namespace std;
typedef long long LL;
const int N = 500005;
const int MOD = 998244353;
int n, m, lin[N], dep[N], sz, rt[N];
vector<int> e[N];
struct tree{int l, r, fa, s, tag;}t[N * 60];
void dfs(int x, int fa)
{
dep[x] = dep[fa] + 1;
for (int to : e[x]) if (to != fa) dfs(to, x);
}
int newnode() {t[++sz].tag = 1; return sz;}
void mark(int d, int w)
{
t[d].s = (LL)t[d].s * w % MOD;
t[d].tag = (LL)t[d].tag * w % MOD;
}
void pushdown(int d)
{
if (t[d].tag == 1) return;
int w = t[d].tag; t[d].tag = 1;
mark(t[d].l, w); mark(t[d].r, w);
}
void ins(int & d, int l, int r, int x, int y)
{
if (!d) d = newnode();
(t[d].s += y) %= MOD;
if (l == r) return;
int mid = (l + r) / 2;
if (x <= mid) ins(t[d].l, l, mid, x, y);
else ins(t[d].r, mid + 1, r, x, y);
}
void mul(int d, int l, int r, int x, int y)
{
if (!d) return;
if (x <= l && r <= y) {mark(d, 2); return;}
int mid = (l + r) / 2;
pushdown(d);
if (x <= mid) mul(t[d].l, l, mid, x, y);
if (y > mid) mul(t[d].r, mid + 1, r, x, y);
t[d].s = (t[t[d].l].s + t[t[d].r].s) % MOD;
}
int merge(int x, int y, int l, int r, int s1, int s2)
{
if (!x || !y) {mark(x, s2); mark(y, s1); return x ^ y;}
if (l == r) {t[x].s = ((LL)t[x].s * s2 + (LL)t[y].s * s1 + (LL)t[x].s * t[y].s) % MOD; return x;}
pushdown(x); pushdown(y);
int mid = (l + r) / 2;
t[x].l = merge(t[x].l, t[y].l, l, mid, (s1 + t[t[x].r].s) % MOD, (s2 + t[t[y].r].s) % MOD);
t[x].r = merge(t[x].r, t[y].r, mid + 1, r, s1, s2);
t[x].s = (t[t[x].l].s + t[t[x].r].s) % MOD;
return x;
}
void solve(int x, int fa)
{
if (lin[x]) ins(rt[x], 1, n, lin[x], MOD - 1);
ins(rt[x], 1, n, n, 1);
for (int to : e[x]) if (to != fa)
{
solve(to, x);
mul(rt[to], 1, n, dep[to], n);
rt[x] = merge(rt[x], rt[to], 1, n, 0, 0);
}
}
int main()
{
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int x, y; scanf("%d%d", &x, &y);
e[x].push_back(y); e[y].push_back(x);
}
dfs(1, 0);
scanf("%d", &m);
for (int i = 1; i <= m; i++)
{
int x, y; scanf("%d%d", &x, &y);
lin[y] = max(lin[y], dep[x]);
}
solve(1, 0);
printf("%d\n", t[rt[1]].s);
return 0;
}