zoj 3649 lca 倍增 DP

下课后直接奔实验室,月赛都快结束了,我擦,随便挑了一道题开始搞起来,是我喜欢的图论,心里暗暗欣喜,可是还是有点绕人,比赛结束刚好敲完,吃晚饭后,调了调,AC,呵呵

比赛的时候A的貌似很少,其实也不太难

链接:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3649

下面是我做过的倍增法求LCA以及倍增的DP的两个练习题

http://blog.csdn.net/haha593572013/article/details/7796497

http://blog.csdn.net/haha593572013/article/details/7855282

题意:最大生成树的那个就不说了,两个算法的叠加,没啥意思,抽象一下就是给你一棵树,每个点都有点权,然后有很多询问

每个询问是两个数x   y,然后要你求出最大差值,最大差值是这样的:在这棵树上从x走到y会得到一个点权的序列


 c1, c2, c3, ... ,ci

 find the maximum ck-cj (ck >= cj, j <= k). 

关键的难点在于一定要用后面的数减去前面的数

解法:

开始的时候一直在纠结最大减最小,还要最大的在最小的后面,如果没有后面这个限制,那就直接树链剖分或者倍增都可以解决,后面的限制是本题的亮点,仔细想想,应该会需要用到求lca,求某个点到lca点权的最大值或者最小值,然后再仔细一想,想在log(n)的时间内求出u到lca的“最大差值”,除了数据结构还能有什么(数据结构显然有点无力),可能是我太弱,实在想不出用什么数据结构可以搞定,然后我就转向倍增DP的思路,还是和上面第二个链接的DP类似,只不过这道题需要一系列的DP数组

p[u][i]表示u的2的i次个祖先

mx[u][i]表示u到u的2的i次个祖先之间的最大点权值

mi[u][i]表示u到u的2的i次个祖先之间的最小点权值

dp[u][i]表示u到u的2的i次个祖先之间的最大差值

dp2[u][i]表示u的2的i次个祖先到u之间的最大差值

之所以需要dp2是因为从x到lca再从lca到y的路线是两个相反的方向

求mx mi数组很简单,和求p数组一模一样

dp数组的求法也不难

路径利用二进制被分成了一段一段,除了每一段的信息可以更新当前状态,段与段之间也需要考虑

然后在求答案的时候也是类似的方法

具体见代码吧,应该能看懂的,都很简单- -

#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
const int inf = ~0u>>2;
const int maxn = 30010;
const int POW = 16;
int mi[maxn][POW],mx[maxn][POW],p[maxn][POW];
int f[maxn];
int find(int x) {return x==f[x] ? x : f[x]=find(f[x]);}
struct EDGE{
	int s,t,w;
}e[50010];
int cmp(EDGE x,EDGE y){
	return x.w>y.w;
}
vector<int> edge[maxn];
int n,m;
int val[maxn];
bool vis[maxn];
int d[maxn];

int dp[maxn][POW],dp2[maxn][POW];
inline int max(int a,int b) {
	return  a>b?a:b;
}
inline int min(int a,int b) {
	return a<b?a:b;
}
void dfs(int u,int f){  
	d[u]=d[f]+1;
	vis[u]=true;
    int sz=edge[u].size(),j;  
    for(int i=0;i<sz;i++){  
        int v=edge[u][i];  
		if(vis[v])  continue;
        p[v][0]=u; 
		mi[v][0]=min(val[v],val[u]);
        mx[v][0]=max(val[v],val[u]);  
		dp[v][0]=val[u]-val[v];
		dp2[v][0]=val[v]-val[u];
        for(j=1;j<POW;j++) {
			p[v][j]=p[p[v][j-1]][j-1]; 
			mx[v][j]=max(mx[v][j-1],mx[p[v][j-1]][j-1]);  
			mi[v][j]=min(mi[v][j-1],mi[p[v][j-1]][j-1]);

			dp[v][j]=max(dp[v][j-1],dp[p[v][j-1]][j-1]);
			dp[v][j]=max(dp[v][j]  ,mx[p[v][j-1]][j-1]-mi[v][j-1]) ;

			dp2[v][j]=max(dp2[v][j-1],dp2[p[v][j-1]][j-1]);
			dp2[v][j]=max(dp2[v][j]  ,mx[v][j-1] - mi[p[v][j-1]][j-1]) ;
		}
        dfs(v,u);  
    }  
}  
int LCA( int a, int b ){  
	int i;
    if( d[a] > d[b] ) a ^= b, b ^= a, a ^= b;  
    if( d[a] < d[b] ){  
        int del = d[b] - d[a];  
        for(i = 0; i < POW; i++ ) if(del&(1<<i)) b=p[b][i];  
    }  
    if( a != b ){  
        for(i = POW-1; i >= 0; i-- )   
            if( p[a][i] != p[b][i] )   
                 a = p[a][i] , b = p[b][i];  
        a = p[a][0], b = p[b][0];  
    }  
    return a;  
}  
void init(int n)  {
	 memset(vis,false,sizeof(vis));
	 fill(p[0],p[n+1],0);
	 fill(mx[0],mx[n+1],-inf);
	 fill(mi[0],mi[n+1],inf);
	 fill(dp[0],dp[n+1],-inf);
	 fill(dp2[0],dp2[n+1],-inf);
	 d[0]=0;
     dfs(1,0);
}
int getmin(int u,int lca,int dp[][POW]) {
	int ans=inf;
    int del=d[u] - d[lca];
    for(int i=POW-1;i>=0;i--) if(del & (1<<i)) {
		ans=min(ans,dp[u][i]);
		u=p[u][i];
	}
    return ans;
}
int getmax(int x,int lca,int dp[][POW]){
    int ans=0;
	int del=d[x]-d[lca];
	for(int i=POW-1;i>=0;i--) if(del & (1<<i)){
		ans=max(ans,dp[x][i]);
		x=p[x][i];
	}
	return ans;
}
int gao1(int x,int lca,int dp[][POW]) {
	int ans=0,tmp=0;
	int del=d[x]-d[lca];
	for(int i=POW-1;i>=0;i--) if(del & (1<<i)) {
		ans=max(ans,dp[x][i]);
		ans=max(ans,tmp-mi[x][i]);
		tmp=max(tmp,mx[x][i]);
		x=p[x][i];
	}
	return ans;
}
int gao2(int x,int lca,int dp[][POW]) {
		int ans=0,tmp=inf;
	int del=d[x]-d[lca];
	for(int i=POW-1;i>=0;i--) if(del & (1<<i)) {
		ans=max(ans,dp[x][i]);
		ans=max(ans,-(tmp-mx[x][i]));
		tmp=min(tmp,mi[x][i]);
		x=p[x][i];
	}
	return ans;
}
void solve(int x,int y) {
     int lca=LCA(x,y);
	 int a,b,c,d;
	 a=gao2(x,lca,dp);
	 b=gao1(y,lca,dp2);
	 c=getmax(y,lca,mx);
	 d=getmin(x,lca,mi);
	 int ans=max(max(a,b),c-d);
	 printf("%d\n",ans);
}
int main() {
	int i,j,k,x,y,q;
	while(scanf("%d",&n)!=EOF){
		for(i=1;i<=n;i++) scanf("%d",&val[i]),f[i]=i,edge[i].clear();
		scanf("%d",&m);
		for(i=0;i<m;i++)
			scanf("%d%d%d",&e[i].s,&e[i].t,&e[i].w);
		sort(e,e+m,cmp);
		int sum=0;
		for(i=0;i<m;i++)	{
             x=find(e[i].s);
			 y=find(e[i].t);
             if(x!=y) {
				 edge[e[i].s].push_back(e[i].t);
				 edge[e[i].t].push_back(e[i].s);
				 f[x]=y;
				 sum+=e[i].w;
			 }
		}
		printf("%d\n",sum);
		init(n);
		scanf("%d",&q);
		while(q--)	{
			scanf("%d%d",&x,&y);
            solve(x,y);
		}
	}
	return 0;
}


你可能感兴趣的:(数据结构,c,算法,struct)