并查集(Disjoint Set Union)是一种常用的处理不相交集合间的合并与查找功能的树形结构,配合与之对应的联合-搜索算法(Union Find Algorithm),可以将不相交集合间的合并与查找功能的时间复杂度大幅缩减至 O ( l o g N ) O(logN) O(logN)乃至 O ( 1 ) O(1) O(1)的量级。
并查集的核心思想在于将所有相互关联的集合都用一个代表元素进行表征,这样,只要找到两个元素对应的代表元素,判断其是否一致即可判断他们是否属于同一个集合;同样的,如果需要联合两个原本不相交的集合,只要将其中一个代表元素指向另一个代表元素,使他们采用同一个代表元素即可。
更详细的原理说明可以参考下面的参考链接3中知乎专栏里的讲解,他对并查集原理的说明讲的非常的详细,我主要就是通过的这篇专栏学习的并查集相关内容。
下面,我们来给出一般的并查集的简单代码实现。
class DSU:
def __init__(self, N):
self.root = [i for i in range(N)]
def find(self, k):
if self.root[k] == k:
return k
return self.find(self.root[k])
def union(self, a, b):
x = self.find(a)
y = self.find(b)
if x != y:
self.root[y] = x
return
上述代码即为最为一般性的并查集结构。
不过,通常而言,为了更好地提高算法效率,我们有时会给其增加一些小的trick。比如:
为了更好地优化算法的效率,我们可以控制树形结构,使其尽可能地扁平化,避免出现链型结构导致深度过深,我们经常会通过记录树的深度的方式优化树形结构,从而优化算法效率。
class DSU:
def __init__(self, N):
self.root = [i for i in range(N)]
self.depth = [1 for i in range(N)]
def find(self, k):
if self.root[k] == k:
return k
return self.find(self.root[k])
def union(self, a, b):
x = self.find(a)
y = self.find(b)
xh = self.depth[x]
yh = self.depth[y]
if x == y:
return
if xh >= yh:
self.root[y] = x
self.depth[x] = max(self.depth[x], self.depth[y]+1)
else:
self.root[x] = y
另一种常用的trick为:
class DSU:
def __init__(self, N):
self.root = [i for i in range(N)]
def find(self, k):
if self.root[k] == k:
return k
self.root[k] = self.find(self.root[k])
return self.root[k]
def union(self, a, b):
x = self.find(a)
y = self.find(b)
if x != y:
self.root[y] = x
return
下面,我们来通过一些leetcode中的例题来考察并查集结构的实际用法。
这一题是最为典型的并查集使用场景,我们直接套用并查集结构就能解答这道题。
直接给出代码实现如下:
class DSU:
def __init__(self, N):
self.root = [i for i in range(N)]
def find(self, k):
if self.root[k] == k:
return k
return self.find(self.root[k])
def union(self, a, b):
x = self.find(a)
y = self.find(b)
if x != y:
self.root[y] = x
return
class Solution:
def findCircleNum(self, M: List[List[int]]) -> int:
n = len(M)
dsu = DSU(n)
for i in range(n):
for j in range(i+1, n):
if M[i][j] == 1:
dsu.union(i, j)
group = set()
for i in range(n):
group.add(dsu.find(i))
return len(group)
这一题较之上一题会显得多少复杂一点,但是还是比较明显的DSU结构,我们不是根据数字,而是根据account建立dsu关系,而后根据不同的代表account返回去找到对应的账户所有人即可。
给出代码实现如下:
class DSU:
def __init__(self):
self.dsu = {
}
def find(self, account):
if account not in self.dsu:
self.dsu[account] = account
return account
if account == self.dsu[account]:
return account
self.dsu[account] = self.find(self.dsu[account])
return self.dsu[account]
def union(self, x, y):
a1 = self.find(x)
a2 = self.find(y)
self.dsu[a2] = a1
return
class Solution:
def accountsMerge(self, accounts: List[List[str]]) -> List[List[str]]:
mapping = {
}
dsu = DSU()
for it in accounts:
name = it[0]
key_account = it[1]
mapping[key_account] = name
for account in it[2:]:
mapping[account] = name
dsu.union(key_account, account)
res = defaultdict(list)
for account in mapping:
key_account = dsu.find(account)
res[key_account].append(account)
ans = [[mapping[k]] + sorted(v) for k, v in res.items()]
return ans
这一题的DSU结构还是比较明显的,就是针对数组中的每一个元素n,查看n-1以及n+1是否也出现在dsu当中,如果在的话就连结这几个元素,反之就不连接。
唯一需要注意的是,由于这一题对于执行效率有较高的要求,因此,我们需要对dsu的树状结构进行优化,使其尽可能地扁平化。
给出一种代码实现如下:
class DSU:
def __init__(self):
self.dsu = {
}
def find(self, n):
if n not in self.dsu:
return None
if n == self.dsu[n]:
return n
self.dsu[n] = self.find(self.dsu[n])
return self.dsu[n]
def union(self, x, y):
xr = self.find(x)
yr = self.find(y)
if xr is None or yr is None:
return
self.dsu[yr] = xr
return
def add(self, n):
if n not in self.dsu:
self.dsu[n] = n
return
class Solution:
def longestConsecutive(self, nums: List[int]) -> int:
if nums == []:
return 0
dsu = DSU()
nums = list(set(nums))
for n in nums:
dsu.add(n)
dsu.union(n, n-1)
dsu.union(n, n+1)
counter = defaultdict(int)
for n in nums:
counter[dsu.find(n)] += 1
return max(counter.values())
这一题是leetcode Weekly Contest 205的最后一题,当时没能做出来,现在,在大致学会了DSU结构之后,我们重新来考察这道题。
主体的思想还是和当时保持一致,当某条边连接的两点已经属于同一个集合时,我们就舍弃掉这条边,反之将这条边保留,最后看是否能够构成一个全连接图,如果能的话,一共删除了几条边。
通过DSU结构,我们很快地搞定了这道题,给出我们自己实现的python代码如下:
from copy import deepcopy
class DSU:
def __init__(self, n):
self.dsu = [i for i in range(n+1)]
def find(self, x):
if x == self.dsu[x]:
return x
self.dsu[x] = self.find(self.dsu[x])
return self.dsu[x]
def union(self, x, y):
xr = self.find(x)
yr = self.find(y)
self.dsu[yr] = xr
return
class Solution:
def maxNumEdgesToRemove(self, n: int, edges: List[List[int]]) -> int:
alice = []
bob = []
both = []
for t, x, y in edges:
if t == 1:
alice.append((x, y))
elif t == 2:
bob.append((x, y))
else:
both.append((x, y))
dsu = DSU(n)
counter3 = 0
for x, y in both:
if dsu.find(x) == dsu.find(y):
continue
dsu.union(x, y)
counter3 += 1
dsu1 = deepcopy(dsu)
counter1 = 0
for x, y in alice:
if dsu1.find(x) == dsu1.find(y):
continue
dsu1.union(x, y)
counter1 += 1
dsu2 = deepcopy(dsu)
counter2 = 0
for x, y in bob:
if dsu2.find(x) == dsu2.find(y):
continue
dsu2.union(x, y)
counter2 += 1
if counter1 + counter3 != n-1 or counter2 + counter3 != n-1:
return -1
else:
return len(edges) + counter3 - 2*n +2