【树分治】 ZOJ Travel

离线每个询问,然后做树分治。。

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;

#define lowbit(x) (x&(-x))
#define pii pair<int, int> 
#define mp(x, y) make_pair(x, y)
const int maxn = 100005;
const int maxm = 200005;
const int INF = 0x3f3f3f3f;

struct Edge
{
	int v;
	Edge *next;
}*H[maxn], *edges, E[maxm];

vector<pii> q[maxn], dis1, dis2, dis;
bool done[maxn];
int size[maxn];
int res[maxn];
int mx[maxn];
int a[maxn];
int tree[maxn];
int tree1[maxn];
int tree2[maxn];
int n, m, root, nsize;

void addedges(int u, int v)
{
	edges->v = v;
	edges->next = H[u];
	H[u] = edges++;
}

void init()
{
	edges = E;
	memset(H, 0, sizeof H);
	memset(res, 0, sizeof res);
	memset(done, 0, sizeof done);
}

void getroot(int u, int fa)
{
	mx[u] = 0, size[u] = 1;
	for(Edge *e = H[u]; e; e = e->next) if(!done[e->v] && e->v != fa) {
		int v = e->v;
		getroot(v, u);
		size[u] += size[v];
		mx[u] = max(mx[u], size[v]);
	}
	mx[u] = max(mx[u], nsize - size[u]);
	if(mx[u] < mx[root]) root = u;
}

void add(int x, int v, int tree[])
{
	x++;
	for(int i = x; i <= n + 1; i += lowbit(i)) tree[i] += v;
}

int sum(int x, int tree[])
{
	x++;
	int ans = 0;
	for(int i = x; i > 0; i -= lowbit(i)) ans += tree[i];
	return ans;
}

void dfs(int u, int fa, int dep, int flag)
{
	if(flag == 0) dis.push_back(mp(dep, u));
	if(flag == 1) dis1.push_back(mp(dep, u));
	if(flag == 2) dis2.push_back(mp(dep, u));
	for(Edge *e = H[u]; e; e = e->next) if(e->v != fa && !done[e->v]) {
		int v = e->v;
		if(flag == 0) {
			if(a[u] == a[v]) dfs(v, u, dep + 1, 0);
			if(a[u] > a[v]) dfs(v, u, dep + 1, 1);
			if(a[u] < a[v]) dfs(v, u, dep + 1, 2);
		}
		else if(flag == 1) {
			if(a[u] >= a[v]) dfs(v, u, dep + 1, 1);
		}
		else {
			if(a[u] <= a[v]) dfs(v, u, dep + 1, 2);
		}
	}
}

void solve(int u)
{
	done[u] = true;
	
	dis.clear();
	dis1.clear();
	dis2.clear();
	dfs(u, u, 0, 0);
	for(int i = 0; i < dis.size(); i++) add(dis[i].first, 1, tree);
	for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, 1, tree1);
	for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, 1, tree2);
	
	for(int i = 0; i < q[u].size(); i++) {
		int t = 0, d = q[u][i].first, id = q[u][i].second;
		t = sum(d, tree) + sum(d, tree1) + sum(d, tree2);
		res[id] += t;
	}

	for(Edge *e = H[u]; e; e = e->next) if(!done[e->v]) {
		int v = e->v;
		dis.clear();
		dis1.clear();
		dis2.clear();
		if(a[u] == a[v]) dfs(v, v, 1, 0);
		if(a[u] > a[v]) dfs(v, v, 1, 1);
		if(a[u] < a[v]) dfs(v, v, 1, 2);
		
		for(int i = 0; i < dis.size(); i++) add(dis[i].first, -1, tree);
		for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, -1, tree1);
		for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, -1, tree2);
		
		for(int i = 0; i < dis.size(); i++) {
			int dist = dis[i].first, x = dis[i].second;
			for(int j = 0; j < q[x].size(); j++) {
				int t = 0, d = q[x][j].first, id = q[x][j].second;
				if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree1) + sum(d - dist, tree2);
				res[id] += t;
			}
		}
		
		for(int i = 0; i < dis1.size(); i++) {
			int dist = dis1[i].first, x = dis1[i].second;
			for(int j = 0; j < q[x].size(); j++) {
				int t = 0, d = q[x][j].first, id = q[x][j].second;
				if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree2);
				res[id] += t;
			}
		}

		for(int i = 0; i < dis2.size(); i++) {
			int dist = dis2[i].first, x = dis2[i].second;
			for(int j = 0; j < q[x].size(); j++) {
				int t = 0, d = q[x][j].first, id = q[x][j].second;
				if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree1);
				res[id] += t;
			}
		}
		
		for(int i = 0; i < dis.size(); i++) add(dis[i].first, 1, tree);
		for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, 1, tree1);
		for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, 1, tree2);
	}
	
	
	dis.clear();
	dis1.clear();
	dis2.clear();
	dfs(u, u, 0, 0);
	for(int i = 0; i < dis.size(); i++) add(dis[i].first, -1, tree);
	for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, -1, tree1);
	for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, -1, tree2);
	
	for(Edge *e = H[u]; e; e = e->next) if(!done[e->v]) {
		int v = e->v;
		mx[0] = nsize = size[v];
		getroot(v, root = 0);
		solve(root);
	}
}

void work()
{
	scanf("%d%d", &n, &m);
	for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
	
	for(int i = 1; i < n; i++) {
		int u, v;
		scanf("%d%d", &u, &v);
		addedges(u, v);
		addedges(v, u);
	}

	for(int i = 1; i <= n; i++) q[i].clear();

	for(int i = 1; i <= m; i++) {
		int x, d;
		scanf("%d%d", &x, &d);
		q[x].push_back(mp(d, i));
	}
	
	mx[0] = nsize = n;
	getroot(1, root = 0);
	solve(root);
	
	for(int i = 1; i <= m; i++) printf("%d\n", res[i]);
}

int main()
{
	int _;
	scanf("%d", &_);
	while(_--) {
		init();
		work();
	}
	
	return 0;
}


你可能感兴趣的:(树分治)