codevs 2370 小机房的树 (lca)

学习了Tarjan求lca,结果是dep[u]+dep[v]-2*dep[lca]

#include 
#include 

const int MAXN = 5e4+10;
const int MAXQ = 8e4+10;

int F[MAXN];
int getf(int x)
{
    if(F[x] == -1) return x;
    return F[x] = getf(F[x]);
}

void bing(int u, int v)
{
    int t1 = getf(u);
    int t2 = getf(v);
    if(t1 != t2)
        F[t1] = t2;
}

bool vis[MAXN];
int ancestor[MAXN];
struct Edge
{
    int to,next,w;
}edge[MAXN*2];
int head[MAXN],tot;
int dep[MAXN];

void addedge(int u, int v, int w)
{
    edge[tot].to = v;
    edge[tot].w = w;
    edge[tot].next = head[u];
    head[u] = tot++;
}

struct Query
{
    int q,next;
    int index;
}query[MAXQ*2];
int answer[MAXQ];
int h[MAXQ];
int tt;
int Q;

void add_query(int u, int v, int index)
{
    query[tt].q = v;
    query[tt].next = h[u];
    query[tt].index = index;
    h[u] = tt++;
}

void init()
{
    tot = tt = 0;
    memset(head,-1,sizeof(head));
    memset(dep,0,sizeof(dep));
    memset(h,-1,sizeof(h));
    memset(vis,0,sizeof(vis));
    memset(F,-1,sizeof(F));
    memset(ancestor,0,sizeof(ancestor));
}

void LCA(int u, int depth)
{
    dep[u] = depth;
    ancestor[u] = u;
    vis[u] = true;
    for(int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].to;
        if(vis[v]) continue;
        LCA(v,depth+edge[i].w);
        bing(u,v);
        ancestor[getf(u)] = u;
    }
    for(int i = h[u]; i != -1; i = query[i].next)
    {
        int v = query[i].q;
        if(vis[v])
        {
            answer[query[i].index] = dep[u] + dep[v] - 2*dep[ancestor[getf(v)]];
        }
    }
}

int main()
{
    int n,u,v,w;
    while(scanf("%d",&n) != EOF)
    {
        init();
        for(int i = 1; i <= n-1; ++i)
        {
            scanf("%d %d %d",&u,&v,&w);
            u++;
            v++;
            addedge(u,v,w);
            addedge(v,u,w);
        }
        scanf("%d",&Q);
        for(int i = 1; i <= Q; ++i)
        {
            scanf("%d %d",&u,&v);
            u++;
            v++;
            add_query(u,v,i);
            add_query(v,u,i);
        }
        LCA(1,0);
        for(int i = 1; i <= Q; ++i)
            printf("%d\n",answer[i]);
    }
    return 0;
}

学习了倍增lca,wa了好久发现写错变量了

#include 
using namespace std;
typedef long long LL;
const int MAXN = 1e5+10;
const int DEG = 20;

struct Edge
{
    int to,next,w;
}edge[MAXN*2];
int head[MAXN],tot;
void addedge(int u, int v, int w)
{
    edge[tot].to = v;
    edge[tot].w = w;
    edge[tot].next = head[u];
    head[u] = tot++;
}
void init()
{
    tot = 0;
    memset(head,-1,sizeof(head));
}
int fa[MAXN][DEG];
LL cost[MAXN][DEG];
int deg[MAXN];

void BFS(int root)
{
    queue<int> que;
    deg[root] = 0;
    fa[root][0] = root;
    cost[root][0] = 0;
    que.push(root);
    while(!que.empty())
    {
        int tmp = que.front();
        que.pop();
        for(int i = 1; i < DEG; ++i)
        {
            fa[tmp][i] = fa[fa[tmp][i-1]][i-1];
            cost[tmp][i] = cost[tmp][i-1] + cost[fa[tmp][i-1]][i-1];
        }

        for(int i = head[tmp]; i != -1; i = edge[i].next)
        {
            int v = edge[i].to;
            if(v == fa[tmp][0]) continue;
            deg[v] = deg[tmp]+1;
            fa[v][0] = tmp;
            cost[v][0] = edge[i].w;
            que.push(v);
        }
    }
}

LL LCA(int u, int v)
{
    LL ret = 0;
    if(deg[u] > deg[v]) swap(u,v);
    int hu = deg[u], hv = deg[v];
    int tu = u, tv = v;
    for(int det = hv-hu, i = 0; det; det >>=1, ++i)
    {
        if(det&1)
        {
            ret += cost[tv][i];
            tv = fa[tv][i];
        }
    }
    if(tv == tu) return ret;
    for(int i = DEG-1; i >= 0; --i)
    {
        if(fa[tu][i] == fa[tv][i])
            continue;
        ret += cost[tu][i];
        ret += cost[tv][i];
        tu = fa[tu][i];
        tv = fa[tv][i];
    }
    if(tu != tv)
        ret += (cost[tu][0]+cost[tv][0]);
    return ret;
}

int main()
{
    init();
    int n,u,v,w;
    scanf("%d",&n);
    for(int i = 0; i < n-1; ++i)
    {
        scanf("%d %d %d",&u,&v,&w);
        addedge(u,v,w);
        addedge(v,u,w);
    }
    int q;
    BFS(1);
    scanf("%d",&q);
    while(q--)
    {
        scanf("%d %d",&u,&v);
        printf("%lld\n",LCA(u,v));
    }
    return 0;
}

ST在线

#include 
using namespace std;
typedef long long LL;
const int MAXN = 100010;
int rmq[2*MAXN];
struct ST
{
    int mm[2*MAXN];
    int dp[2*MAXN][20];
    void init(int n)
    {
        mm[0] = -1;
        for(int i = 1; i <= n; i++)
        {
            mm[i] = ((i&(i-1)) == 0)?mm[i-1]+1:mm[i-1];
            dp[i][0] = i;
        }
        for(int j = 1; j <= mm[n]; j++)
            for(int i = 1; i + (1<1 <= n; i++)
                dp[i][j] = rmq[dp[i][j-1]] <
                           rmq[dp[i+(1<<(j-1))][j-1]]?dp[i][j-1]:dp[i+(1<<(j-1))][j-1];
    }
    int query(int a,int b)
    {
        if(a > b)swap(a,b);
        int k = mm[b-a+1];
        return rmq[dp[a][k]] <=
               rmq[dp[b-(1<[k]]?dp[a][k]:dp[b-(1<1][k];
    }
};

struct Edge
{
    int to,next,w;
};
Edge edge[MAXN*2];
int tot,head[MAXN];
LL dist[MAXN];
int F[MAXN*2];
int P[MAXN];
int cnt;
ST st;
void init()
{
    tot = 0;
    memset(head,-1,sizeof(head));
}
void addedge(int u, int v, int w)//加边,无向边需要加两次
{
    edge[tot].to = v;
    edge[tot].w = w;
    edge[tot].next = head[u];
    head[u] = tot++;
}
void dfs(int u,int pre,int dep)
{
    F[++cnt] = u;
    rmq[cnt] = dep;
    P[u] = cnt;
    for(int i = head[u]; i != -1; i = edge[i].next)
    {
        int v = edge[i].to;
        if(v == pre)continue;
        dist[v] = dist[u] + edge[i].w;
        dfs(v,u,dep+1);
        F[++cnt] = u;
        rmq[cnt] = dep;
    }
}
void LCA_init(int root,int node_num)//查询LCA前的初始化
{
    cnt = 0;
    dfs(root,root,0);
    st.init(2*node_num-1);
}
int query_lca(int u,int v)//查询u,v的lca编号
{
    return F[st.query(P[u],P[v])];
}

int main()
{
    int N;
    int u,v,w;

    scanf("%d",&N);
    init();
    for(int i = 1; i < N; i++)
    {
        scanf("%d%d%d",&u,&v,&w);
        addedge(++u,++v,w);
        addedge(v,u,w);
    }
    LCA_init(1,N);
    int q;
    scanf("%d",&q);
    while(q--)
    {
        scanf("%d%d",&u,&v);
        printf("%lld\n",dist[++u]+dist[++v]-2*dist[query_lca(u,v)]);
    }
    return 0;
}

你可能感兴趣的:(LCA)