生成树(Minimal Spanning Tree,MST)的概念针对连通图而提出。


  • 连通图(connected graph):无向图(undirected graph)中,如果任意两点有路径连接,则称其为连通图(connected graph)
  • 强连通图:在有向图(directed graph)中,任意两点,都有路径相连接,则称其为强连通图
  • 生成树:连通图的生成树指一个连通子图,含有图中的全部n个顶点,只有足以构成一棵树的n-1条边。在该树中再添加一条边,则必定成环
  • 最小生成树:设连通图的每个边有权重,在一个连通图的所有生成树中,权重之和最小的那课树定义为最小生成树

在无向图中查找MST,可通过greedy method实现,本文后面将要介绍的两种方法都是基于该方案。在介绍算法之前,先了解关于MST的一个重要事实。

Let G be a weighted connected graph, and let V1 and V2 be a
partition of the vertices of G into two disjoint nonempty sets. Furthermore, let e be an edge in G with minimum weight from among those with one endpoint in V1 and the other in V2. There is a minimum spanning tree T that has e as one of its edges.



Let T be a minimum spanning tree of G. If T does not contain
edge e, the addition of e to T must create a cycle. Therefore, there is some edge of this cycle that has one endpoint in V1 and the other in V2. Moreover, by the choice of e, . If we remove f from , we obtain a spanning tree whose total weight is no more than before. Since T was a minimum spanning tree, this new tree must also be a minimum spanning tree.




Complexity analysis

该算法的计算分为两部分:edge sorting和test whether 2 clusters are distinct。





import logging

class Edge:
    edge structure, two ends, i.e., u and v, and weight incoporated
    def __init__(self, u, v, weight):
        self.u = u
        self.v = v
        self.weight = weight

class DisjointSet:
    def __init__(self, n):
        n: number of vertices
        self.parent = [None] * n # parenet of node i
        self.size = [1] * n # for merge
        for i in range(n):
            self.parent[i] = i # set each vertex as its only parent

    def merge_set(self, a, b):
        merge two vertices, a and b, and the trees they belong to

          a - index of a
          b - index of b
          if a and b share no the same root, then merge them
        a = self.find_set(a)
        b = self.find_set(b)
        logging.info(f"""merging: vertex {a+1} and {b+1}""")

        if self.size[a] < self.size[b]:
            logging.info(f"""merging: size.{a+1} < size.{b+1}""")
            self.parent[a] = b # merge a to b
            self.size[b] += self.size[a] # add size of old set(a) to set(b)
            logging.info(f"""merging: size.{a+1} >= size.{b+1}""")
            self.parent[b] = a # merge b to a
            self.size[a] += self.size[b] # add size of old set(b) to set(a)

    def find_set(self, a):
        find out root of node a, or the set which names after root
        if self.parent[a] != a: 
            # find it out in a recurrsive way 
            self.parent[a] = self.find_set(self.parent[a])
        # return root
        return self.parent[a]

def k_algo(n, edges, ds):
    Kruskal algorithm: 
      Rank all edges in an ascendent order of weights, find out top n-1 edges with no loop.
      Select a new edge from edge ranking, if the two ends of that edge are not in the disjoint set,
      then add the edge to MST. Otherwise, adding the edge leads to a loop. The final results contain 
      n-1 edges.

      n - int, #vertices
      edges - list, graph edges
      ds - DisjointSet
      sum(weights) - MST weight summary

    edges.sort(key=lambda x: x.weight) # sort edges by weights
    MST = [] 
    for edge in edges:
        set_u = ds.find_set(edge.u) # root of u 
        set_v = ds.find_set(edge.v) 
        if set_u != set_v: 
            logging.info(f"""Vertices {u+1} and {v+1} have different roots.""")
            # u and v have no same root, then merge
            ds.merge_set(set_u, set_v)
            if len(MST) == n-1: 
                # reach to the minimal number of edge, simply stop
                logging.info(f"""MST reaches to the end.""")
    edges_selected = []
    for e in MST:
        edges_selected.append((e.u+1, e.v+1, e.weight))
    logging.info(f"""MST edges: {edges_selected}""")
    return sum([edge.weight for edge in MST])

if __name__ == "__main__":
    format = "%(asctime)s: %(message)s"
    logging.basicConfig(format=format, level=logging.INFO)
    n = 6 # num of vertices in the weighted undirected graph
    # assume the vertex index are continuous integers starting from 1 
    e = [(1, 2, 10), (3, 5, 19), (4, 6, 7), 
         (1, 5, 3),  (2, 6, 31), (3, 6, 20)]
    m = len(e) # num of edges
    ds = DisjointSet(m)
    edges = [None] * m 

    # allocate edge info 
    for i in range(m):
        u, v, weight = e[i][0], e[i][1], e[i][2]
        u -= 1 # transform index, [1, n] -> [0, n-1]
        v -= 1 
        edges[i] = Edge(u, v, weight)

    print("MST weights sum:", k_algo(n, edges, ds))


