LCA(倍增,RMQ,Tarjan)

LCA(Least Common Ancestors):最近公共祖先

题目:https://www.luogu.org/problemnew/show/P3379

倍增

先预处理出每个结点向上跳2^x层的祖先和每个结点的深度

类似快速幂,拆分deep[u] - deep[v](假设deep[u] > deep[v]),每次使u向上跳2^x步,使deep[u] = deep[v]

然后再一起往上跳,dis = deep[lca(u,v)] - deep[u],由于dis我们不知道具体多少并且lca后面的结点都是u,v的公共祖先,所以不能像上面的一样直接拆分,于是换个思路,我们去找的最远的非公共祖先,最后肯定是到lca的子节点

比如dis = 10100(2)

从大到小枚举2^n次方,2^4前面的明显就是公共祖先

2^4时发现这时候不是公共祖先,于是先跳到这里,dis = 100(2)

同理下一步之后dis = 11(2),最后dis = 1(2)循环结束

然后lca就等于这时候的v(u = v)的直接祖先

#include

using namespace std;

const int maxn = 5e5 + 50;
const int maxlog = 20;

int n, q, root;
int par[maxlog + 5][maxn];
int dep[maxn];
int head[maxn], edgecnt;
struct P
{
    int to, next;
}edge[2 * maxn];

void dfs(int u, int fa)
{
    par[0][u] = fa;
    dep[u] = dep[fa] + 1;
    for(int i = head[u]; ~i; i = edge[i].next)
    {
        if(edge[i].to == fa) continue;
        dfs(edge[i].to, u);
    }
}

void init()
{
    dfs(root, -1);
    for(int k = 0; k + 1 < maxlog; k++)
    {
        for(int u = 1; u <= n; u++)
        {
            if(par[k][u] == -1)
                par[k + 1][u] = -1;
            else
                par[k + 1][u] = par[k][par[k][u]];
        }
    }
}

int lca(int u, int v)
{
    if(dep[u] > dep[v])
    {
        swap(u, v);
    }
    for(int i = 20; ~i; i--)
    {
        if(dep[par[i][v]] >= dep[u]) v = par[i][v];
    }
    if(v == u) return v;

    for(int i = maxlog; ~i; i--)
    {
        if(par[i][u] != par[i][v])
        {
            u = par[i][u];
            v = par[i][v];
        }
    }
    return par[0][u];
}

void add(int u, int v)
{
    edge[edgecnt].to = v;
    edge[edgecnt].next = head[u];
    head[u] = edgecnt++;
}

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    cin >> n >> q >> root;
    memset(head, -1, sizeof(head));

    int u, v;
    for(int i = 0; i < n - 1; i++)
    {
        cin >> u >> v;
        add(u, v);
        add(v, u);
    }

    init();
    while(q--)
    {
        cin >> u >> v;
        printf("%d\n", lca(u, v));
    }
    return 0;
}

基于RMQ(ST表)的算法(对于LCA ±1RMQ也可以做)

vis[i]为第i次dfs所到达的结点

dep[i]为vis[i]的深度

id[u]为u第一次被dfs时的vis编号

在dfs序中易知lca(u, v) = vis[{i, id[u] <= i <= id[v]}中dep最小的i]

上面放的模板题有点卡常数,所以放了快速read

#include

using namespace std;

const int maxn = 10e5 + 50;
const int maxlog = 20;

int n, q, root, cnt;
int head[maxn], edgecnt;

int dp[maxn][maxlog];
int rec[maxn][maxlog];
int dep[maxn];
int vis[maxn];
int id[maxn];
struct P
{
    int to, next;
}edge[maxn];

void add(int u, int v)
{
    edge[edgecnt].to = v;
    edge[edgecnt].next = head[u];
    head[u] = edgecnt++;
}

void DFS(int u, int fa, int d, int &cnt)
{
    dep[cnt] = d;
    vis[cnt] = u;
    id[u] = cnt++;
    for(int i = head[u]; ~i; i = edge[i].next)
    {
        int v = edge[i].to;
        if(v != fa)
        {
            DFS(v, u, d + 1, cnt);
            dep[cnt] = d;
            vis[cnt++] = u;
        }
    }
}

void ST()
{
    for(int i = 0; i < cnt; i++)
        dp[i][0] = dep[i], rec[i][0] = vis[i];
    for(int j = 1; (1 << j) <= cnt; j++)
    {
        for(int i = 0; i + (1 << j) - 1 < cnt; i++)
        {
            if(dp[i][j - 1] > dp[i + (1 << (j - 1))][j - 1])
                dp[i][j] = dp[i + (1 << (j - 1))][j - 1], rec[i][j] = rec[i + (1 << (j - 1))][j - 1];
            else
                dp[i][j] = dp[i][j - 1], rec[i][j] = rec[i][j - 1];
        }
    }
}

void init()
{
    cnt = 0;
    DFS(root, -1, 0, cnt);
    ST();
}

inline int RMQ(int l, int r)
{
    int k = 0;
    while(1 << (k + 1) <= r - l + 1) k++;
    if(dp[l][k] > dp[r - (1 << k) + 1][k])
            return rec[r - (1 << k) + 1][k];
    else
            return rec[l][k];
}

inline int read()
{
    int x=0,flag=0;
    char ch=getchar();
    if(ch=='-') flag=1;
    while(ch<'0'||ch>'9')ch=getchar();
    while(ch>='0'&&ch<='9')x*=10,x+=ch-'0',ch=getchar();
    if(flag) return -x;
    return x;
}

int main()
{
    //ios_base::sync_with_stdio(0);
    //cin.tie(0); cout.tie(0);

    n=read(),q=read(),root=read();
    //cin >> n >> q >> root;
    memset(head, -1, sizeof(head));

    int u, v;
    for(int i = 0; i < n - 1; i++)
    {
        u = read(), v = read();
     //   cin >> u >> v;
        add(u, v);
        add(v, u);
    }

    init();
    int l, r;
    while(q--)
    {
        l = read(), r = read();
     //   cin >> l >> r;
        l = id[l], r = id[r];
        if(l > r) swap(l, r);
        printf("%d\n", RMQ(l, r));
        //cout << RMQ(l, r) << endl;
    }
    return 0;
}

Tarjan

离线算法

在dfs时, 对于查询u,v来说,若遍历到u时发现v已经遍历过了,这时lca(u,v)就是用并查集维护的集合的root

注意合并的时候将父节点作为并查集的上级即u->v合并时par[v] = u

#include

using namespace std;

const int maxn = 1e6 + 50;

int n, q, root;
int head[maxn], edgecnt;
int qhead[maxn], quecnt;
int par[maxn];
int ans[maxn];
int vis[maxn];

struct P
{
    int to, next;
}edge[maxn];
struct P1
{
    int to, num, lca, next;
}que[maxn];

int Find(int u)
{
    if(u == par[u])
        return u;
    return par[u] = Find(par[u]);
}
void Merge(int x, int y)
{
    x = Find(x);
    y = Find(y);
    if(x != y)
    {
        par[x] = y;
    }
}

void add(int u, int v)
{
    edge[edgecnt].to = v;
    edge[edgecnt].next = head[u];
    head[u] = edgecnt++;
}

void addque(int u, int v, int i)
{
    que[quecnt].to = v;
    que[quecnt].num = i;
    que[quecnt].next = qhead[u];
    qhead[u] = quecnt++;
}

inline int read()
{
    int x=0,flag=0;
    char ch=getchar();
    if(ch=='-') flag=1;
    while(ch<'0'||ch>'9')ch=getchar();
    while(ch>='0'&&ch<='9')x*=10,x+=ch-'0',ch=getchar();
    if(flag) return -x;
    return x;
}

void Tarjan(int u, int fa)
{
    for(int i = head[u]; ~i; i = edge[i].next)
    {
        int v = edge[i].to;
        if(v != fa)
        {
            Tarjan(v, u);
            Merge(v, u);
            vis[v] = 1;
        }
    }
    vis[u] = 1; 
    for(int i = qhead[u]; ~i; i = que[i].next)
    {
        int v = que[i].to;
        if(vis[v])
        {
            que[i].lca = Find(v);
            que[i^1].lca = que[i].lca;
            ans[que[i].num] = que[i].lca;
        }
    }
}

int main()
{
    //ios_base::sync_with_stdio(0);
    //cin.tie(0); cout.tie(0);

    n=read(),q=read(),root=read();
    //cin >> n >> q >> root;
    memset(head, -1, sizeof(head));
    memset(qhead, -1, sizeof(qhead));
    for(int i = 0; i <= n; i++) par[i] = i;

    int u, v;
    for(int i = 0; i < n - 1; i++)
    {
        u = read(), v = read();
        //cin >> u >> v;
        add(u, v);
        add(v, u);
    }

    for(int i = 0; i < q; i++)
    {
        u = read(), v = read();
        //cin >> u >> v;
        addque(u, v, i);
        addque(v, u, i);
    }
    Tarjan(root, -1);
    for(int i = 0; i < q; i++)
    {
        printf("%d\n", ans[i]);
        //cout << ans[i] << endl;
    }
    return 0;
}

树链剖分

..正在学习中

你可能感兴趣的:(图论)