Address
Solution
- 先考虑如果以某个点(下面定为 1 1 1 )为根时,如果所有的限制二元组 ( u , v ) (u,v) (u,v) 都满足 u u u 是 v v v 的父亲(即 u u u 向 v v v 连边构成外向树)怎么做
- 显然,对于任意一个点 u u u , u u u 必须是在 u u u 的子树内第一个被翻到的
- 如果每个点的 W W W 已经确定,则这个概率就等于
- ∏ u = 1 n W u ∑ v ∈ s u b t r e e ( u ) W v \prod_{u=1}^n\frac{W_u}{\sum_{v\in subtree(u)W_v}} u=1∏n∑v∈subtree(u)WvWu
- 容易设计一个 DP 状态
- f [ u ] [ i ] f[u][i] f[u][i] 表示 u u u 的子树内的点全部满足条件,并且这些点的 W W W 之和为 i i i 的概率
- 如果只有一个点,则
- f [ u ] [ i ] = p u , i p u , 1 + p u , 2 + p u , 3 f[u][i]=\frac{p_{u,i}}{p_{u,1}+p_{u,2}+p_{u,3}} f[u][i]=pu,1+pu,2+pu,3pu,i
- 边 ( u , v ) (u,v) (u,v) 合并两个连通块
- f [ u ] [ i + j ] + = i i + j × f ′ [ u ] [ i ] × f [ v ] [ j ] f[u][i+j]+=\frac i{i+j}\times f'[u][i]\times f[v][j] f[u][i+j]+=i+ji×f′[u][i]×f[v][j]
- 这里 i i + j \frac i{i+j} i+ji 表示 u u u 子树内的 W W W 之和由 i i i 变成 i + j i+j i+j 时对应概率的分母也需要变化
- 答案显然为 ∑ i = 1 3 n f [ 1 ] [ i ] \sum_{i=1}^{3n}f[1][i] ∑i=13nf[1][i]
- 根据某经典的树形背包复杂度分析,以上算法复杂度为 O ( n 2 ) O(n^2) O(n2)
- 回到原问题,不是外向树的情况,考虑容斥
- 也就是说,每条反向边,我们有两种处置方法
- (1)不计这条边的限制,即把这条边删掉
- (2)强行让这条边的限制变为正向
- 对每条反向边进行这两种处理后共 2 反 向 边 条 数 2^{反向边条数} 2反向边条数 种情况
- 在每种情况下,对分离出的每个连通块求一下概率并相乘
- 如果这种情况种强心变为正向的反向边有偶数条则计入答案,否则从答案中扣除
- 这样我们有了一个 O ( 2 反 向 边 条 数 × n 2 ) O(2^{反向边条数}\times n^2) O(2反向边条数×n2) 的
优秀做法
- 我们考虑把容斥的过程放进 DP 里
- f [ u ] [ i ] f[u][i] f[u][i] 表示 i i i 的子树,把子树内的反向边进行处理的所有情况下,子树内所有连通块符合要求,并且 u u u 所在连通块的 W W W 之和为 i i i 的概率,偶数条反向边进行(2)处理则计入否则扣除得到的结果
- 边界相同
- 如果 ( u , v ) (u,v) (u,v) 不是反向边,还是一样
- f [ u ] [ i + j ] + = i i + j × f ′ [ u ] [ i ] × f [ v ] [ j ] f[u][i+j]+=\frac i{i+j}\times f'[u][i]\times f[v][j] f[u][i+j]+=i+ji×f′[u][i]×f[v][j]
- 否则
- f [ u ] [ i ] + = f ′ [ u ] [ i ] × f [ v ] [ j ] f[u][i]+=f'[u][i]\times f[v][j] f[u][i]+=f′[u][i]×f[v][j]
- f [ u ] [ i + j ] − = i i + j × f ′ [ u ] [ i ] × f [ v ] [ j ] f[u][i+j]-=\frac i{i+j}\times f'[u][i]\times f[v][j] f[u][i+j]−=i+ji×f′[u][i]×f[v][j]
- 答案还是
- ∑ i = 1 3 n f [ 1 ] [ i ] \sum_{i=1}^{3n}f[1][i] i=1∑3nf[1][i]
- O ( n 2 ) O(n^2) O(n2)
Code
#include
template <class T>
inline void read(T &res)
{
res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
if (bo) res = ~res + 1;
}
const int N = 1005, L = 2005, M = 3005, E = 3e6 + 5, ZZQ = 998244353;
int n, a[N][4], inv[E], ecnt, nxt[L], adj[N], go[L], col[L], f[N][M], sze[N],
tmp[M], ans;
void add_edge(int u, int v)
{
nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v; col[ecnt] = 0;
nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u; col[ecnt] = 1;
}
void dfs(int u, int fu)
{
sze[u] = 1;
int fr = inv[a[u][1] + a[u][2] + a[u][3]];
for (int i = 1; i <= 3; i++)
f[u][i] = 1ll * i * a[u][i] * fr % ZZQ;
for (int e = adj[u], v; e; e = nxt[e])
{
if ((v = go[e]) == fu) continue;
dfs(v, u);
for (int i = 1; i <= (sze[u] + sze[v]) * 3; i++) tmp[i] = 0;
for (int i = 1; i <= sze[u] * 3; i++)
for (int j = 1; j <= sze[v] * 3; j++)
{
int delta = 1ll * f[u][i] * f[v][j] % ZZQ;
if (col[e]) tmp[i + j] = (tmp[i + j] - delta + ZZQ) % ZZQ,
tmp[i] = (tmp[i] + delta) % ZZQ;
else tmp[i + j] = (tmp[i + j] + delta) % ZZQ;
}
sze[u] += sze[v];
for (int i = 1; i <= sze[u] * 3; i++) f[u][i] = tmp[i];
}
for (int i = 1; i <= sze[u] * 3; i++)
f[u][i] = 1ll * f[u][i] * inv[i] % ZZQ;
}
int main()
{
int x, y;
read(n);
for (int i = 1; i <= n; i++)
for (int j = 1; j <= 3; j++)
read(a[i][j]);
for (int i = 1; i < n; i++)
read(x), read(y), add_edge(x, y);
inv[1] = 1;
for (int i = 2; i <= 3000000; i++)
inv[i] = 1ll * (ZZQ - ZZQ / i) * inv[ZZQ % i] % ZZQ;
dfs(1, 0);
for (int i = 1; i <= n * 3; i++)
ans = (ans + f[1][i]) % ZZQ;
return std::cout << ans << std::endl, 0;
}
```