BZOJ1146 CTSC2008 网络管理network

丧心病狂的数据结构题、、

树链剖分之后用线段树套一个随便什么BST维护第K大值、、

 

Code:

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <set>
#include <map>
#include <queue>
#include <vector>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <complex>

using namespace std;

#define rep(a,b,c) for(int a=b;a<=c;a++)
#define per(a,b,c) for(int a=b;a>=c;a--)
#define max(a,b) ((a>b)?(a):(b))
#define min(a,b) ((a<b)?(a):(b))
#define pb push_back
#define mp make_pair
#define PII pair<int,int>
#define X first
#define Y second

struct node{
	int key,ls,rs,mt,sz;
}tr[2000000];

#define MAXN 80010

bool leaf[MAXN];
int v[MAXN],fa[MAXN][20],h[MAXN],g[MAXN],e[MAXN],a[MAXN],sz[MAXN],head[MAXN];
int next[MAXN*2],t[MAXN*2],id[MAXN],wh[MAXN];
int lm[MAXN*4],rm[MAXN*4],root[MAXN*4];
int n,Q,l,r=0,T=0;

inline int logg(int x){int res=0;while	(x){res++;x/=2;}return	res;}
inline void addedge(int aa,int bb){t[++r]=bb;if	(!g[aa])g[aa]=r;else next[e[aa]]=r;e[aa]=r;}

inline void dfs1(int x){
	v[x]=1;sz[x]=1;
	for	(int u=g[x];u;u=next[u])
		if	(!v[t[u]]){
			fa[t[u]][0]=x;
			h[t[u]]=h[x]+1;
			dfs1(t[u]);
			sz[x]+=sz[t[u]];
		}
	if	(sz[x]==1)	leaf[x]=1;
}

inline void dfs2(int x,int top){
	int maxid,maxn=-1;
	head[x]=top;id[++T]=x;wh[x]=T;
	if	(leaf[x])	return	;
	for	(int u=g[x];u;u=next[u]){
		if	(t[u]==fa[x][0])	continue;
		if	(sz[t[u]]>maxn){
			maxn=sz[t[u]];
			maxid=t[u];
		}
	}
	dfs2(maxid,top);
	for	(int u=g[x];u;u=next[u])
		if	(t[u]!=maxid && t[u]!=fa[x][0])	dfs2(t[u],t[u]);
}

inline int newnode(int x){
	T++;tr[T].key=x;
	tr[T].ls=tr[T].rs=0;
	tr[T].mt=1;tr[T].sz=1;
	return	T;
}

inline void insert(int x,int cur){
	int P=x;
	while	(1){
		tr[P].sz++;
		if	(tr[P].key==cur){tr[P].mt++;return;}
		if	(tr[P].key<cur)
			if	(!tr[P].rs){tr[P].rs=newnode(cur);return;}else	P=tr[P].rs;
		else
			if	(!tr[P].ls){tr[P].ls=newnode(cur);return;}else	P=tr[P].ls;
	}
}

inline void del(int x,int cur){
	int P=x;
	while	(1){
		tr[P].sz--;
		if	(tr[P].key==cur){
			tr[P].mt--;
			return	;
		}
		if	(tr[P].key>cur)	P=tr[P].ls;else	P=tr[P].rs;
	}
}

inline int Count(int x,int cur){
	int res=0,P=x;
	while	(P){
		if	(cur==tr[P].key)	return	res+tr[P].sz-tr[tr[P].rs].sz;
		if	(cur>tr[P].key)	{res+=tr[P].sz-tr[tr[P].rs].sz;P=tr[P].rs;}
		else	P=tr[P].ls;
	}
	return	res;
}

inline void build(int cur,int ll,int rr){
	lm[cur]=ll;rm[cur]=rr;
	if	(ll==rr){
		root[cur]=newnode(a[id[ll]]);
		return	;
	}
	build(cur*2,ll,(ll+rr)/2);
	build(cur*2+1,(ll+rr)/2+1,rr);
	root[cur]=newnode(a[id[ll]]);
	rep(i,ll+1,rr)	insert(root[cur],a[id[i]]);
	return	;
}

inline void modify(int cur,int pp,int K,int KK){
	if	(lm[cur]>pp || rm[cur]<pp)	return	;
	if	(lm[cur]>=pp && rm[cur]<=pp){
		del(root[cur],K);
		insert(root[cur],KK);
		return	;
	}
	del(root[cur],K);
	insert(root[cur],KK);
	modify(cur*2,pp,K,KK);
	modify(cur*2+1,pp,K,KK);
	return	;
}

inline int query(int cur,int ll,int rr,int K){
	if	(lm[cur]>rr || rm[cur]<ll)	return	0;
	if	(lm[cur]>=ll && rm[cur]<=rr)	return	Count(root[cur],K);
	return	query(cur*2,ll,rr,K)+query(cur*2+1,ll,rr,K);
}

inline int lca(int x,int y){
	int th;
	while	(h[x]>h[y]){
		th=0;
		while	(h[fa[x][th]]>h[y])	th++;
		if	(h[fa[x][th]]<h[y])	th--;
		x=fa[x][th];
	}
	while	(h[x]<h[y]){
		th=0;
		while	(h[fa[y][th]]>h[x])	th++;
		if	(h[fa[y][th]]<h[x])	th--;
		y=fa[y][th];
	}
	while	(x!=y){
		th=0;
		while	(fa[x][th]!=fa[y][th])	th++;
		if	(th!=0)	th--;
		x=fa[x][th];y=fa[y][th];
	}
	return	x;
}

inline int check(int be,int en,int K){
	int res=0;
	while	(head[be]!=head[en])	res+=query(1,wh[head[be]],wh[be],K),be=fa[head[be]][0];
	res+=query(1,wh[en],wh[be],K);
	return	res;
}

inline void scan(int &x){
	char c=getchar();
	while 	(c<'0' || c>'9')	c=getchar();
	x=c-'0';c=getchar();
	while	(c>='0' && c<='9'){
		x=x*10+c-'0';
		c=getchar();
	}
}

int main(){
	scan(n);scan(Q);
	rep(i,1,n)	scan(a[i]);
	rep(i,1,n-1){
		int aa,bb;
		scan(aa);scan(bb);
		addedge(aa,bb);
		addedge(bb,aa);
	}
	dfs1(1);dfs2(1,1);
	T=0;
	build(1,1,n);
	fa[1][0]=1;
	rep(j,1,logg(n))	rep(i,1,n)	fa[i][j]=fa[fa[i][j-1]][j-1];
	while	(Q--){
		int k,A,b;
		scan(k);scan(A);scan(b);
		if	(!k)	modify(1,wh[A],a[A],b),a[A]=b;
		else{
			int Fa=lca(A,b);
			if	(h[A]+h[b]-h[Fa]*2+1<k){
				puts("invalid request!");
				continue;
			}
			k=(h[A]+h[b]-h[Fa]*2+1-k+1);
			l=0,r=100000000;
			while	(l<r){
				int C=check(A,Fa,(l+r)/2)+check(b,Fa,(l+r)/2);
				if (a[Fa]<=(l+r)/2)	C--;
				if	(C>=k)	r=(l+r)/2;else	l=(l+r)/2+1;
			}
			printf("%d\n",l);
		}
	}
	return	0;
}

  

你可能感兴趣的:(NetWork)