LCA模板

tarjan:
#define up(i, j, k) for(int i = j; i <= k; ++i)
#define down(i, j, k) for(int i = j; i >= k; --i)
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;

int aim[20000],len[20000],nxt[20000],first[10001];
int _src[2000001],_aim[2000001],_lca[1000001],_nxt[2000001],_first[10001];
int fa[10001],visit[10001],dist[10001];
int tot,n,m,c;

void addedge(int u,int v,int w)
{
    aim[tot]=v; len[tot]=w; nxt[tot]=first[u]; first[u]=tot++;
}

void addask(int u,int v)
{
    _src[tot]=u; _aim[tot]=v; _nxt[tot]=_first[u]; _first[u]=tot++;
}

int find(int v)
{
    return fa[v]==v?v:fa[v]=find(fa[v]);
}

void tarjan(int u,int x)
{
    visit[u]=x;
    fa[u]=u;
    for (int i=first[u];~i;i=nxt[i])
        if (!visit[aim[i]])
        {
            dist[aim[i]]=dist[u]+len[i];
            tarjan(aim[i],x);
            fa[aim[i]]=u;
        }
    for (int i=_first[u];~i;i=_nxt[i])
        if (visit[_aim[i]]==x)
            _lca[i/2]=find(_aim[i]);
}

void init()
{
    memset(first,-1,sizeof(first));
    memset(_first,-1,sizeof(_first));
    memset(visit,0,sizeof(visit));
}

int main()
{
    while(scanf("%d%d%d",&n,&m,&c)!=EOF)
    {
        init();
        tot=0;
        up(i,1,m)
        {
            int u,v,w;
            scanf("%d%d%d",&u,&v,&w);
            addedge(u,v,w); addedge(v,u,w);
        }
        tot=0;
        up(i,0,c-1)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            addask(u,v); addask(v,u);
            _lca[i]=-1;
        }
        tot=0;
        up(i,1,n)
            if (!visit[i])
            {
                dist[i]=0;
                tarjan(i,++tot);
            }
        up(i,0,c-1)
            if (~_lca[i]) printf("%d\n",dist[_src[i*2]]+dist[_aim[i*2]]-2*dist[_lca[i]]);
            else printf("Not connected\n");
    }
}
rmq:
#define up(i, j, k) for(int i = j; i <= k; ++i)
#define down(i, j, k) for(int i = j; i >= k; --i)
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;

int tot;
int aim[80010],len[80010],nxt[80010],first[40010];
int visit[40010],seq[80010],deep[80010],seqfirst[40010],dist[40010];
int dp[80010][18];

void addedge(int u,int v,int w)
{
    aim[tot]=v; len[tot]=w; nxt[tot]=first[u]; first[u]=tot++;
}

void dfs(int u,int dep)
{
    visit[u]=1; seq[++tot]=u; deep[tot]=dep; seqfirst[u]=tot;
    for(int i=first[u];~i;i=nxt[i])
    {
        if (!visit[aim[i]])
        {
            dist[aim[i]]=dist[u]+len[i];
            dfs(aim[i],dep+1);
            seq[++tot]=u; deep[tot]=dep;
        }
    }
}

void st(int n)
{
    up(i,1,n) dp[i][0]=i;
    for(int j=1;(1<r) swap(l,r);
    return seq[rmq(l,r)];
}

void init()
{
    memset(first,-1,sizeof(first));
    memset(visit,0,sizeof(visit));
    memset(seqfirst,0,sizeof(seqfirst));
    dist[1]=0;
}

int main()
{
    int t;
    scanf("%d",&t);
    while (t--)
    {
        init();
        int n,m;
        scanf("%d%d",&n,&m);
        tot=0;
        up(i,1,n-1)
        {
            int u,v,w;
            scanf("%d%d%d",&u,&v,&w);
            addedge(u,v,w); addedge(v,u,w);
        }
        tot=0;
        dfs(1,1);
        st(2*n-1);
        up(i,1,m)
        {
            int u,v;
            scanf("%d%d",&u,&v);
            printf("%d\n",dist[u]+dist[v]-2*dist[lca(u,v)]);
        }
    }
}


你可能感兴趣的:(总结)