CDOJ 92 – Journey 【LCA】

【题意】给出一棵树,有n个点(2≤N≤105),每条边有权值,现在打算新修一条路径,给出新路径u的起点v,终点和权值,下面给出Q(1≤Q≤105)个询问(a,b)问如果都按照最短路径走,从a到b节省了多少距离。

 

咱不妨把新修路径的一个端点u设为根结点,然后建树。

这样新路径另一端v一定连着它的子树的一个点。

 

从祖先u到孩子v再加上(u,v)构成一个环,这个是树中唯一的一个环,

如果新加的路径比不加新路径时从u到v的距离短:

对于询问a到b的路程改变量,如果a是v的孩子,而b是u的祖先或者是在另一棵子树中(与b不在同一棵子树中),显然经过新路径更划算。

那么如果a是v的孩子,而b在未加新路径时(u,v)路径的某一个点上,那么需要判断,先从v沿新路到u,再从u到b更近,还是从v直接到b更近。

 

我们发现,如果b是靠近u的一些点,采用前者更近,如果b是靠近v的一些点,采用后者更近,显然中间一定有一条边作为这两类点的分界线,无论b在什么位置,都一定不会经过这条边。

 

一开始在原树上做一次LCA,求出各个询问的路径长度。然后加上新边,查找出要删去的那条边,重新做一次LCA,求出询问的路径长度。求出差就是这种情况的答案。

 

但是这种算法不适用与a与b都在原树的u,v路径上的情况,这时有可能不走新路径长度更短,那么第二次求最短路径就调整为求两次路径长度中的最小值,这样就确保了差是最小值。

 

#include<cstdio>

#include<cstring>

#include<cmath>

#include<iostream>

#include<algorithm>

#include<set>

#include<map>

#include<stack>

#include<vector>

#include<queue>

#include<string>

#include<sstream>

#define eps 1e-9

#define ALL(x) x.begin(),x.end()

#define INS(x) inserter(x,x.begin())

#define FOR(i,j,k) for(int i=j;i<=k;i++)

#define MAXN 100015

#define MAXM 200015

#define INF 0x3fffffff

using namespace std;

typedef long long LL;

int i,j,k,n,m,x,y,T,ans,big,cas,num,len;

bool flag;

int u,v,w,d;

int edge,head[MAXN],u2,v2,w2;



int fa[MAXN],pre[MAXN],query[MAXN][4],dist[MAXN];

bool vis[MAXN];



struct node

{

    int v,id;

    node (int _v,int _id):v(_v),id(_id){}

};



int find(int x)

{

    if (x==fa[x]) return fa[x];

    return fa[x]=find(fa[x]);

}



vector <vector<node> > mp;



struct edgenode

{

    int from,to,next,w;

} G[MAXM];



void add_edge(int x,int y,int w)

{

    G[edge].from=x;

    G[edge].to=y;

    G[edge].w=w;

    G[edge].next=head[x];

    head[x]=edge++;

}





void tarjan(int u,int p)

{

    vis[u]=true;

    for (int i=0;i<mp[u].size();i++)

    {

        int v=mp[u][i].v,id=mp[u][i].id;

        if (vis[v]) query[id][p]=find(v);

    }

    

    for (int i=head[u];i!=-1;i=G[i].next)

    {

        int v=G[i].to,w=G[i].w;

        if (w==-1) continue;

        if (!vis[v])

        {

            dist[v]=dist[u]+w;

            tarjan(v,p);

            fa[v]=u;

            

            pre[v]=i;

        }

    }

}





int main()

{

    scanf("%d",&T);

    while (T--)

    {    

        memset(head,-1,sizeof(head));

        edge=0;

        

        scanf("%d%d",&n,&m);

        for (i=0;i<n-1;i++)

        {

            scanf("%d%d%d",&u,&v,&w);

            add_edge(u,v,w);

            add_edge(v,u,w);

        }

        scanf("%d%d%d",&u2,&v2,&w2);

        

        mp.clear();

        mp.resize(n+4); 

        

        for (i=0;i<m;i++)

        {

            scanf("%d%d",&u,&v);

            query[i][0]=u;

            query[i][1]=v;

            mp[u].push_back(node(v,i));

            mp[v].push_back(node(u,i));

        }

        for (i=1;i<=n;i++) fa[i]=i;

        memset(vis,0,sizeof(vis));

        dist[u2]=0;

        tarjan(u2,2);

        

        for (i=0;i<m;i++)

        {

            int u=query[i][0];

            int v=query[i][1];

            int d=query[i][2];

            query[i][2]=dist[u]+dist[v]-2*dist[d];

        }

        

        int p=v2;

        int t1=w2,t2=dist[v2];

        printf("Case #%d:\n",++cas);

        

        if (t1<t2)

        {    

            while (p!=u2)

            {

                int v=G[pre[p]].from;

                int w=G[pre[p]].w;

                

                t1+=w;

                t2-=w;

                if (t1>t2)

                {

                    G[pre[p]].w=-1;

                    G[pre[p]^1].w=-1;

                    break;

                }

                p=v;

            }

            memset(vis,0,sizeof(vis));

            dist[u2]=0;

            for (int i=1;i<=n;i++) fa[i]=i;

            add_edge(u2,v2,w2);

            add_edge(v2,u2,w2);

            tarjan(u2,3);

            

            for (int i=0;i<m;i++)

            {

                int u=query[i][0];

                int v=query[i][1];

                int d=query[i][3];

                int dis=dist[u]+dist[v]-2*dist[d];

                if (dis<query[i][2])

                printf("%d\n",query[i][2]-dis);

                else printf("0\n");

            }

        }else

        {

            for(int i=0;i<m;i++)

            {

                 printf("0\n");

            }

        }

        

    }

    return 0;

}

 

你可能感兴趣的:(ca)