一个有n个点的树,m个询问,每个询问给出(u,v),问两点间简单路径上的点权不同值有多少个。
n < = 40000 , m < = 100000 n <= 40000, m <= 100000 n<=40000,m<=100000
树上莫队是利用欧拉序的性质实现的莫队,可以解决很多在树上的问题
这题其实就是一个树上莫队的板子,直接上即可
#include
#define rep(i, st, ed) for (int i = st; i <= ed; i++)
#define rwp(i, ed, st) for (int i = ed; i >= st; i--)
#define N 50005
using namespace std;
typedef long long ll;
struct edge {
int To, nxt;
}e[N*4];
struct Code {
int l, r;
}C[N];
struct Node {
int l, r, lca, id;
}q[N*2];
int fa[N][30], deep[N], seq[N*4], sum[N], ans[N*2], beln[N*5], vis[N], val[N], ls[N], d[N];
int n, m, now, len, siz, cnt;
void read(int &x) {
int f = 1; x = 0; char s = getchar();
while (s < '0' || s > '9') { if (s == '-') f = - 1; s = getchar(); }
while (s >= '0' && s <= '9') { x = x * 10 + (s - '0'); s = getchar(); }
x = x * f;
}
void Addedge(int u, int v) {
e[++cnt].To = v, e[cnt].nxt = ls[u], ls[u] = cnt;
e[++cnt].To = u, e[cnt].nxt = ls[v], ls[v] = cnt;
}
void dfs(int x) {
seq[++len] = x; C[x].l = len;
for(int i = ls[x]; i; i = e[i].nxt) {
if (e[i].To == fa[x][0]) continue;
int y = e[i].To;
deep[y] = deep[x] + 1; fa[y][0] = x;
for(int j = 1; (1 << j) <= deep[y]; j++) fa[y][j] = fa[fa[y][j - 1]][j - 1]; dfs(y);
}
seq[++len] = x; C[x].r = len;
}
int Get_lca(int u, int v) {
if (deep[u] < deep[v]) swap(u, v);
rwp(i, 20, 0)
if(deep[fa[u][i]] >= deep[v]) u = fa[u][i];
if (u == v) return u;
rwp(i, 20, 0)
if(fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
bool cmp(Node aa, Node bb) {
return (beln[aa.l] != beln[bb.l]) ? (beln[aa.l] < beln[bb.l]) : ((beln[aa.l] & 1) ? aa.r < bb.r : aa.r > bb.r);
}
void ins(int x) {
if (sum[x] == 0) ++now; ++sum[x];
}
void del(int x) {
if (sum[x] == 1) --now; --sum[x];
}
void work(int x) {
if (vis[x]) del(val[x]); else ins(val[x]);
vis[x] ^= 1;
}
void write(int x) {
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
int main() {
read(n); read(m);
rep(i, 1, n) read(val[i]), d[i] = val[i];
sort(d + 1, d + n + 1);
int tot = unique(d + 1, d + n + 1) - d - 1;
rep(i, 1, n) val[i] = lower_bound(d + 1, d + tot + 1, val[i]) - d;
int u, v, uv;
rep(i, 1, n - 1) read(u), read(v), Addedge(u, v);
deep[1] = 1; dfs(1);
siz = sqrt(len);
rep(i, 1, ceil((double) len / siz))
rep(j, (i - 1) * siz + 1, i * siz) beln[j] = i;
rep(i, 1, m) {
read(u), read(v), uv = Get_lca(u, v);
if (C[u].l > C[v].l) swap(u, v);
if (u == uv) q[i].l = C[u].l, q[i].r = C[v].l; else q[i].l = C[u].r, q[i].r = C[v].l, q[i].lca = uv;
q[i].id = i;
}
sort(q + 1, q + m + 1, cmp);
int l = 1, r = 0;
rep(i, 1, m) {
int ql = q[i].l, qr = q[i].r, lca = q[i].lca;
while (l < q[i].l) work(seq[l]), l++;
while (l > q[i].l) l--, work(seq[l]);
while (r < q[i].r) r++, work(seq[r]);
while (r > q[i].r) work(seq[r]), r--;
if (lca) work(lca);
ans[q[i].id] = now;
if (lca) work(lca);
}
rep(i, 1, m) write(ans[i]), printf("\n");
return 0;
}