【树链剖分】 HDOJ 4718 The LCIS on the Tree

树链剖分,线段树区间合并。。。。比较难调试。。。

#include <iostream>
#include <queue> 
#include <stack> 
#include <map> 
#include <set> 
#include <bitset> 
#include <cstdio> 
#include <algorithm> 
#include <cstring> 
#include <climits>
#include <cstdlib>
#include <cmath>
#include <time.h>
#define maxn 100005
#define maxm 200005
#define eps 1e-10
#define mod 1000000007
#define INF 0x3f3f3f3f
#define PI (acos(-1.0))
#define lowbit(x) (x&(-x))
#define mp make_pair
#define ls o<<1
#define rs o<<1 | 1
#define lson o<<1, L, mid 
#define rson o<<1 | 1, mid+1, R
//#pragma comment(linker, "/STACK:16777216")
typedef long long LL;
typedef unsigned long long ULL;
//typedef int LL;
using namespace std;
LL qpow(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base;base=base*base;b/=2;}return res;}
LL powmod(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base%mod;base=base*base%mod;b/=2;}return res;}
void scanf(int &__x){__x=0;char __ch=getchar();while(__ch==' '||__ch=='\n')__ch=getchar();while(__ch>='0'&&__ch<='9')__x=__x*10+__ch-'0',__ch = getchar();}
LL gcd(LL _a, LL _b){if(!_b) return _a;else return gcd(_b, _a%_b);}
// head

struct Edge
{
	int v;
	Edge *next;
}E[maxm], *H[maxn], *edges;

int segmax[maxn << 2];
int maxnum[maxn << 2];
int segmin[maxn << 2];
int minnum[maxn << 2];
int lmin[maxn << 2];
int rmin[maxn << 2];
int lmax[maxn << 2];
int rmax[maxn << 2];
int size[maxn];
int son[maxn];
int dep[maxn];
int top[maxn];
int fa[maxn];
int w[maxn];
int val[maxn], val1[maxn];
int z, n, m, ok, res, last;

void addedges(int u, int v)
{
	edges->v = v;
	edges->next = H[u];
	H[u] = edges++;
}

void dfs1(int u)
{
	size[u] = 1, son[u] = 0;
	for(Edge *e = H[u]; e; e = e->next) {
		dep[e->v] = dep[u] + 1;
		fa[e->v] = u;
		dfs1(e->v);
		size[u] += size[e->v];
		if(size[son[u]] < size[e->v]) son[u] = e->v;
	}
}

void dfs2(int u, int tp)
{
	w[u] = ++z, top[u] = tp;
	if(son[u]) dfs2(son[u], tp);
	for(Edge *e = H[u]; e; e = e->next)
		if(e->v != son[u]) dfs2(e->v, e->v);
}

void init(void)
{
	z = 0;
	edges = E;
	memset(H, 0, sizeof H);
	memset(son, 0, sizeof son);
}

void read(void)
{
	int v, u;
	scanf("%d", &n);
	for(int i = 1; i <= n; i++) scanf("%d", &val1[i]);
	for(v = 2; v <= n; v++) {
		scanf("%d", &u);
		addedges(u, v);
	}
}

void pushup(int o, int L, int R)
{
	int mid = (L + R) >> 1;
	if(mid - L + 1 == lmax[ls]) {
		if(val[mid] < val[mid+1]) lmax[o] = lmax[ls] + lmax[rs];
		else lmax[o] = lmax[ls];
	}
	else lmax[o] = lmax[ls];

	if(R - mid == rmax[rs]) {
		if(val[mid] < val[mid+1]) rmax[o] = rmax[ls] + rmax[rs];
		else rmax[o] = rmax[rs];
	}
	else rmax[o] = rmax[rs];

	segmax[o] = 0;
	if(segmax[ls] > segmax[o]) segmax[o] = segmax[ls], maxnum[o] = maxnum[ls];
	if(segmax[rs] > segmax[o]) segmax[o] = segmax[rs], maxnum[o] = maxnum[rs];
	if(rmax[ls] + lmax[rs] > segmax[o] && val[mid] < val[mid+1])
		segmax[o] = rmax[ls] + lmax[rs], maxnum[o] = mid - rmax[ls] + 1;

	if(mid - L + 1 == lmin[ls]) {
		if(val[mid] > val[mid+1]) lmin[o] = lmin[ls] + lmin[rs];
		else lmin[o] = lmin[ls];
	}
	else lmin[o] = lmin[ls];

	if(R - mid == rmin[rs]) {
		if(val[mid] > val[mid+1]) rmin[o] = rmin[ls] + rmin[rs];
		else rmin[o] = rmin[rs];
	}
	else rmin[o] = rmin[rs];

	segmin[o] = 0;
	if(segmin[ls] > segmin[o]) segmin[o] = segmin[ls], minnum[o] = minnum[ls];
	if(segmin[rs] > segmin[o]) segmin[o] = segmin[rs], minnum[o] = minnum[rs];
	if(rmin[ls] + lmin[rs] > segmin[o] && val[mid] > val[mid+1])
		segmin[o] = rmin[ls] + lmin[rs], minnum[o] = mid - rmin[ls] + 1;
}

void build(int o, int L, int R)
{
	if(L == R) {
		segmax[o] = lmax[o] = rmax[o] = 1;
		segmin[o] = lmin[o] = rmin[o] = 1;
		maxnum[o] = minnum[o] = L;
		return;
	}
	int mid = (L + R) >> 1;
	build(lson);
	build(rson);
	pushup(o, L, R);
}

int query_max(int o, int L, int R, int ql, int qr)
{
	if(ql <= L && qr >= R) return segmax[o];
	int mid = (L + R) >> 1;
	if(ql <= mid && qr > mid) {
		int ans = max(query_max(lson, ql, qr), query_max(rson, ql, qr));
		if(val[mid+1] > val[mid]) ans = max(ans, min(qr, mid + lmax[rs]) - max(ql, mid - rmax[ls] + 1) + 1);
		return ans;
	}
	else if(ql <= mid) return query_max(lson, ql, qr);
	else return query_max(rson, ql, qr);
}

int query_min(int o, int L, int R, int ql, int qr)
{
	if(ql <= L && qr >= R) return segmin[o];
	int mid = (L + R) >> 1;
	if(ql <= mid && qr > mid) {
		int ans = max(query_min(lson, ql, qr), query_min(rson, ql, qr));
		if(val[mid+1] < val[mid]) ans = max(ans, min(qr, mid + lmin[rs]) - max(ql, mid - rmin[ls] + 1) + 1);
		return ans;
	}
	else if(ql <= mid) return query_min(lson, ql, qr);
	else return query_min(rson, ql, qr);
}

void query_lmax(int o, int L, int R, int ql, int qr)
{
	if(ok) return;
	if(ql <= L && qr >= R) {
		if(last  < val[L]) {
			res += lmax[o];
			if(lmax[o] != R - L + 1) ok = 1;
			last = val[R];
		}
		else ok = 1;
		return;
	}
	int mid = (L + R) >> 1;
	if(ql <= mid) query_lmax(lson, ql, qr);
	if(qr > mid) query_lmax(rson, ql, qr);
}

void query_lmin(int o, int L, int R, int ql, int qr)
{
	if(ok) return;
	if(ql <= L && qr >= R) {
		if(last > val[L]) {
			res += lmin[o];
			if(lmin[o] != R - L + 1) ok = 1;
			last = val[R];
		}
		else ok = 1;
		return;
	}
	int mid = (L + R) >> 1;
	if(ql <= mid) query_lmin(lson, ql, qr);
	if(qr > mid) query_lmin(rson, ql, qr);
}

void query_rmax(int o, int L, int R, int ql, int qr)
{
	if(ok) return;
	if(ql <= L && qr >= R) {
		if(last > val[R]) {
			res += rmax[o];
			if(rmax[o] != R - L + 1) ok = 1;
			last = val[L];
		}
		else ok = 1;
		return;
	}
	int mid = (L + R) >> 1;
	if(qr > mid) query_rmax(rson, ql, qr);
	if(ql <= mid) query_rmax(lson, ql, qr);
}

void query_rmin(int o, int L, int R, int ql, int qr)
{
	if(ok) return;
	if(ql <= L && qr >= R) {
		if(last < val[R]) {
			res += rmin[o];
			if(rmin[o] != R - L + 1) ok = 1;
			last = val[L];
		}
		else ok = 1;
		return;
	}
	int mid = (L + R) >> 1;
	if(qr > mid) query_rmin(rson, ql, qr);
	if(ql <= mid) query_rmin(lson, ql, qr);
}

int solve(int a, int b)
{
	int f1 = top[a], f2 = top[b];
	int ans = 0, lasta = a, lastb = b;
	int prea = 0, preb = 0;
	int lminv, rminv, lmaxv, rmaxv;
	while(f1 != f2) {
		if(dep[f1] < dep[f2]) {
			ans = max(ans, query_max(1, 1, n, w[f2], w[b]));
			ok = res = 0, last = val[w[b]] + 1;
			query_rmax(1, 1, n, w[f2], w[b]);
			rmaxv = res;
			ok = res = 0, last = val[w[f2]] - 1;
			query_lmax(1, 1, n, w[f2], w[b]);
			lmaxv = res;
			if(val[w[lastb]] > val[w[b]]) ans = max(ans, preb + rmaxv);
			if(rmaxv == w[b] - w[f2] + 1 && val[w[b]] < val[w[lastb]]) preb = rmaxv + preb;
			else preb = lmaxv;
			lastb = f2;
			b = fa[f2], f2 = top[b];	
		}
		else {
			ans = max(ans, query_min(1, 1, n, w[f1], w[a]));
			ok = res = 0, last = val[w[a]] - 1;
			query_rmin(1, 1, n, w[f1], w[a]);
			rminv = res;
			ok = res = 0, last = val[w[f1]] + 1;
			query_lmin(1, 1, n, w[f1], w[a]);
			lminv = res;
			if(val[w[lasta]] < val[w[a]]) ans = max(ans, prea + rminv);
			if(rminv == w[a] - w[f1] + 1 && val[w[a]] > val[w[lasta]]) prea = rminv + prea;
			else prea = lminv;
			lasta = f1;
			a = fa[f1], f1 = top[a];
		}
	}
	if(a == b) {
		if(lasta == a) {
			if(prea == 0) prea = 1;
			if(val[w[a]] < val[w[lastb]]) ans = max(ans, prea + preb);
		}
		else if(lastb == b) {
			if(preb == 0) preb = 1;
			if(val[w[b]] > val[w[lasta]]) ans = max(ans, prea + preb);
		}
		else {
			if(val[w[a]] > val[w[lasta]]) ans = max(ans, prea+1);
			if(val[w[a]] < val[w[lastb]]) ans = max(ans, preb+1);
			if(val[w[a]] > val[w[lasta]] && val[w[a]] < val[w[lastb]]) ans = max(ans, prea + preb + 1);
		}
		return ans;
	}
	if(a != b) {
		if(dep[a] < dep[b]) {
			ans = max(ans, query_max(1, 1, n, w[a], w[b]));
			ok = res = 0, last = val[w[a]] - 1;
			query_lmax(1, 1, n, w[a], w[b]);
			lmaxv = res;
			ok = res = 0, last = val[w[b]] + 1;
			query_rmax(1, 1, n, w[a], w[b]);
			rmaxv = res;
			if(val[w[lastb]] > val[w[b]]) ans = max(ans, preb + rmaxv);
			if(rmaxv == w[b] - w[a] + 1 && val[w[b]] < val[w[lastb]]) preb = rmaxv + preb;
			else preb = lmaxv;
			lastb = a;
			b = a;
		}
		else {
			ans = max(ans, query_min(1, 1, n, w[b], w[a]));
			ok = res = 0, last = val[w[b]] + 1;
			query_lmin(1, 1, n, w[b], w[a]);		
			lminv = res;
			ok = res = 0, last = val[w[a]] - 1;
			query_rmin(1, 1, n, w[b], w[a]);
			rminv = res;
			if(val[w[lasta]] < val[w[a]]) ans = max(ans, prea + rminv);
			if(rminv == w[a] - w[b] + 1 && val[w[a]] > val[w[lasta]]) prea = rminv + prea;
			else prea = lminv;
			lasta = b;
			a = b;
		}
	}
	if(a == lasta) {
		if(preb == 0) preb = 1;
		if(val[w[lastb]] > val[w[b]]) ans = max(ans, prea + preb);
	}
	else if(b == lastb) {
		if(prea == 0) prea = 1;
		if(val[w[lasta]] < val[w[a]]) ans = max(ans, prea + preb);
	}
	else {

	}
	
	return ans;
}
void work(void)
{
	fa[1] = 1;
	dfs1(1);
	dfs2(1, 1);
	for(int i = 1; i <= n; i++) val[w[i]] = val1[i];
	build(1, 1, n);
	int a, b;
	scanf("%d", &m);
	while(m--) {
		scanf("%d%d", &a, &b);
		if(a == b) {
			printf("1\n");
			continue;
		}
		printf("%d\n", solve(a, b));
	}
}

int main(void)
{
	int _, __;
	while(scanf("%d", &_)!=EOF) {
		__ = 0;
		while(_--) {
			init();
			read();
			printf("Case #%d:\n", ++__);
			work();
			if(_) printf("\n");
		}
	}

	return 0;
}


你可能感兴趣的:(HDU)