转自:http://blog.csdn.net/rav009/article/details/12781899
1 # -*- coding: UTF-8 -*- 2 3 class unionfind: 4 def __init__(self, groups): 5 self.groups = groups 6 self.items = [] 7 for g in groups: 8 self.items += list(g) 9 self.items = set(self.items) 10 self.parent = {} 11 self.rootdict = {} # 记住每个root下节点的数量 12 for item in self.items: 13 self.rootdict[item] = 1 14 self.parent[item] = item 15 16 def union(self, r1, r2): 17 rr1 = self.findroot(r1) 18 rr2 = self.findroot(r2) 19 cr1 = self.rootdict[rr1] 20 cr2 = self.rootdict[rr2] 21 if cr1 >= cr2: # 将节点数量较小的树归并给节点数更大的树 22 self.parent[rr2] = rr1 23 self.rootdict.pop(rr2) 24 self.rootdict[rr1] = cr1 + cr2 25 else: 26 self.parent[rr1] = rr2 27 self.rootdict.pop(rr1) 28 self.rootdict[rr2] = cr1 + cr2 29 30 def findroot(self, r): 31 """ 32 可以通过压缩路径来优化算法,即遍历路径上的每个节点直接指向根节点 33 """ 34 if r in self.rootdict.keys(): 35 return r 36 else: 37 return self.findroot(self.parent[r]) 38 39 def createtree(self): 40 for g in self.groups: 41 if len(g) < 2: 42 continue 43 else: 44 for i in range(0, len(g) - 1): 45 if self.findroot(g[i]) != self.findroot(g[i + 1]): # 如果处于同一个集合的节点有不同的根节点,归并之 46 self.union(g[i], g[i + 1]) 47 48 def printree(self): 49 rs = {} 50 for item in self.items: 51 root = self.findroot(item) 52 rs.setdefault(root, []) 53 rs[root] += [item] 54 for key in rs.keys(): 55 print rs[key], 56 57 def gettree(self): 58 rs = {} 59 group = [] 60 for item in self.items: 61 root = self.findroot(item) 62 rs.setdefault(root, []) 63 rs[root] += [item] 64 for key in rs.keys(): 65 group.append(rs[key]) 66 return group 67 68 69 if __name__ == '__main__': 70 u = unionfind([['a', 'b', 'c'], ['b', 'd'], ['e', 'f'], ['g'], ['d', 'h', 'i']]) 71 u.createtree() 72 u.printree()