spoj COT 可持久化数据结构 (LCA模版)

查询树链第K大 。                


每个版本的线段树维护的是 从这个节点到 根的 树链的版本, 由于树链第K大,在统计比X 小的数个数时 是可以 进行加减法运算的,所以  就可以用可持久化数据结构。

维护个数时 , sum = f(a) + f(b) - f(c) -f(d)    : c 为 a,b 的最近公共祖先, d 为 c 的父亲节点。这样就是 四个版本运算。

同时:二分可以直接在树上跑,判断 左半区域的和 是否大于K,大于K 说明第K大的值 还在 左区间, 相反在右区间里查 第K -sum 大的数。

复杂度 O(nlgn) 如果直接二分区间 复杂度是O(nlgnlgn)。

倍增 LCA 算法:

const int K = 18;
int d[maxn];
int p[maxn][K];
void dfs(int rt,int f){  
    d[rt]=d[f]+1;  
    p[rt][0]=f;  
    int pos = mp[num[rt]];
    root[rt] = update(pos,1,n,1,root[f]);
    for(int i=1;i<K;i++) p[rt][i] = p[p[rt][i-1]][i-1];
    for(int i=head[rt];i!=-1;i= edge[i].next){
    	 int son = edge[i].v;
    	 if(son==f)continue;
    	 dfs(son,rt);
    }  
}  

int lca(int a,int b){
    if(d[a]>d[b]) swap(a,b);
    if(d[a]<d[b]){
        int del = d[b]-d[a];
        for(int i=0;i<K;i++) if(del &(1<<i)) b= p[b][i];
    }
    if(a!=b){
        for(int i= K-1;i>=0;i--){
            if(p[a][i]!= p[b][i]){
                a = p[a][i],b = p[b][i];
            }
        }
        a= p[a][0],b = p[b][0];
    }
    return a;
}

代码:

#include <vector>
#include <list>
#include <map>
#include <set>
#include <deque>
#include <stack>
#include <cstring>
#include <bitset>
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <assert.h>
#include <queue>
#define REP(i,n) for(int i=0;i<n;i++)
#define TR(i,x) for(typeof(x.begin()) i=x.begin();i!=x.end();i++)
#define ALLL(x) x.begin(),x.end()
#define SORT(x) sort(ALLL(x))
#define CLEAR(x) memset(x,0,sizeof(x))
#define FILLL(x,c) memset(x,c,sizeof(x))
using namespace std;
const double eps = 1e-9;
#define LL long long 
#define pb push_back
const int maxn  = 101000;
const int K = 18;
int n ,m ;
int num[maxn];
int d[maxn];
int p[maxn][K];
map<int,int>mp;
map<int,int>::iterator it;
int idx[maxn];
int head[maxn];
struct Edge{
    int v;
    int next;
}edge[2*maxn];
int tot;
void init(){
    memset(head,-1,sizeof(head));
    CLEAR(d);
    CLEAR(p);
    tot = 0;
}
void add(int u,int v){
    tot ++;
    edge[tot].v= v;
    edge[tot].next = head[u];
    head[u] = tot;
}


struct Node{
	Node *l,*r;
	int sum;
}nodes[maxn*40];
Node *root[maxn];
Node *null;
int C;
void inits(){
	C= 0;
	null = &nodes[C++];
	root[0] = null;
	null->l = null->r = null;
	null->sum = 0;
}
Node *update(int pos,int left ,int right,int val,Node *root){
	 Node *rt = &nodes[C++];
	 rt->l = root->l;
	 rt->r = root->r;
	 rt->sum = root->sum;
	 if(left ==right){
	 	  rt->sum +=val;
	      return rt;
	 }
	 int mid =(left +right)/2;
	 if(pos<=mid){
	 	rt ->l =update(pos,left,mid,val,root->l);
	 }else{
	 	rt ->r = update(pos,mid+1,right,val,root->r);
	 }
	 rt->sum = rt->l->sum + rt->r->sum;
	 return rt;
}
int query(int k,int left ,int right,Node *rt,Node *rt2,Node *rt3,Node *rt4){
//	cout << left << " lr "<<right<<endl;
	 if(left ==right){
	 	return left;
	 }
	 int mid = (left +right)/2;
	// cout <<rt->sum<<" "<< rt2->sum <<"   "<<rt3->sum<<" "<<rt4->sum<<endl;
	 int sum = rt->l->sum + rt2->l->sum - rt3->l->sum - rt4->l->sum;
	// cout << sum <<" sum k " << k << " "<<mid << endl;
	 if(sum>=k){
	 	  return query(k,left,mid,rt->l,rt2->l,rt3->l,rt4->l);
	 }else{
	 	return query(k-sum,mid+1,right,rt->r,rt2->r,rt3->r,rt4->r);
	 }
}
int get(int a,int b,int c,int d,int  k){
	return query(k,1,n,root[a],root[b],root[c],root[d]);
}

void dfs(int rt,int f){  
    d[rt]=d[f]+1;  
    p[rt][0]=f;  
    int pos = mp[num[rt]];
    root[rt] = update(pos,1,n,1,root[f]);
    for(int i=1;i<K;i++) p[rt][i] = p[p[rt][i-1]][i-1];
    for(int i=head[rt];i!=-1;i= edge[i].next){
    	 int son = edge[i].v;
    	 if(son==f)continue;
    	 dfs(son,rt);
    }  
}  
int lca(int a,int b){
	if(d[a]>d[b]) swap(a,b);
	if(d[a]<d[b]){
		int del = d[b]-d[a];
		for(int i=0;i<K;i++) if(del &(1<<i)) b= p[b][i];
	}
	if(a!=b){
		for(int i= K-1;i>=0;i--){
			if(p[a][i]!= p[b][i]){
				a = p[a][i],b = p[b][i];
			}
		}
		a= p[a][0],b = p[b][0];
	}
	return a;
}
void solve(){
    init();
	inits();
	for(int i =1;i<n;i++){
		int u,v;
		scanf("%d%d",&u,&v);
		add(u,v);
		add(v,u);
	}

	dfs(1,0);
	for(int i=1;i<=m;i++){
		int a,b,k;
		scanf("%d%d%d",&a,&b,&k);
		int t1 = lca(a,b);
		int t2 = p[t1][0];
		int ans = get(a,b,t1,t2,k);
		printf("%d\n",idx[ans]);
	}
}


int main(){
    while(~scanf("%d%d",&n,&m)){
    	mp.clear();
    	for(int i=1;i<=n;i++){
    		scanf("%d",&num[i]);
    		mp[num[i]] = 1;
    	}
    	int tot2 = 0;
        for(it = mp.begin();it!=mp.end();it++){
        	tot2 ++ ;
        	it->second = tot2;
        	//cout << tot2 << "  "<<it->first<<endl;
        	idx[tot2] = it->first; 
        }   
        
        solve();
    }
    return 0;
}


你可能感兴趣的:(spoj COT 可持久化数据结构 (LCA模版))