python3 数据结构和算法(4) 并查集

import random
import time
import sys
from functools import wraps

def timethis(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print('%-30s %10s %f' % (func.__name__, ' cost time:', end - start))
        return result
    return wrapper

class DisjointSet():
    def __init__(self, n):
        self.rank = [0] * n
        self.p = [0] * n
        for i in range(n):
            self.make_set(i)

    def make_set(self, x):
        self.p[x] = x
        self.rank[x] = 0

    def union(self, x, y):
        self.link(self.find_set(x), self.find_set(y))

    def link(self, x, y):
        if self.rank[x] > self.rank[y]:
            self.p[y] = x
        else:
            self.p[x] = y
            if self.rank[x] == self.rank[y]:
                self.rank[y] += 1

    def find_set(self, x):
        if x != self.p[x]:
            self.p[x] = self.find_set(self.p[x])
        return self.p[x]


@timethis
def test():
    n, q = map(int, sys.stdin.readline().split())
    ds = DisjointSet(n)

    for i in range(q):
        t, a, b = map(int, sys.stdin.readline().split())
        if t == 0:
            ds.union(a, b)
        elif t == 1:
            if ds.find_set(a) == ds.find_set(b):
                print(1)
            else:
                print(0)

def main():
    test()

if __name__ == '__main__':
    main()

你可能感兴趣的:(python3,算法)