树链剖分 JAG Summer 2012 Day 4 D Do use segment tree

树链剖分裸题....

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;

#define lson o << 1, L, mid
#define rson o << 1 | 1, mid+1, R
#define ls o << 1
#define rs o << 1 | 1
const int INF = 0x3f3f3f3f * 2;
const int maxn = 500005;
const int maxm = 500005;

int size[maxn];
int son[maxn];
int dep[maxn];
int top[maxn];
int fa[maxn];
int mpp[maxn];
int w[maxn];
int a[maxn];
int n, m, z;

struct Edge
{
	int v;
	Edge *next;
}E[maxm], *H[maxn], *edges;

struct node
{
	int tmax, lmax, rmax, sum, lazy;
	node() {}
	node(int tmax, int lmax, int rmax, int sum, int lazy) : tmax(tmax), lmax(lmax), rmax(rmax), sum(sum), lazy(lazy) {}
}tree[maxn << 2];

void init()
{
	z = 0;
	edges = E;
	memset(H, 0, sizeof H);
}

void addedges(int u, int v)
{
	edges->v = v;
	edges->next = H[u];
	H[u] = edges++;
}

void pushup(int o)
{
	tree[o].sum = tree[ls].sum + tree[rs].sum;
	tree[o].tmax = max(tree[ls].tmax, tree[rs].tmax);
	tree[o].tmax = max(tree[o].tmax, tree[ls].rmax + tree[rs].lmax);
	
	tree[o].lmax = max(tree[ls].lmax, tree[ls].sum + tree[rs].lmax);
	tree[o].rmax = max(tree[rs].rmax, tree[rs].sum + tree[ls].rmax);
}

void pushdown(int o, int L, int R)
{
	if(tree[o].lazy != INF) {
		int mid = (L + R) >> 1, t;
		t = tree[o].lazy * (mid - L + 1);
		tree[ls] = node(max(tree[o].lazy, t), max(tree[o].lazy, t), max(tree[o].lazy, t), t, tree[o].lazy);
		
		t = tree[o].lazy * (R - mid);
		tree[rs] = node(max(tree[o].lazy, t), max(tree[o].lazy, t), max(tree[o].lazy, t), t, tree[o].lazy);
		tree[o].lazy = INF;
	}
}

node merge(node a, node b)
{
	node ans;
	ans.sum = a.sum + b.sum;
	ans.tmax = max(a.tmax, b.tmax);
	ans.tmax = max(ans.tmax, a.rmax + b.lmax);
	
	ans.lmax = max(a.lmax, a.sum + b.lmax);
	ans.rmax = max(b.rmax, b.sum + a.rmax);
	return ans;
}

void build(int o, int L, int R)
{
	tree[o] = node(0, 0, 0, 0, INF);
	if(L == R) {
		tree[o] = node(a[mpp[L]], a[mpp[L]], a[mpp[L]], a[mpp[L]], INF);
		return;
	}
	int mid = (L + R) >> 1;
	build(lson);
	build(rson);
	pushup(o);
}

void update(int o, int L, int R, int ql, int qr, int val)
{
	if(ql <= L && qr >= R) {
		int t = (R - L + 1) * val;
		tree[o] = node(max(val, t), max(val, t), max(val, t), t, val);
		return;
	}
	pushdown(o, L, R);
	int mid = (L + R) >> 1;
	if(ql <= mid) update(lson, ql, qr, val);
	if(qr > mid) update(rson, ql, qr, val);
	pushup(o);
}

node query(int o, int L, int R, int ql, int qr)
{
	if(ql <= L && qr >= R) return tree[o];
	pushdown(o, L, R);
	int mid = (L + R) >> 1;
	node ans;
	if(ql > mid) ans = query(rson, ql, qr);
	else if(qr <= mid) ans = query(lson, ql, qr);
	else ans = merge(query(lson, ql, qr), query(rson, ql, qr));
	pushup(o);
	return ans;
}

void dfs1(int u)
{
	size[u] = 1, son[u] = 0;
	for(Edge *e = H[u]; e; e = e->next) {
		if(e->v != fa[u]) {
			dep[e->v] = dep[u] + 1;
			fa[e->v] = u;
			dfs1(e->v);
			size[u] += size[e->v];
			if(size[son[u]] < size[e->v]) son[u] = e->v;
		}
	}
}

void dfs2(int u, int tp)
{
	w[u] = ++z, top[u] = tp;
	if(son[u]) dfs2(son[u], tp);
	for(Edge *e = H[u]; e; e = e->next) {
		if(e->v != fa[u] && e->v != son[u]) dfs2(e->v, e->v);
	}
}

void solve1(int a, int b, int c)
{
	int f1 = top[a], f2 = top[b];
	while(f1 != f2) {
		if(dep[f1] < dep[f2]) {
			swap(a, b);
			swap(f1, f2);
		}
		update(1, 1, n, w[f1], w[a], c);
		a = fa[f1], f1 = top[a];
	}
	if(dep[a] > dep[b]) swap(a, b);
	update(1, 1, n, w[a], w[b], c);
}

void solve2(int a, int b)
{
	int f1 = top[a], f2 = top[b];
	node ans1, ans2;
	int flag1 = 0, flag2 = 0;
	while(f1 != f2) {
		if(dep[f1] > dep[f2]) {
			if(flag1 == 0) flag1 = 1, ans1 = query(1, 1, n, w[f1], w[a]);
			else ans1 = merge(query(1, 1, n, w[f1], w[a]), ans1);
			a = fa[f1], f1 = top[a];
		}
		else {
			if(flag2 == 0) flag2 = 1, ans2 = query(1, 1, n, w[f2], w[b]);
			else ans2 = merge(query(1, 1, n, w[f2], w[b]), ans2);
			b = fa[f2], f2 = top[b];
		}
	}
	if(dep[a] > dep[b]) {
		if(flag1 == 0) flag1 = 1, ans1 = query(1, 1, n, w[b], w[a]);
		else ans1 = merge(query(1, 1, n, w[b], w[a]), ans1);
	}
	else {
		if(flag2 == 0) flag2 = 1, ans2 = query(1, 1, n, w[a], w[b]);
		else ans2 = merge(query(1, 1, n, w[a], w[b]), ans2);
	}
	
	int res = 0;
	if(flag1 == 0 && flag2 == 1) {
		res = ans2.tmax;
	}
	else if(flag1 == 1 && flag2 == 0) {
		res = ans1.tmax;
	}
	else {
		res = max(ans1.tmax, ans2.tmax);
		res = max(res, ans1.lmax + ans2.lmax);
	}
	printf("%d\n", res);
}


void work()
{
	for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
	for(int i = 1; i < n; i++) {
		int u, v;
		scanf("%d%d", &u, &v);
		addedges(u, v);
		addedges(v, u);
	}
	dfs1(1);
	dfs2(1, 1);
	
	for(int i = 1; i <= n; i++) mpp[w[i]] = i;
	build(1, 1, n);
	while(m--) {
		int op, a, b, c;
		scanf("%d%d%d%d", &op, &a, &b, &c);
		if(op == 1) solve1(a, b, c);
		if(op == 2) solve2(a, b);
	}
}

int main()
{
//freopen("data", "r", stdin);
	while(scanf("%d%d", &n, &m) != EOF) {
		init();
		work();
	}
	
	return 0;
}


你可能感兴趣的:(树链剖分)