[传送门]
题目即求所有的三元组,相对大小关系同 $p_1,p_2,p_3$。
题解说都很清楚,这里写一下过程整理一下思路。
如果我们枚举中间这个元素,那么就是统计子树内外有多少个大于这个数和小于这个数的个数。
假设$a_1$,$a_3$的$LCA$不是$a_2$,那么就是一个在$a_2$子树内一个在子树外。
设$S_u$, $B_u$分别为$u$子树内小于$u$和大于$u$的节点个数,$S_t$, $B_t$分别为整棵树小于$u$和大于$u$的节点个数。
当$p_2 = 1$时,对答案的贡献为$B_u \times (B_t - B_u)$
当$p_2 = 2$时,对答案的贡献为$B_u \times (S_t - S_u) + S_u \times (B_t - B_u)$
当$p_2 = 1$时,对答案的贡献为$S_u \times (S_t - S_u)$
当$a_1$和$a_3$的$LCA$是$a_2$时,枚举$u$的子节点。
设$S_v$为$u$的子节点$v$的子树中,小于$u$的节点个数,$B_v$为$u$的子节点$v$的子树中,大于$u$的节点个数。
当$p_2 = 1$时,对答案的贡献为$B_v \times (B_u - B_v)$
当$p_2 = 2$时,对答案的贡献为$S_v \times (B_u - B_v) + B_v \times (S_u - S_v)$
当$p_2 = 1$时,对答案的贡献为$S_v \times (S_u - S_v)$
这部分的贡献会被算两次,所以最后得除以二。
查子树内都多少节点大于/小于该节点的,题解用了dfs序+树状数组,但是对于第二部分求答案很麻烦。所以我用了线段树合并,不过会卡常,那么只求一遍小于的,再用子树的size减去这个值就得到大于的。
#include#define pii pair #define ll long long using namespace std; namespace IO { const int MAXSIZE = 1 << 20; char buf[MAXSIZE], *p1, *p2; #define gc() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, MAXSIZE, stdin), p1 == p2) ? EOF : *p1++) template<class T> inline void read(T &x) { x = 0; T f = 1; char c = gc(); while (!isdigit(c)) { if (c == '-') f = -1; c = gc(); } while (isdigit(c)) x = x * 10 + (c ^ 48), c = gc(); x *= f; } char pbuf[1 << 20], *pp = pbuf; inline void push(const char &c) { if (pp - pbuf == 1 << 20) fwrite(pbuf, 1, 1 << 20, stdout), pp = pbuf; *pp++ = c; } inline void write(int x) { static int sta[35]; int top = 0; do { sta[top++] = x % 10, x /= 10; } while (x); while (top) push(sta[--top] + '0'); } } using namespace IO; const int N = 1e5 + 7; int root[N], p[3], n, sz[N]; vector<int> G[N]; struct Seg { struct Tree { int lp, rp, sum; } tree[N * 50]; int tol; inline void clear() { tol = 0; memset(tree, 0, sizeof(tree)); } inline void pushup(int p) { tree[p].sum = tree[tree[p].lp].sum + tree[tree[p].rp].sum; } void update(int &p, int l, int r, int pos) { if (!p) p = ++tol; if (l == r) { tree[p].sum++; return; } int mid = l + r >> 1; if (pos <= mid) update(tree[p].lp, l, mid, pos); else update(tree[p].rp, mid + 1, r, pos); pushup(p); } int merge(int p, int q, int l, int r) { if (!p || !q) return p | q; int u = ++tol; int mid = l + r >> 1; tree[u].lp = merge(tree[p].lp, tree[q].lp, l, mid); tree[u].rp = merge(tree[p].rp, tree[q].rp, mid + 1, r); pushup(u); return u; } int query(int p, int l, int r, int x, int y) { if (x > y) return 0; if (!p) return 0; if (x <= l && y >= r) return tree[p].sum; int mid = l + r >> 1; int ans = 0; if (x <= mid) ans += query(tree[p].lp, l, mid, x, y); if (y > mid) ans += query(tree[p].rp, mid + 1, r, x, y); return ans; } } seg; ll ans; inline void init() { ans = 0; seg.clear(); for (int i = 1; i <= n; i++) G[i].clear(), root[i] = 0; } void dfs(int u, int fa) { seg.update(root[u], 1, n, u); vector vec; int su = 0, bu = 0; int st = u - 1, bt = n - u; sz[u] = 1; for (auto v: G[u]) { if (v == fa) continue; dfs(v, u); sz[u] += sz[v]; int sv = seg.query(root[v], 1, n, 1, u - 1), bv = sz[v] - sv; su += sv, bu += bv; vec.push_back(pii(sv, bv)); root[u] = seg.merge(root[u], root[v], 1, n); } if (p[1] == 1) ans += 1LL * bu * (bt - bu); else if (p[1] == 2) ans += 1LL * bu * (st - su) + 1LL * su * (bt - bu); else ans += 1LL * su * (st - su); ll res = 0; for (auto pp: vec) { if (p[1] == 1) res += 1LL * pp.second * (bu - pp.second); else if (p[1] == 2) res += 1LL * pp.first * (bu - pp.second) + 1LL * pp.second * (su - pp.first); else res += 1LL * pp.first * (su - pp.first); } ans += res / 2; } int main() { // freopen("in.txt", "r", stdin); int T; read(T); while (T--) { read(n); init(); for (int i = 0; i < 3; i++) read(p[i]); for (int i = 1; i < n; i++) { int u, v; read(u), read(v); G[u].push_back(v); G[v].push_back(u); } dfs(1, 0); printf("%lld\n", ans); } return 0; }