BZOJ3757【树上莫队算法】

/* I will wait for you*/

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <ctime>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <vector>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <string>
#define make make_pair
#define fi first
#define se second

using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

const int maxn = 100010;
const int maxm = 1010;
const int maxs = 26;
const int inf = 0x3f3f3f3f;
const int P = 1000000007;
const double error = 1e-9;

inline int read()
{
	int x = 0, f = 1;
	char ch = getchar();
	while (ch <= 47 || ch >= 58)
		f = (ch == 45 ? -1 : 1), ch = getchar();
	while (ch >= 48 && ch <= 57)
		x = x * 10 + ch - 48, ch = getchar();
	return x * f;
}

struct edge
{
	int v, next;
} e[maxn];

struct query
{
	int u, v, a, b, pos;
} q[maxn];

int n, m, blnum, ans, cnt, top, root, clo,
    head[maxn], res[maxn], fa[maxn][20], bel[maxn], deep[maxn], bin[maxn],
    size[maxn], dfn[maxn], vis[maxn], co[maxn], in[maxn], s[maxn];

void insert(int u, int v)
{
	e[cnt] = (edge) {v, head[u]}, head[u] = cnt++;
	e[cnt] = (edge) {u, head[v]}, head[v] = cnt++;
}

void dfs(int u)
{
	dfn[u] = ++clo;

	for (int i = 1; i <= 16; i++)
		if (deep[u] >= bin[i])
			fa[u][i] = fa[fa[u][i - 1]][i - 1];
		else
			break;

	for (int i = head[u]; i != -1; i = e[i].next) {
		int v = e[i].v;

		if (v != fa[u][0]) {
			fa[v][0] = u, deep[v] = deep[u] + 1;
			dfs(v), size[u] += size[v];

			if (size[u] >= 300) {
				blnum++;
				for (int j = 1; j <= size[u]; j++)
					bel[s[top--]] = blnum;
				size[u] = 0;
			}
		}
	}

	s[++top] = u, size[u]++;
}

int lca(int u, int v)
{
	if (deep[u] < deep[v])
		swap(u, v);

	int t = deep[u] - deep[v];

	for (int i = 0; i <= 16; i++)
		if (t & bin[i])
			u = fa[u][i];

	for (int i = 16; i >= 0; i--)
		if (fa[u][i] != fa[v][i])
			u = fa[u][i], v = fa[v][i];

	return u == v ? u : fa[u][0];
}

void reverse(int u)
{
	if (vis[u] == 0 && (in[co[u]]++) == 0)
		ans++;
	if (vis[u] == 1 && (--in[co[u]]) == 0)
		ans--;
	vis[u] ^= 1;
}
			
void solve(int u, int v)
{
	while (u != v) {
		if (deep[u] < deep[v])
			swap(u, v);
		reverse(u), u = fa[u][0];
	}
}

bool cmp(query a, query b)
{
	return bel[a.u] < bel[b.u] || 
	       bel[a.u] == bel[b.u] && dfn[a.v] < dfn[b.v];
}

int main()
{
	for (int i = 0; i <= 16; i++)
		bin[i] = (1 << i);

	n = read(), m = read();

	for (int i = 1; i <= n; i++)
		co[i] = read();

	memset(head, -1, sizeof head);

	for (int i = 1; i <= n; i++) {
		int u = read(), v = read();

		if (!u)
			root = v;
		if (!v)
			root = u;

		if (u && v)
			insert(u, v);
	}

	dfs(root), blnum++;

	while (top)
		bel[s[top--]] = blnum;

	for (int i = 1; i <= m; i++) {
		q[i].u = read(), q[i].v = read();
		q[i].a = read(), q[i].b = read();
		q[i].pos = i;
	}
	
	sort(q + 1, q + 1 + m, cmp);

	for (int i = 1; i <= m; i++) {
		if (i == 1)
			solve(q[i].u, q[i].v);

		else {
			solve(q[i - 1].u, q[i].u);
			solve(q[i - 1].v, q[i].v);
		}
	
		int t = lca(q[i].u, q[i].v);
		
		reverse(t), res[q[i].pos] = ans;

		if (in[q[i].a] && in[q[i].b] && q[i].a != q[i].b)
			res[q[i].pos]--;

		reverse(t);
	}
	
	for (int i = 1; i <= m; i++)
		printf("%d\n", res[i]);	

	return 0;
}

你可能感兴趣的:(BZOJ3757【树上莫队算法】)