BZOJ4012【动态点分治】

/* I will wait for you*/

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <ctime>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <vector>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <string>
#define make make_pair
#define fi first
#define se second

using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;

const int maxn = 200010;
const int maxm = 1010;
const int maxs = 26;
const int inf = 0x3f3f3f3f;
const int P = 1000000007;
const double error = 1e-9;

inline ll read()
{
	ll x = 0, f = 1;
	char ch = getchar();
	while (ch <= 47 || ch >= 58)
		f = (ch == 45 ? -1 : 1), ch = getchar();
	while (ch >= 48 && ch <= 57)
		x = x * 10 + ch - 48, ch = getchar();
	return x * f;
}

struct edge
{
	ll v, w, next;
} e[2 * maxn];

struct node
{
	ll age, seat, pos, dis, pis;
};

bool operator < (node a, node b)
{
	return a.age < b.age;
}

ll n, m, ans, cnt, root, sum, A, L, R, 
    bin[maxn], head[maxn], dis[maxn], fa[maxn][30], deep[maxn],
    size[maxn], mx[maxn], fath[maxn], del[maxn], age[maxn];
    
vector<node> di[maxn];

void insert(int u, ll v, ll w)
{
	e[cnt] = (edge) {v, w, head[u]}, head[u] = cnt++;
	e[cnt] = (edge) {u, w, head[v]}, head[v] = cnt++;
}

void dfs1(int u)
{
	for (int i = head[u]; i != -1; i = e[i].next) {
		ll v = e[i].v;
		if (v != fa[u][0]) {
			deep[v] = deep[u] + 1;
			dis[v] = dis[u] + e[i].w;
			fa[v][0] = u, dfs1(v);
		}
	}
}

void init()
{
	for (int i = 0; i <= 20; i++)
		bin[i] = 1 << i;

	for (int i = 1; i <= n; i++)
		for (int j = 1; j <= 20; j++)
			if (deep[i] >= bin[j])
				fa[i][j] = fa[fa[i][j - 1]][j - 1];
}

ll lca(ll u, ll v)
{
	if (deep[u] < deep[v])
		swap(u, v);

	for (int i = 20; i >= 0; i--)
		if (deep[u] - bin[i] >= deep[v])
			u = fa[u][i];

	for (int i = 20; i >= 0; i--)
		if (fa[u][i] != fa[v][i])
			u = fa[u][i], v = fa[v][i];
		
	return u == v ? u : fa[u][0];
}

ll Dis(int u, int v)
{
	int t = (int) lca(u, v);
	return dis[u] + dis[v] - 2 * dis[t];
}

void find(int u, int p)
{
	size[u] = 1, mx[u] = 0;

	for (int i = head[u]; i != -1; i = e[i].next) {
		ll v = e[i].v;
		if (v != p && !del[v]) {
			find(v, u), size[u] += size[v];
			mx[u] = max(mx[u], size[v]);
		}
	}

	mx[u] = max(mx[u], sum - size[u]);

	if (mx[u] < mx[root])
		root = u;
}

void divide(int u, int p)
{
	fath[u] = p, del[u] = 1;

	di[u].push_back((node) {age[u], u});

	for (int i = head[u]; i != -1; i = e[i].next) {
		ll v = e[i].v, tmp;
		if (v != p && !del[v]) {
			root = 0, sum = size[v], find(v, u);
			tmp = root, divide(root, u);

			for (int j = 0; j < di[tmp].size(); j++) {
				node v = di[tmp][j];
				node t = (node) {v.age, v.seat};
				di[u].push_back(t);
			}
		}
	}

	sort(di[u].begin(), di[u].end());

	for (int i = 0; i < di[u].size(); i++) {
		di[u][i].pos = i + 1;
		di[u][i].dis = Dis(u, di[u][i].seat);
		di[u][i].pis = Dis(p, di[u][i].seat);
	}

	for (int i = 1; i < di[u].size(); i++) {
		di[u][i].dis += di[u][i - 1].dis;
		di[u][i].pis += di[u][i - 1].pis;
	}

}

void solve(int u, int p, ll su, ll si)
{
	vector<node> :: iterator l
		= lower_bound(di[p].begin(), di[p].end(), (node) {L});
	vector<node> :: iterator r
	       	= upper_bound(di[p].begin(), di[p].end(), (node) {R});

	ll ldis, rdis, lpos, rpos, lpis, rpis;

	if (l == di[p].begin())
		ldis = lpos = lpis = 0;
	else {
		l--;
		ldis = l -> dis;
		lpos = l -> pos;
		lpis = l -> pis;
	}

	if (r == di[p].begin())
		rdis = rpos = rpis = 0;
	else {
		r--;
		rdis = r -> dis;
		rpos = r -> pos;
		rpis = r -> pis;
	}

	ll size = rpos - lpos;
	ll sum = rdis - ldis;
	ll pum = rpis - lpis;

	ans += (sum - su) + (size - si) * Dis(u, p);

	if (fath[p])
		solve(u, fath[p], pum, size);
}

int main()
{
	n = read(), m = read(), A = read();

	for (int i = 1; i <= n; i++)
		age[i] = read();

	memset(head, -1, sizeof head);
	memset(fa, 0, sizeof fa);
	
	for (int i = 1; i < n; i++) {
		int u = read(), v = read(), w = read();
		insert(u, v, w);
	}
	
	dfs1(1), init();

	root = 0, sum = n, mx[0] = inf;
	find(1, 0), divide(root, 0);

	for (int i = 1; i <= m; i++) {
		int u = read(), a = read(), b = read();

		L = min((a + ans) % A, (b + ans) % A);
		R = max((a + ans) % A, (b + ans) % A);

		ans = 0, solve(u, u, 0, 0);
		printf("%lld\n", ans);
	}
	
	return 0;
}

你可能感兴趣的:(BZOJ4012【动态点分治】)