Luogu SP10707 COT2 - Count on a tree II___树上莫队

题目大意:

一个有n个点的树,m个询问,每个询问给出(u,v),问两点间简单路径上的点权不同值有多少个。

n < = 40000 , m < = 100000 n <= 40000, m <= 100000 n<=40000,m<=100000

分析:

树上莫队是利用欧拉序的性质实现的莫队,可以解决很多在树上的问题
这题其实就是一个树上莫队的板子,直接上即可

代码:

#include 

#define rep(i, st, ed) for (int i = st; i <= ed; i++)
#define rwp(i, ed, st) for (int i = ed; i >= st; i--)

#define N 50005

using namespace std;

typedef long long ll;

struct edge {
    int To, nxt;
}e[N*4];
 
struct Code {
	int l, r;
}C[N];
struct Node {
	int l, r, lca, id;
}q[N*2];
int fa[N][30], deep[N], seq[N*4], sum[N], ans[N*2], beln[N*5], vis[N], val[N], ls[N], d[N];
int n, m, now, len, siz, cnt;

void read(int &x) {
	int f = 1; x = 0; char s = getchar();
	while (s < '0' || s > '9') { if (s == '-') f = - 1; s = getchar(); }
	while (s >= '0' && s <= '9') { x = x * 10 + (s - '0'); s = getchar(); }
	x = x * f;
}

void Addedge(int u, int v) {
    e[++cnt].To = v, e[cnt].nxt = ls[u], ls[u] = cnt;
    e[++cnt].To = u, e[cnt].nxt = ls[v], ls[v] = cnt;
}

void dfs(int x) {
    seq[++len] = x; C[x].l = len;
    for(int i = ls[x]; i; i = e[i].nxt) {
        if (e[i].To == fa[x][0]) continue;
		int y = e[i].To; 
		deep[y] = deep[x] + 1; fa[y][0] = x;
        for(int j = 1; (1 << j) <= deep[y]; j++) fa[y][j] = fa[fa[y][j - 1]][j - 1]; dfs(y);
    }
    seq[++len] = x; C[x].r = len;
}

int Get_lca(int u, int v) {
    if (deep[u] < deep[v]) swap(u, v);
    rwp(i, 20, 0) 
        if(deep[fa[u][i]] >= deep[v]) u = fa[u][i];
    if (u == v) return u;
    rwp(i, 20, 0)
        if(fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
    return fa[u][0];
}

bool cmp(Node aa, Node bb) {
    return (beln[aa.l] != beln[bb.l]) ? (beln[aa.l] < beln[bb.l]) : ((beln[aa.l] & 1) ? aa.r < bb.r : aa.r > bb.r);
}

void ins(int x) {
	if (sum[x] == 0) ++now; ++sum[x];
}

void del(int x) {
    if (sum[x] == 1) --now; --sum[x]; 
} 

void work(int x) {
    if (vis[x]) del(val[x]); else ins(val[x]);
	vis[x] ^= 1;
}

void write(int x) {
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0'); 
}

int main() {
    read(n); read(m); 
    rep(i, 1, n) read(val[i]), d[i] = val[i];
    sort(d + 1, d + n + 1);
    int tot = unique(d + 1, d + n + 1) - d - 1;	
    rep(i, 1, n) val[i] = lower_bound(d + 1, d + tot + 1, val[i]) - d;
    int u, v, uv;
    rep(i, 1, n - 1) read(u), read(v), Addedge(u, v);
    deep[1] = 1; dfs(1);
    siz = sqrt(len);
    rep(i, 1, ceil((double) len / siz))
        rep(j, (i - 1) * siz + 1, i * siz) beln[j] = i;
    rep(i, 1, m) {
        read(u), read(v), uv = Get_lca(u, v);
        if (C[u].l > C[v].l) swap(u, v);
        if (u == uv) q[i].l = C[u].l, q[i].r = C[v].l; else q[i].l = C[u].r, q[i].r = C[v].l, q[i].lca = uv;
        q[i].id = i;
    }
    sort(q + 1, q + m + 1, cmp);
    int l = 1, r = 0;
    rep(i, 1, m) {
        int ql = q[i].l, qr = q[i].r, lca = q[i].lca;
        while (l < q[i].l) work(seq[l]), l++;
        while (l > q[i].l) l--, work(seq[l]);
        while (r < q[i].r) r++, work(seq[r]);
        while (r > q[i].r) work(seq[r]), r--;
        if (lca) work(lca);
        ans[q[i].id] = now;
        if (lca) work(lca);
    }
    rep(i, 1, m) write(ans[i]), printf("\n");
    return 0;
}

你可能感兴趣的:(C++,莫队)