CF #326 (Div. 2) E Duff in the Army

//	CF #326 (Div. 2) E  Duff in the Army
//
//	题目大意
//		一棵n各节点树,有p个人分布在节点上,q个询问,每个询
//	问求u到v路径上小于k的所有人的标号.
//
//	解题思路:
//		函数式线段树,初始建立一颗1-P的线段树,然后将每个
//	节点上的信息,按照dfs序列,维护到以该节点为根的函数式
//	线段树.同样的我们的答案,只与rt[u],rt[v],rt[LCA(u,v)],
//	rt[father[LCA(u,v)]].这四棵线段树有关.建议先去写一下
//	spoj 10628 这道简化版的题目.
//
//	感悟:
//
//		一开始就是为了做这题才去学习函数式线段树,感觉这类计数
//	问题都可以用函数式线段树解决.十分巧妙.感谢大神们的博客.小子
//	受教啦~~~~继续加油吧~~~FIGHTING!!!


#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define For(x,a,b,c) for (int x = a; x <= b; x += c)
#define Ffor(x,a,b,c) for (int x = a; x >= b; x -= c)
#define cls(x,a) memset(x,a,sizeof(x))
using namespace std;
typedef long long ll;

const double PI = acos(-1.0);
const double eps = 1e-9;
const int MAX_N = 1e5 + 10000;

const int INF = 1e9 + 7;
const ll MOD = 1e9 + 7;

int N,P,Q;
int a[MAX_N];
vector g[MAX_N];

vector c[MAX_N];

int top[MAX_N];
int idx[MAX_N];
int siz[MAX_N];
int son[MAX_N];
int dep[MAX_N];
int father[MAX_N];

struct node{
	int ls;
	int rs;
	int cnt;
};
struct IntervalTree{
	node p[MAX_N * 20];
	int rt[MAX_N];
	int Siz;

	void init(){
		Siz = 1;
	}

	int build(int L,int R){
		int rt = Siz++;
		p[rt].cnt = 0;
		
		if (L == R){
			return rt;
		}
		int M = (L + R) >> 1;

		p[rt].ls = build(L,M);
		p[rt].rs = build(M+1,R);

		return rt;
	}


	int update(int rt,int L,int R,int q,int v){
		int nrt = Siz++;
		p[nrt] = p[rt];
		p[nrt].cnt += v;
		if (L == R){
			return nrt;
		}
		
		int M = (L + R) >> 1;

		if (q <= M)
			p[nrt].ls = update(p[rt].ls,L,M,q,v);
		else 
			p[nrt].rs = update(p[rt].rs,M+1,R,q,v);
		return nrt;
	}

	int query(int rtl,int rtr,int rlca,int rf_lca,int L,int R,int k){
		if (L == R){
			return L;
		}

		int M = (L + R) >> 1;
		int tmp = p[p[rtl].ls].cnt + p[p[rtr].ls].cnt - p[p[rlca].ls].cnt - p[p[rf_lca].ls].cnt;

		if (tmp >= k){
			return query(p[rtl].ls,p[rtr].ls,p[rlca].ls,p[rf_lca].ls,L,M,k);
		}else 
			return query(p[rtl].rs,p[rtr].rs,p[rlca].rs,p[rf_lca].rs,M+1,R,k - tmp);
	}

}it;

void dfs(int u,int fa,int d){
	dep[u] = d;
	father[u] = fa;
	son[u] = 0;
	siz[u] = 1;
	for (int i = 0 ;i < g[u].size();i ++){
		int v = g[u][i];
		if (v == fa)
			continue;
		dfs(v,u,d+1);
		siz[u] += siz[v];
		if (siz[son[u]] < siz[v])
			son[u] = v;
	}


}

void dfs_2(int u,int tp){
	top[u] = tp;
	if (son[u])
		dfs_2(son[u],tp);
	for (int i = 0 ;i < g[u].size();i ++){
		int v =g[u][i];
		if (v == father[u] || v == son[u])
			continue;
		dfs_2(v,v);
	}
}

int LCA_init(){

	dfs(1,0,1);
	dfs_2(1,1);


}

int LCA(int u,int v){
	int p = top[u];
	int q = top[v];
	while(p != q){
		if (dep[p] < dep[q]){
			swap(p,q);
			swap(u,v);
		}
		u = father[p];
		p = top[u];
	}
	if (dep[u] > dep[v])
		swap(u,v);
	return u;
}

void dfs(int u,int fa){
	for (int i = 0 ;i < g[u].size();i ++){
		int v =g[u][i];
		if (v == fa)
			continue;
		it.rt[v] = it.rt[u];
		for (int j = 0 ;j < c[v].size();j ++)	
			it.rt[v] = it.update(it.rt[v],1,P,c[v][j],1);
		dfs(v,u);
	}
	
}
void print(){
	For(i,1,N,1)
		printf("%d ",a[i]);
	cout << endl;
}
void input(){
	int m = 0;

	For(i,1,N-1,1){
		int u,v;
		scanf("%d%d",&u,&v);
		g[u].push_back(v);
		g[v].push_back(u);
	}
	For(i,1,P,1){
		int x;
		scanf("%d",&x);
		c[x].push_back(i);
	}
	it.init();
	it.rt[0] = it.build(1,P);
	//print();
	it.rt[1] = it.rt[0];
	for (int i = 0 ;i < c[1].size();i ++){
		it.rt[1] = it.update(it.rt[1],1,P,c[1][i],1);
	}
	dfs(1,-1);
	LCA_init();
}


void solve(){



	For(i,1,Q,1){
		int u,v,k;
		scanf("%d%d%d",&u,&v,&k);
		int t = LCA(u,v);
		vector ans;

		//cout << i << endl;
		int x = it.p[it.rt[u]].cnt + it.p[it.rt[v]].cnt - it.p[it.rt[t]].cnt - it.p[it.rt[father[t]]].cnt; 

		for (int j = 1;j <= k;j ++){
			
			if (x < j)
				break;

			int tmp = it.query(it.rt[u],it.rt[v],it.rt[t],it.rt[father[t]],1,P,j);
			ans.push_back(tmp);
		}
		printf("%d",ans.size());
		if (ans.size()){
			For(i,0,ans.size()-1,1)
				printf(" %d",ans[i]);
		}
		puts("");
	}
}


void init(){
	For(i,1,N,1){
		g[i].clear();
		c[i].clear();
	}
	cls(dep,0);
}

int main(){
	//freopen("1.in","r",stdin);
	//freopen("1.out","w",stdout);
	while(scanf("%d%d%d",&N,&P,&Q)!=EOF){
		init();
		input();
		solve();
	}
	return 0;
}

你可能感兴趣的:(Codeforces,线段树,Data,structure)