AK F.*ing leetcode 流浪计划之最近公共祖先(倍增算法)

欢迎关注更多精彩
关注我,学习常用算法与数据结构,一题多解,降维打击。

本期话题:在树上查找2个结点的最近公共祖先

问题提出

最近公共祖先定义

最近公共祖先简称 LCA(Lowest Common Ancestor)。两个节点的最近公共祖先,就是这两个点的公共祖先里面,离根最远(深度最深)的那个。

问题

参考地址:https://www.luogu.com.cn/problem/P3379
给定一棵树,询问每两个结点的最近公共祖先,一般会询问多次。

朴素做法

  1. 利用dfs求出所有结点的深度和父亲结点。
  2. 查询时把深度大的结点往上移,直到两个结点深度一样。然后两个结点同时往上移,直到两结点相遇。

复杂度分析

第1步求深度和父亲结点,需要遍历所有结点,复杂度是O(n)。
第2步在极端情况下是O(n) , 在多次查询的情况下,效率很低。

空间换时间

试想一下我们给每1个结点分配1个数组空间来存储往上移n个位置到达的祖先结点。
当我们要查询两个公共祖先时,就可以使用二分查找的方法来加速。


以A, B为例,可以看到后面黄色部分是公共祖先,我们要找的是最左边的10号祖先。只要利用二分查找即可找到。
该方法可以把查询复杂度降低到log(n). 但同时空间复杂度是O(n^2)。

优化空间(倍增算法)

参考资料:https://oi-wiki.org//graph/lca/#%E5%80%8D%E5%A2%9E%E7%AE%97%E6%B3%95
上面的方法的问题是空间分配的太多了,而且仔细观察,空间是冗余的。
比如A往上1个的祖先分配的数组和A的数组是高度重合的,可以看出是有递归或继承关系的。而且我们每次都是取的数组的一半。

那么我们可以存储往上数2^n个的祖先。
即存储往上1个,2个,4个。。。的祖先分别是谁。
查询的时候,由于任意数字都可以用2进制进行组合而成,可以遍历到所有祖先。
具体算法可以类比二分算法。

代码模板

题目链接:https://www.luogu.com.cn/problem/P3379



#include
#include
#include
#include
#include

using namespace std;

const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2];
int len;
int h[N];
int father[bitL][N];

void initPara(int n)
{
	len = 0;
	for (int i = 0; i < n; i++)
	{
		head[i] = -1;
	}
}

void add(int a, int b)
{
	to[len] = b;
	nextEdge[len] = head[a];
	head[a] = len++;
}

void dfs(int x, int fa)
{
	if (fa == -1) h[x] = 0;
	else {
		h[x] = h[fa] + 1;
		father[0][x] = fa;
		// 利用倍增算法初始化father
		for (int t = 1; t < bitL && (1<<t)<=h[x]; t++) {
			father[t][x] = father[t-1][father[t - 1][x]];
		}
	}
	int i;
	for (i = head[x]; i != -1; i = nextEdge[i])
	{
		int j = to[i];
		if (fa==j)continue;
		dfs(j, x);
	}
}

int lca(int a, int b) {
	if (h[a] < h[b]) {
		return lca(b, a);
	}

	// 先将两个结点跳到一样高度
	int gap = h[a] - h[b];
	for (int t = bitL-1; t>=0; t--) {
		if (gap & (1 << t))a = father[t][a];
	}
	if (a == b)return a;
	gap = h[a];
	// 利用二分查找找到深度最低的且不一样的结点。
	for (int t = bitL-1; t >= 0; t--) {
		if (gap <=(1 << t))continue;
		if (father[t][a] == father[t][b])continue;
		a = father[t][a];
		b = father[t][b];
		gap -= 1 << t;
	}
	
	return father[0][a]; // 再往上1个既是公共祖先
}

void solve()
{
	int t;
	int n, m, s;

	scanf("%d%d%d", &n, &m, &s);
	s--;
	initPara(n);
	int a, b;
	for (int i = 0; i < n - 1; ++i) {
		scanf("%d%d", &a, &b);
		a--, b--;
		add(a, b);
		add(b, a);
	}

	dfs(s, -1);

	/*for (int i = 0; i < n; ++i) {
		printf("%d: %d\n", i, h[i]);
	}*/

	while (m--) {
		scanf("%d%d", &a, &b);
		a--, b--;
		printf("%d\n", 1+lca(a, b));
	}
}

void test() {
	int t;
	int n=5000, m=500000, s=1;

	//scanf("%d%d%d", &n, &m, &s);
	s--;
	initPara(n);
	int a, b;
	for (int i = 0; i < n - 1; ++i) {
		a = i, b = i + 1;
		add(a, b);
		add(b, a);
	}

	dfs(s, -1);
	//printf("%d\n", 1 + lca(10, 5000-1));
	while (m--) {
		a = (m+102)%n, b =( 3823+m*2)%n;
		//printf("%d\n", m);
		if(lca(a, b)!=min(a,b))
		printf("%d %d %d\n", 1 + lca(a, b), a+1, b+1);
	}
}

int main()
{
	solve();
	//test();
	return 0;
}



/*

5 5 4
3 1
2 4
5 1
1 4
2 4
3 2
3 5
1 2
4 5

12 11 8
8 1
8 9
8 12
1 5
1 7
7 6
9 4
9 11
9 2
4 3
12 10

1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
*/

练习一

链接:https://loj.ac/p/10135
注意点:需要对结点进行编号,无公共祖先时返回-1


#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;



const int M = 500000 + 10;
const int N = 500000 + 10;

map<int, int> num2Ind;
int indLen;

void initIndex() {
    num2Ind.clear();
    indLen = 0;
}

int getIndex(int n) {
    if (num2Ind.count(n) == 0)
        return -1;

    return num2Ind[n];
}

int addIndex(int n) {
    if (num2Ind.count(n) == 0)
        num2Ind[n] = indLen++;

    return num2Ind[n];
}

const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2];
int len;
int h[N];
int father[bitL][N];

void initPara(int n) {
    len = 0;

    for (int i = 0; i < n; i++) {
        head[i] = -1;
    }
}

void add(int a, int b) {
    to[len] = b;
    nextEdge[len] = head[a];

    head[a] = len++;
}

void dfs(int x, int fa) {
    if (fa == -1)
        h[x] = 0;
    else {
        h[x] = h[fa] + 1;
        father[0][x] = fa;

        for (int t = 1; t < bitL && (1 << t) <= h[x]; t++) {
            father[t][x] = father[t - 1][father[t - 1][x]];
        }
    }

    int i;

    for (i = head[x]; i != -1; i = nextEdge[i]) {
        int j = to[i];

        if (fa == j)
            continue;

        dfs(j, x);
    }
}

int lca(int a, int b) {
    if (h[a] < h[b]) {
        return lca(b, a);
    }

    int gap = h[a] - h[b];

    for (int t = bitL - 1; t >= 0; t--) {
        if (gap & (1 << t))
            a = father[t][a];
    }

    if (a == b)
        return a;

    gap = h[a];

    for (int t = bitL - 1; t >= 0; t--) {
        if (gap <= (1 << t))
            continue;

        if (father[t][a] == father[t][b])
            continue;

        a = father[t][a];
        b = father[t][b];
        gap -= 1 << t;
    }

    return father[0][a];
}

void solve() {
    int n, m;
    int a, b, s;
    scanf("%d", &n);
    initPara(n);

    for (int i = 0; i < n; ++i) {
        scanf("%d%d", &a, &b);

        if (b == -1) {
            s = addIndex(a);
            continue;
        }

        a = addIndex(a);
        b = addIndex(b);
        add(a, b);
        add(b, a);
    }

    dfs(s, -1);

    scanf("%d", &m);
    /*for (int i = 0; i < n; ++i) {
        printf("%d: %d\n", i, h[i]);
    }*/

    while (m--) {
        scanf("%d%d", &a, &b);

        a = getIndex(a);
        b = getIndex(b);

        if (a < 0 || b < 0 || a == b) {
            puts("0");
            continue;
        }

        int lcab = lca(a, b);

        if (lcab == a)
            puts("1");
        else if (lcab == b)
            puts("2");
        else
            puts("0");
    }
}


int main() {
    solve();
    return 0;
}



/*
3
2 -1
1 2
3 1
2
1 2
2 3

2
1 -1
1 2
2
1 2
2 1


12
8 -1
8 1
8 9
8 12
1 5
1 7
7 6
9 4
9 11
9 2
4 3
12 10
11
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12


10
234 -1
12 234
13 234
14 234
15 234
16 234
17 234
18 234
19 234
233 19
5
234 233
233 12
233 13
233 15
233 19
*/

练习二

链接:https://loj.ac/p/2610
算法思路:先用最小生成算法把所有大的边加入到树中。
利用倍增算法建立祖先关系,以及到祖先链路上的最小负载。
查询A,B最小负载为=min(A到公共祖先最小负载,B到公共祖先最小负载)。
具体实现分别从A,B查找最近公共祖先时记录链路上的最小值。
注意点:需要事先判断是否可达。题目中规定A!=B。
利用并查集点击前往判断是否在一棵树中。

#include
#include
#include
#include
#include
#include

using namespace std;


class UnionFindSet {
private:
	vector<int> father; // 父结点定义,father[i]=i时,i为本集合的代表
	vector<int> height; // 代表树高度,初始为1
	int nodeNum; // 集合中的点数

public:
	UnionFindSet(int n); // 初始化
	bool Union(int x, int y); // 合并
	int Find(int x);

	bool UnionV2(int x, int y); // 合并
	int FindV2(int x);
};

UnionFindSet::UnionFindSet(int n) : nodeNum(n + 1) {
	father = vector<int>(nodeNum);
	height = vector<int>(nodeNum);
	for (int i = 0; i < nodeNum; ++i) father[i] = i, height[i] = 1; // 初始为自己
}

int UnionFindSet::Find(int x) {
	while (father[x] != x) x = father[x];
	return x;
}

bool UnionFindSet::Union(int x, int y) {
	x = Find(x);
	y = Find(y);

	if (x == y)return false;
	father[x] = y;
	return true;
}


int UnionFindSet::FindV2(int x) {
	int root = x; // 保存好路径上的头结点
	while (father[root] != root) {
		root = father[root];
	}
	/*
	从头结点开始一直往根上遍历
	把所有结点的father直接指向root。
	*/
	while (father[x] != x) {
		// 一定要先保存好下一个结点,下一步是要对father[x]进行赋值
		int temp = father[x];
		father[x] = root;
		x = temp;
	}

	return root;
}

/*
需要加入height[]属性,初始化为1.
*/
//合并结点
bool UnionFindSet::UnionV2(int x, int y) {
	x = Find(x);
	y = Find(y);
	if (x == y) {
		return false;
	}
	if (height[x] < height[y]) {
		father[x] = y;
	}
	else if (height[x] > height[y]) {
		father[y] = x;
	}
	else {
		father[x] = y;
		height[y]++;
	}
	return true;
}

const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2], weight[M * 2];
int len;
int h[N];
int father[bitL][N];
int dis[bitL][N];

void initPara(int n)
{
	len = 0;
	for (int i = 0; i < n; i++)
	{
		head[i] = -1;
		h[i] = -1;
	}
}

void add(int a, int b, int w)
{
	to[len] = b;
	weight[len] = w;
	nextEdge[len] = head[a];

	head[a] = len++;
}

void dfs(int x, int fa, int w)
{
	if (fa == -1) h[x] = 0;
	else {
		h[x] = h[fa] + 1;
		father[0][x] = fa;
		dis[0][x] = w;
		for (int t = 1; t < bitL && (1 << t) <= h[x]; t++) {
			father[t][x] = father[t - 1][father[t - 1][x]];
			dis[t][x] = min(dis[t - 1][x], dis[t - 1][father[t - 1][x]]);
		}
	}
	int i;
	for (i = head[x]; i != -1; i = nextEdge[i])
	{
		int j = to[i];
		if (fa == j)continue;
		dfs(j, x, weight[i]);
	}
}

int lca(int a, int b) {
	if (h[a] < h[b]) {
		return lca(b, a);
	}

	int gap = h[a] - h[b];
	for (int t = bitL - 1; t >= 0; t--) {
		if (gap & (1 << t))a = father[t][a];
	}
	if (a == b)return a;
	gap = h[a];
	for (int t = bitL - 1; t >= 0; t--) {
		if (gap <= (1 << t))continue;
		if (father[t][a] == father[t][b])continue;
		a = father[t][a];
		b = father[t][b];
		gap -= 1 << t;
	}

	return father[0][a];
}


int optDis(int a, int b) {
	if (h[a] < h[b]) {
		return optDis(b, a);
	}
	int d = 1e6;
	int gap = h[a] - h[b];
	for (int t = bitL - 1; t >= 0; t--) {
		if (gap & (1 << t)) {
			d=min(d, dis[t][a]);
			a = father[t][a];
		}
	}
	if (a == b)return d;
	gap = h[a];
	for (int t = bitL - 1; t >= 0; t--) {
		if (gap <= (1 << t))continue;
		if (father[t][a] == father[t][b])continue;
		d = min(d,dis[t][a]);
		d = min(d,dis[t][b]);
		a = father[t][a];
		b = father[t][b];
		gap -= 1 << t;
	}
	d = min(d, min(dis[0][a], dis[0][b]));
	return d;
}

bool cmp(vector<int> &a, vector<int> &b) {
	return a[2] > b[2];
}

void solve()
{
	int n, m;
	int a, b, w;
	scanf("%d%d", &n, &m);
	initPara(n);
	auto us = UnionFindSet(n);
	vector<vector<int>> eds;
	for (int i = 0; i < m; ++i) {
		scanf("%d%d%d", &a, &b, &w);
		a--, b--;
		eds.push_back({a,b,w});
	}

	sort(eds.begin(), eds.end(), cmp);
	for (auto ed : eds) {
		if (us.UnionV2(ed[0], ed[1])) {
			add(ed[0], ed[1], ed[2]);
			add(ed[1], ed[0], ed[2]);
		}
	}

	for (int i = 0; i < n; ++i) {
		if(h[i]<0)dfs(i, -1, 0);
	}

	scanf("%d", &m);

	while (m--) {
		scanf("%d%d", &a, &b);
		a--, b--;
		if (us.FindV2(a) != us.FindV2(b))puts("-1");
		else printf("%d\n", optDis(a, b));
	}
}

int main()
{
	solve();
	return 0;
}



/*
12 11
8 1 4
8 9 3
8 12 6
1 5 5
1 7 1
7 6 2
9 4 2
9 11 10
9 2 9
4 3 2
12 10 7

11
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12


12 11
8 1 1
8 9 1
8 12 1
1 5 1
1 7 1
7 6 1
9 4 1
9 11 1
9 2 1
4 3 1
12 10 1

11
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12

*/

练习三

https://loj.ac/p/10130
算法思路:
利用倍增算法建立祖先关系,以及到祖先链路上的距离。
查询A,B距离=A到公共祖先距离+B到公共祖先距离)。
具体实现与上一题类似。



#include
#include
#include
#include
#include

using namespace std;

const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2], weight[M * 2];
int len;
int h[N];
int father[bitL][N];
int dis[bitL][N];

void initPara(int n)
{
	len = 0;
	for (int i = 0; i < n; i++)
	{
		head[i] = -1;
	}
}

void add(int a, int b, int w)
{
	to[len] = b;
	weight[len] = w;
	nextEdge[len] = head[a];

	head[a] = len++;
}

void dfs(int x, int fa, int w)
{
	if (fa == -1) h[x] = 0;
	else {
		h[x] = h[fa] + 1;
		father[0][x] = fa;
		dis[0][x] = w;
		for (int t = 1; t < bitL && (1 << t) <= h[x]; t++) {
			father[t][x] = father[t - 1][father[t - 1][x]];
			dis[t][x] = dis[t - 1][x] + dis[t - 1][father[t - 1][x]];
		}
	}
	int i;
	for (i = head[x]; i != -1; i = nextEdge[i])
	{
		int j = to[i];
		if (fa == j)continue;
		dfs(j, x, weight[i]);
	}
}

int lca(int a, int b) {
	if (h[a] < h[b]) {
		return lca(b, a);
	}

	int gap = h[a] - h[b];
	for (int t = bitL - 1; t >= 0; t--) {
		if (gap & (1 << t))a = father[t][a];
	}
	if (a == b)return a;
	gap = h[a];
	for (int t = bitL - 1; t >= 0; t--) {
		if (gap <= (1 << t))continue;
		if (father[t][a] == father[t][b])continue;
		a = father[t][a];
		b = father[t][b];
		gap -= 1 << t;
	}

	return father[0][a];
}


int optDis(int a, int b) {
	if (h[a] < h[b]) {
		return optDis(b, a);
	}
	int d = 0;
	int gap = h[a] - h[b];
	for (int t = bitL - 1; t >= 0; t--) {
		if (gap & (1 << t)) {
			d += dis[t][a];
			a = father[t][a];
		}
	}
	if (a == b)return d;
	gap = h[a];
	for (int t = bitL - 1; t >= 0; t--) {
		if (gap <= (1 << t))continue;
		if (father[t][a] == father[t][b])continue;
		d += dis[t][a];
		d += dis[t][b];
		a = father[t][a];
		b = father[t][b];
		gap -= 1 << t;
	}
	d += dis[0][a] + dis[0][b];
	return d;
}

void solve()
{
	int n, m;
	int a, b;
	scanf("%d", &n);
	initPara(n);
	for (int i = 0; i < n - 1; ++i) {
		scanf("%d%d", &a, &b);
		a--, b--;
		add(a, b, 1);
		add(b, a, 1);
	}

	dfs(0, -1, 0);

	scanf("%d", &m);

	while (m--) {
		scanf("%d%d", &a, &b);
		a--, b--;
		printf("%d\n", optDis(a, b));
	}
}

int main()
{
	solve();
	return 0;
}



/*
6
1 2
1 3
2 4
2 5
3 6
2
2 6
5 6


12
8 1
8 9
8 12
1 5
1 7
7 6
9 4
9 11
9 2
4 3
12 10
11
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12

*/

练习四

https://acm.hdu.edu.cn/showproblem.php?pid=2586
与练习三类似



#include
#include
#include
#include
#include

using namespace std;

const int M = 500000 + 10;
const int N = 500000 + 10;
const int bitL = 22;
int head[N];
int to[M * 2], nextEdge[M * 2],weight[M*2];
int len;
int h[N];
int father[bitL][N];
int dis[bitL][N];

void initPara(int n)
{
	len = 0;
	for (int i = 0; i < n; i++)
	{
		head[i] = -1;
	}
}

void add(int a, int b, int w)
{
	to[len] = b;
	weight[len] = w;
	nextEdge[len] = head[a];

	head[a] = len++;
}

void dfs(int x, int fa, int w)
{
	if (fa == -1) h[x] = 0;
	else {
		h[x] = h[fa] + 1;
		father[0][x] = fa;
		dis[0][x] = w;
		for (int t = 1; t < bitL && (1<<t)<=h[x]; t++) {
			father[t][x] = father[t-1][father[t - 1][x]];
			dis[t][x] = dis[t-1][x]+ dis[t - 1][father[t - 1][x]];
		}
	}
	int i;
	for (i = head[x]; i != -1; i = nextEdge[i])
	{
		int j = to[i];
		if (fa==j)continue;
		dfs(j, x, weight[i]);
	}
}

int lca(int a, int b) {
	if (h[a] < h[b]) {
		return lca(b, a);
	}

	int gap = h[a] - h[b];
	for (int t = bitL-1; t>=0; t--) {
		if (gap & (1 << t))a = father[t][a];
	}
	if (a == b)return a;
	gap = h[a];
	for (int t = bitL-1; t >= 0; t--) {
		if (gap <=(1 << t))continue;
		if (father[t][a] == father[t][b])continue;
		a = father[t][a];
		b = father[t][b];
		gap -= 1 << t;
	}
	
	return father[0][a];
}


int optDis(int a, int b) {
	if (h[a] < h[b]) {
		return optDis(b, a);
	}
	int d = 0;
	int gap = h[a] - h[b];
	for (int t = bitL - 1; t >= 0; t--) {
		if (gap & (1 << t)) {
			d += dis[t][a];
			a = father[t][a];
		}
	}
	if (a == b)return d;
	gap = h[a];
	for (int t = bitL - 1; t >= 0; t--) {
		if (gap <= (1 << t))continue;
		if (father[t][a] == father[t][b])continue;
		d += dis[t][a];
		d += dis[t][b];
		a = father[t][a];
		b = father[t][b];
		gap -= 1 << t;
	}
	d += dis[0][a] + dis[0][b];
	return d;
}

void solve()
{
	int t;
	int n, m;
	int a, b, w;
	scanf("%d", &t);
	while (t--) {
		scanf("%d%d", &n, &m);
		initPara(n);
		for (int i = 0; i < n - 1; ++i) {
			scanf("%d%d%d", &a, &b, &w);
			a--, b--;
			add(a, b,w);
			add(b, a,w);
		}

		dfs(0, -1, 0);

		/*for (int i = 0; i < n; ++i) {
			printf("%d: %d\n", i, h[i]);
		}*/

		while (m--) {
			scanf("%d%d", &a, &b);
			a--, b--;
			printf("%d\n", optDis(a,b));
		}
	}
}

void test() {
	int t;
	int n = 5000, m = 500000;

	//scanf("%d%d%d", &n, &m, &s);
	initPara(n);
	int a, b;
	for (int i = 0; i < n - 1; ++i) {
		a = i, b = i + 1;
		add(a, b,1);
		add(b, a,1);
	}

	dfs(0, -1,0);
	//printf("%d\n", 1 + lca(10, 5000-1));
	while (m--) {
		a = (m+102)%n, b =( 3823+m*2)%n;
		//printf("%d\n", m);
		if(lca(a, b)!=min(a,b))
		printf("%d %d %d\n", 1 + lca(a, b), a+1, b+1);
	}
}

int main()
{
	solve();
	//test();
	return 0;
}



/*
2
3 2
1 2 10
3 1 15
1 2
2 3

2 2
1 2 100
1 2
2 1

1
12 11
8 1 4
8 9 3
8 12 6
1 5 5
1 7 1
7 6 2
9 4 2
9 11 10
9 2 9
4 3 2
12 10 7

1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12

1
12 11
8 1 1
8 9 1
8 12 1
1 5 1
1 7 1
7 6 1
9 4 1
9 11 1
9 2 1
4 3 1
12 10 1

1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11
11 12
*/

本人码农,希望通过自己的分享,让大家更容易学懂计算机知识。创作不易,帮忙点击公众号的链接。

你可能感兴趣的:(leetcode,高阶算法,算法,leetcode,最近公共祖先,倍增算法)