离线每个询问,然后做树分治。。
#include <bits/stdc++.h> using namespace std; typedef long long LL; #define lowbit(x) (x&(-x)) #define pii pair<int, int> #define mp(x, y) make_pair(x, y) const int maxn = 100005; const int maxm = 200005; const int INF = 0x3f3f3f3f; struct Edge { int v; Edge *next; }*H[maxn], *edges, E[maxm]; vector<pii> q[maxn], dis1, dis2, dis; bool done[maxn]; int size[maxn]; int res[maxn]; int mx[maxn]; int a[maxn]; int tree[maxn]; int tree1[maxn]; int tree2[maxn]; int n, m, root, nsize; void addedges(int u, int v) { edges->v = v; edges->next = H[u]; H[u] = edges++; } void init() { edges = E; memset(H, 0, sizeof H); memset(res, 0, sizeof res); memset(done, 0, sizeof done); } void getroot(int u, int fa) { mx[u] = 0, size[u] = 1; for(Edge *e = H[u]; e; e = e->next) if(!done[e->v] && e->v != fa) { int v = e->v; getroot(v, u); size[u] += size[v]; mx[u] = max(mx[u], size[v]); } mx[u] = max(mx[u], nsize - size[u]); if(mx[u] < mx[root]) root = u; } void add(int x, int v, int tree[]) { x++; for(int i = x; i <= n + 1; i += lowbit(i)) tree[i] += v; } int sum(int x, int tree[]) { x++; int ans = 0; for(int i = x; i > 0; i -= lowbit(i)) ans += tree[i]; return ans; } void dfs(int u, int fa, int dep, int flag) { if(flag == 0) dis.push_back(mp(dep, u)); if(flag == 1) dis1.push_back(mp(dep, u)); if(flag == 2) dis2.push_back(mp(dep, u)); for(Edge *e = H[u]; e; e = e->next) if(e->v != fa && !done[e->v]) { int v = e->v; if(flag == 0) { if(a[u] == a[v]) dfs(v, u, dep + 1, 0); if(a[u] > a[v]) dfs(v, u, dep + 1, 1); if(a[u] < a[v]) dfs(v, u, dep + 1, 2); } else if(flag == 1) { if(a[u] >= a[v]) dfs(v, u, dep + 1, 1); } else { if(a[u] <= a[v]) dfs(v, u, dep + 1, 2); } } } void solve(int u) { done[u] = true; dis.clear(); dis1.clear(); dis2.clear(); dfs(u, u, 0, 0); for(int i = 0; i < dis.size(); i++) add(dis[i].first, 1, tree); for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, 1, tree1); for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, 1, tree2); for(int i = 0; i < q[u].size(); i++) { int t = 0, d = q[u][i].first, id = q[u][i].second; t = sum(d, tree) + sum(d, tree1) + sum(d, tree2); res[id] += t; } for(Edge *e = H[u]; e; e = e->next) if(!done[e->v]) { int v = e->v; dis.clear(); dis1.clear(); dis2.clear(); if(a[u] == a[v]) dfs(v, v, 1, 0); if(a[u] > a[v]) dfs(v, v, 1, 1); if(a[u] < a[v]) dfs(v, v, 1, 2); for(int i = 0; i < dis.size(); i++) add(dis[i].first, -1, tree); for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, -1, tree1); for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, -1, tree2); for(int i = 0; i < dis.size(); i++) { int dist = dis[i].first, x = dis[i].second; for(int j = 0; j < q[x].size(); j++) { int t = 0, d = q[x][j].first, id = q[x][j].second; if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree1) + sum(d - dist, tree2); res[id] += t; } } for(int i = 0; i < dis1.size(); i++) { int dist = dis1[i].first, x = dis1[i].second; for(int j = 0; j < q[x].size(); j++) { int t = 0, d = q[x][j].first, id = q[x][j].second; if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree2); res[id] += t; } } for(int i = 0; i < dis2.size(); i++) { int dist = dis2[i].first, x = dis2[i].second; for(int j = 0; j < q[x].size(); j++) { int t = 0, d = q[x][j].first, id = q[x][j].second; if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree1); res[id] += t; } } for(int i = 0; i < dis.size(); i++) add(dis[i].first, 1, tree); for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, 1, tree1); for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, 1, tree2); } dis.clear(); dis1.clear(); dis2.clear(); dfs(u, u, 0, 0); for(int i = 0; i < dis.size(); i++) add(dis[i].first, -1, tree); for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, -1, tree1); for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, -1, tree2); for(Edge *e = H[u]; e; e = e->next) if(!done[e->v]) { int v = e->v; mx[0] = nsize = size[v]; getroot(v, root = 0); solve(root); } } void work() { scanf("%d%d", &n, &m); for(int i = 1; i <= n; i++) scanf("%d", &a[i]); for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); addedges(u, v); addedges(v, u); } for(int i = 1; i <= n; i++) q[i].clear(); for(int i = 1; i <= m; i++) { int x, d; scanf("%d%d", &x, &d); q[x].push_back(mp(d, i)); } mx[0] = nsize = n; getroot(1, root = 0); solve(root); for(int i = 1; i <= m; i++) printf("%d\n", res[i]); } int main() { int _; scanf("%d", &_); while(_--) { init(); work(); } return 0; }