树上倍增法求最近公共祖先LCA

LCA,最近公共祖先,这个东西有很多作用,因此,如何高效求出LCA就成了一个热点的讨论话题。

树上倍增法求最近公共祖先LCA_第1张图片

下面所有的讨论都以图中这棵树为例子。

先来了解下什么是倍增吧,倍增其实就是二分的逆向,二分是逐渐缩小范围,而倍增是成倍扩大。这里的倍增借用二进制来表达更容易理解;倍增的做法是先求出 20,21,22, ,然后任意一个数字都可以用 20,21,22, 相加来表示,就像给你32个1,你能表示出32-bit 中的任意一个二进制一样。

倍增有什么好处呢,好处就是!倍增是一种优化手段,能提升查找等操作的效率的手段,其提升效率的原因就是二进制思想,提升的幅度为 O(n)O(logn) ,具体的解释可以参照树状数组简单易懂的详解,数组数组的思想就是基于倍增来实现的。

这里为什么要说成树上倍增呢?因为这个算法的操作都是在树上完成的,没错,求LCA的方法还有很多,比如RMQ-ST算法也可以做,这个算法的思想也是倍增,只不过这个倍增体现在区间上,而树上倍增法求最近公共祖先LCA的倍增体现在树的深度上。

先说说朴素的做法,求两个结点的最近公共祖先,我会让一个结点先向上走到根,并记录下它走的路径,然后然让另一个结点也向上往根走,边走边在先前记录的路径中查找是否存在该结点。举个例子,求lca(3, 6),先让3走到1,路径为3, 2, 1,然后让6走到16在序列[3, 2, 1]中查找,没有找到,继续走,走到44在序列[3, 2, 1]中查找,没有找到,继续走,走到22在序列[3, 2, 1]中查找,找到了,那么lca(3, 6) = 2

分析下朴素做法的时间复杂度,算法中需要让两个结点依次走到根,且在一个结点移动的过程中还需要在路径序列中查找;假设树有n个结点,由于树可能退化成链,因此从某一个结点移动到根这个操作的时间复杂度为 O(n) ,而查找这个操作可以使用set这一类容器,故时间复杂度为 O(logn) ,因此朴素算法求一次LCA的时间复杂度为 O(nlogn) ,假设要多次求LCA,这个时间复杂度显然是不能接受的。

然而,受剑指Offer66题之每日6题 - 第六天中第六题:两个链表的第一个公共结点 的启发,在树上求LCA和在两个链表的第一个公共结点是一样的,因此,朴素做法就有三种了,大家可以去剑指Offer66题之每日6题 - 第六天详细了解,这里就不多赘述了。

LCA用得普遍的地方就是求树中两个结点之间的最短路:dis[u, v] = dis[root, u] + dis[root, v] - 2 * dis[root, lca(u, v)]


现在就来好好说下树上倍增法求最近公共祖先LCA的算法了。

思想

算法的思想很简单,同剑指Offer66题之每日6题 - 第六天中第六题两个链表的第一个公共结点中的 O(n) 的做法一样,把两个结点移动到同一高度,然后一起向根走,一边走一边比较两个结点是否相等就行了。

但是这样做,时间复杂度还是 O(n) ,问题的规模较大时,复杂度还是不能接受,因此,树上倍增就是来提升这个效率的,树上倍增把移动这个操作提速了,原来只能一步一步移动,现在可以移动多步了。

具体是怎么移动的呢?请看完预处理,然后接着看LCA就知道了。

预处理

首先,要预处理出数中每一个结点的深度dep以及到根的距离dis,前面也提到了,树上倍增是树深度的倍增,自然需要每一个结点的dep

然后,要预处理出每一个结点的第 2i 个祖先pd[u][i],什么意思呢,举个例子就明白了,例如结点11的第 20=1 个祖先是9,第 21=2 个祖先是8,第 22=4 个祖先是1。这一步就是要为倍增提供”零件“。

第一步可以使用 dfs 预处理出来,第二步,可以使用动态规划处理出来,pd[u][i] = pd[pd[u][i - 1]][i - 1],画个图就理解了。

树上倍增法求最近公共祖先LCA_第2张图片

结点C的第 22=4 个祖先等于结点C的第 21=2 个祖先B的第 21=2 个祖先A。

LCA

预处理完成后,剩下的事情就是向根结点移动了;

第一步求出两个结点之间的高度差,让较深的那个结点移动到另一个结点一样的高度上,如果是朴素算法需要一步一步移动,而树上倍增算法把这个高度差表示成二进制,从而把这个移动转化成二进制的数位上移动,这样子,复杂度一下子就降到了 O(logn) 。举个例子,高度差diff = 6(110),那么较深的结点先移动2,这时高度差变为4,然后较深的结点移动4,这时两个结点的高度一样了。

第二步就是两个结点同时向根移动,先看看两个结点最远的祖先是否相同,如果相同,说明最近的祖先还可能没出现,于是再看看两个结点第二远的祖先是否相同;如果两个结点最远的祖先不相同,说明这两个结点正在接近最近公共祖先,故把这两个结点同时移动到对应的祖先处。以此类推,最终可以得到最近公共祖先。这里距离都是 2i ,原因在第一步中已经说明。

代码

宏,全局变量

/**
 * 直系祖先,pd[u][0]
 */
#define NUM_PARENT 0
/**
 * 树中结点的最大数目
 */
#define MAXSIZE (40000 + 5)

/**
 * 求二进制中最高一位1的index
 */
#define BITOFBINARY(x) ((int)(log((x) * 1.0) / log(2.0)))

/**
 * 求二进制中最低一位1所表示的数值
 */
int lowbit(int x)
{
    return x & -x;
}

/**
 * 树高的最大幂次
 */
const int MAXDEP = BITOFBINARY(MAXSIZE);

/**
 * 每个结点的深度,距根结点的距离
 */
int dep[MAXSIZE], dis[MAXSIZE];

/**
 * 每个结点的不同深度幂次的祖先
 */
int pd[MAXSIZE][MAXDEP + 1];

预处理

/**
 * 求出每个结点的深度,距离根的距离及它们的直系祖先
 */
void init_dfs(int src)
{
    for (int i = head[src]; i + 1; i = edges[i].next) {
        int to = edges[i].to;

        // 领接表建树,避免重复访问
        if (to == pd[src][NUM_PARENT])
            continue;
        dep[to] = dep[src] + 1;
        dis[to] = dis[src] + edges[i].val;
        pd[to][NUM_PARENT] = src;
        init_dfs(to);
    }
}

/**
 * 动态规划求出每个结点不同距离的祖先
 */
void init_redouble()
{
    for (int power = 1; power <= MAXDEP; ++power)
        for (int i = 1; i <= n; i++)
            pd[i][power] = pd[pd[i][power - 1]][power - 1];
}

LCA

int lca(int x, int y)
{
    // 始终保持x结点的深度较深
    if (dep[x] < dep[y])
        swap(x, y);

    // 求出高度差,并使x移动到同y一样的高度
    for (int diff = dep[x] - dep[y]; diff; diff -= lowbit(diff))
        x = pd[x][BITOFBINARY(lowbit(diff))];

    // 处理x和y是同一个结点或y是x的祖先这两种情况
    if (x == y)
        return x;

    // x和y一样的高度,同时移动x, y
    for (int i = MAXDEP; i >= 0; --i)
        if (pd[x][i] != pd[y][i])
            x = pd[x][i],
            y = pd[y][i];
    return pd[x][NUM_PARENT];
}

完整代码

这里结合一个题目背景,HDU2586:How far away?,完整地给出代码。

这个题目的意思是:给你n个点,n - 1条边的最小生成树,然后给你m次询问,每次询问树中任意两个结点之间的最短路。

做法是随便令一个结点为根,然后用树上倍增的方法求lca,然后利用dis[u, v] = dis[root, u] + dis[root, v] - 2 * dis[root, lca(u, v)]可以求得答案。

n达到了40000m达到了200,朴素做法或许行得通,但我没试过。

#include 

using namespace std;

#define MAXSIZE (40000 + 5)
#define NUM_PARENT 0

#define BITOFBINARY(x) ((int)(log((x) * 1.0) / log(2.0)))

typedef struct Edge Edge;

struct Edge {
    int to, val;
    int next;
    Edge() {};
    Edge(int to, int val, int next = -1) :
        to(to), val(val), next(next) {}
};

int n, m;
Edge edges[MAXSIZE * 2];
int head[MAXSIZE];

int lowbit(int x)
{
    return x & -x;
}

void add_edge(int x, int y, int val, int i)
{
    edges[i] = Edge(y, val, head[x]);
    head[x] = i;
}

const int MAXDEP = BITOFBINARY(MAXSIZE);

int dep[MAXSIZE], dis[MAXSIZE];
int pd[MAXSIZE][MAXDEP + 1];

void init_dfs(int src)
{
    for (int i = head[src]; i + 1; i = edges[i].next) {
        int to = edges[i].to;
        if (to == pd[src][NUM_PARENT])
            continue;
        dep[to] = dep[src] + 1;
        dis[to] = dis[src] + edges[i].val;
        pd[to][NUM_PARENT] = src;
        init_dfs(to);
    }
}

void init_redouble()
{
    for (int power = 1; power <= MAXDEP; ++power)
        for (int i = 1; i <= n; i++)
            pd[i][power] = pd[pd[i][power - 1]][power - 1];
}

int lca(int x, int y)
{
    if (dep[x] < dep[y])
        swap(x, y);

    for (int diff = dep[x] - dep[y]; diff; diff -= lowbit(diff))
        x = pd[x][BITOFBINARY(lowbit(diff))];

    if (x == y)
        return x;

    for (int i = MAXDEP; i >= 0; --i)
        if (pd[x][i] != pd[y][i])
            x = pd[x][i],
            y = pd[y][i];
    return pd[x][NUM_PARENT];
}

int main()
{
    int T;
    for (scanf("%d", &T); T--; ) {
        int x, y, val;
        scanf("%d%d", &n, &m);

        int root = 1;

        memset(head, -1, sizeof(head));
        memset(pd, 0, sizeof(pd));
        dis[root] = 0;
        dep[root] = 1;

        for (int i = 0; i < 2 * (n - 1); i += 2) {
            scanf("%d%d%d", &x, &y, &val);
            add_edge(x, y, val, i);
            add_edge(y, x, val, i + 1);
        }

        init_dfs(root);
        init_redouble();

        for (; m--; ) {
            scanf("%d%d", &x, &y);
            printf("%d\n", dis[x] + dis[y] - 2 * dis[lca(x, y)]);
        }
    }
    return 0;
}

复杂度

预处理中,init_dfs的时间复杂度为 O(n) init_redouble的时间复杂度为 O(nlogn) ,所以总的复杂度为 O(nlogn)

由于倍增算法把树上的移动转为在二进制数位上的移动,故单次lca的时间复杂度为 O(logn) ,可以接受;

你可能感兴趣的:(算法)