Union Find 和一道谷歌面试题

复习经历

继续按照主题刷题。最近一次的中心内容是 union-find 这种数据结构。选择它的原因是我在一次谷歌面试中被问到了这个题,而且当时没有回答出来,这就刺激了我一定要把 union-find 搞懂的愿望。我并不是完全不懂这个数据结构,在斯坦福在 Coursera 中的 Algorithm Specialization 里 Tim Roughgarden 教授讲过的,我也动手实现过。只是这段经历已经很遥远,而且在那之后再也没有用到过这种数据结构,所以没有机会复习。

把 union-find 放入正题对待以后,我首先做的就是回去把 Algorithm Specialization 里的讲课视频再看了一遍。这个课是要付费的,但是有免费的 preview, 用谷歌搜索视频“stanford union find” 就能够找到。课堂里是在讲述如何实现 Krugal’s Algorithm, 而 union-find 恰好是一种很契合的实现方式。Tim 老师简单介绍了一下这种数据结构,不过这些就够用了。不光够实现 Krugal’s Algorithm, 攻破谷歌的面试题也可以的。

在正式刷题之前,我又看了几篇教程文章,其中 有一篇 通过一系列数组的图片把概念解释得很清楚。union find 也叫 disjoint set union,一个整体被分成了若干不相交的部分。总体上来说,这个数据结构就是在维护一个元素间有从属关系的数组,它有两个基本方法:

  1. find(x), 找到元素 x 的根节点,通常在 union find 维护的数组中,元素是作为索引的形式存在,索引对应的值是它父亲的索引;这个 find() 方法就是递归地寻找父亲,直到某个节点的父亲是自己,就返回这个节点
  2. union(x,y), 联合元素 x 和 元素 y,即把 x 和 y 所属的部分并起来,实现方法就是找到 x 和 y 各自的根节点,让其中一个根节点依附于另一个

Union-Find 的样例 Python 实现

class UnionFind:
    def __init__(self, n):
        self.arr = [i for i in range(n)]
        self.n = n
        self.size = n

    def find(self, x):
        if x >= self.n:
            return -1
        if self.arr[x] != x:
            return self.find(self.arr[x])
        else:
            return x

    def union(self, x, y):
        if x >= self.n or y >= self.n:
            return False
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x != root_y:
            self.arr[root_x] = root_y
            self.size -= 1
        return True

if __name__ == '__main__':
    uf = UnionFind(5)
    print(uf.find(3))
    uf.union(3, 4)
    print(uf.find(3))

通过代码读者可以发现 union find 这种数据结构有一个副产品 size. 这是随着成功的 union 操作而递减的,这个副产品在求一个图上的不连通组件的数量的时候非常好用。程序员只要按照 union find 的套路记录节点的变化,结果就维护在这个数据结构里。

谷歌面试题讲解

leetcode 947. 没错,我的电面就是一道 leetcode 原题。要求能够拿走的石头的最大数量。面试官后来都提示我了,就是求这个图上面不连通组件的数量。解决办法就是使用 union find, 将坐标作为 key, 互相映射。

def to_tuple(li):
    return (li[0], li[1])

class UnionFind:
    def __init__(self):
        self.size = 0
        self.map = {}
        
    def add(self, element):
        self.map[to_tuple(element)] = element
        self.size += 1
        
    def find(self, element):
        father = self.map[to_tuple(element)]
        if father != element:
            return self.find(father)
        else:
            return father
        
    def union(self, e1, e2):
        root1 = self.find(e1)
        root2 = self.find(e2)
        if root1 != root2:
            self.map[to_tuple(root1)] = root2
            self.size -= 1
        

class Solution:
    def removeStones(self, stones):
        """
        :type stones: List[List[int]]
        :rtype: int
        """
        uf = UnionFind()
        for stone in stones:
            uf.add(stone)
        length = len(stones)
        for i in range(length - 1):
            for j in range(i + 1, length):
                if stones[i][0] == stones[j][0] or stones[i][1] == stones[j][1]:
                    uf.union(stones[i], stones[j])
        return length - uf.size

最后最大的石头数量就是总数量 - 不连通组件的数量。

你可能感兴趣的:(算法,python,union-find)