【专题讲解】搜索与图论

引言

acwing基础课程第三章的内容整理,主要是对于图的搜索有关的问题。关于dfs和bfs的方法会简略的写一些,主要是最短路和最小生成树的问题。

文章目录

  • 引言
  • 1. DFS与BFS
    • 无向图的存储
    • 例题
  • 2. 有向树与图的遍历:拓扑排序
  • 3. 最短路
    • (1). 单源最短路
      • 所有边的权重为正数(Dijkstra)
        • 朴素Dijkstra算法
        • 堆优化Dijkstra算法
      • 存在负权边
        • Bellman-Ford算法
        • SPFA算法
    • (2). 多源汇最短路
      • Folyd算法
  • 4. 最小生成树
    • Prim算法
    • Kruskal算法
  • 5. 二分图:染色法、匈牙利算法
    • 如何判别是不是二分图
    • 匈牙利算法

1. DFS与BFS

无向图的存储

首先是需要解决图的存储问题。我一般习惯用python中的字典解决,也可以用数组的方式。

#--------------------------字典存储的方法---------------------------------
dic = collections.defaultdict(list)
for a, b in map:
	dic[a].append(b)
	dic[b].append(a)
#----------------------------数组存储---------------------------------------
# 需要注意这里不可以,dic = [[]*(N+1)]
dic = [[] for _ in range(N+1)]
for a,b,w in map:
	dic[a].append([b,w])
	
#--------------------------链表存储的方法---------------------------------
# 对于每个点k,开一个单链表,存储k所有可以走到的点。h[k]存储这个单链表的头结点
N = 1000000
idx = 0
# 注意h要开点的个数+1,e,ne,w需要边的个数
h = [-1]*N
e = [0]*M, ne = [0]*M, w = [0]*M

# 添加一条边a->b,权重为x
def add(a, b, x)
    e[idx] = b
    ne[idx] = h[a]
    w[idx] = x
    h[a] = idx
	idx += 1
	
for a, b in map:
	add(a,b)
	add(b,a)

# 进行遍历
index = h[cur]
while index != -1:
	next = e[index] # 得到
	wight = w[index]
	# 操作,
	index = ne[index]

例题

  • 树的重心
    【专题讲解】搜索与图论_第1张图片

可以用dfs也可以用bfs。首先给出dfs的算法。


import collections
N = int(input())
## 存储图
dic = collections.defaultdict(list)
for i in range(N-1):
    a,b = map(int, input().split())
    dic[a].append(b)
    dic[b].append(a)

ans = float('inf')
visit = set()
### ---------------dfs------------------
res = 0
def dfs(x):
    global res
    visit.add(x)
    sum = 1
    for next in dic[x]:
        if next not in visit:
            s = dfs(next)
            res = max(res, s)
            sum += s
    return sum

for i in range(1, N+1):
	res = 0
    dfs(i)
    visit.clear() # 及时清空访问过的位置
    ans = min(ans, res)
print(ans)

#--------------------------------bfs-------------------
def bfs(x):
    res = 0
    queue = collections.deque()
    queue.append(x)
    visit.add(x)
    while queue:
        n = len(queue)
        res += 1
        for i in range(n):
            cur = queue.popleft()
            for next in dic[cur]:
                if next not in visit:
                    queue.append(next)
                    visit.add(next)
    return res

for i in range(1, N+1):
    visit.clear()
    res = bfs(i)
    ans = min(ans, res)
print(ans)
  • 图中点的层次
    【专题讲解】搜索与图论_第2张图片

这类题目最好使用bfs的方法

import collections
N, M= map(int, input().split())
dic = collections.defaultdict(list)
for i in range(M):
    a,b = map(int, input().split())
    dic[a].append(b)
    dic[b].append(a)

def bfs(x):
	visit = set()
    res = 0
    queue = collections.deque()
    queue.append(x)
    visit.add(x)
    while queue:
        n = len(queue)
        res += 1
        for i in range(n):
            cur = queue.popleft()
            for next in dic[cur]:
                if next not in visit:
                    if next == N:
                        ans = res
                        return res
                    queue.append(next)
                    visit.add(next)
    return -1

print(bfs(1))

2. 有向树与图的遍历:拓扑排序

有向的路径,特点在于存在依赖关系。基本思路是采用bfs,维护每一个节点的入度,当入读为0时,加入队列。

  • 课程表问题

【专题讲解】搜索与图论_第3张图片

class Solution:
    def findOrder(self, n: int, prerequisites: List[List[int]]) -> List[int]:
    ## 这里没有考虑存在环的情况,否则应该用一个visit
        dic = collections.defaultdict(list)
        num = [0]*n 
        queue = collections.deque()
        ans = []        

        for cur, pre in prerequisites:
            dic[pre].append(cur)   # bfs的方法,前驱是key,后继是键值
            num[cur] += 1  ## 计算了改节点的入度
		# 初始化,把所有初始入度为0的节点加入队列
        for i in range(n):
            if num[i] == 0:
                queue.append(i)
                ans.append(i)
        while queue:
            top = queue.popleft()
            for item in dic[top]:
            	# 依次减小节点的入度
                num[item] -= 1
                if num[item] == 0:
					# 每当入度为0就加入。
                    queue.append(item)
                    ans.append(item)
        if len(ans) == n: # 不等于,说明存在某个点没有进入,入度存在问题
            return ans
        else:
            return []

3. 最短路

(1). 单源最短路

解释:从一号点到n号点的最短路径。
点的个数是n,边的个数是m

所有边的权重为正数(Dijkstra)

朴素Dijkstra算法

  • 朴素Dijkstra算法: O ( n 2 ) O(n^2) O(n2)。适合稠密图, m m m接近于 n 2 n^2 n2,与边的数量无关。

因为相对稠密,采用邻接矩阵来写。例题

# n<500, m<10^5
# 朴素的dijkstra算法采用邻接矩阵来存储
# 注意这里都是考虑的有向边
N, M = map(int, input().split())
g = [[float('inf')]*(N+1) for _ in range(N+1)]
for i in range(M):
    a, b, c = map(int, input().split())
    g[a][b] = min(g[a][b], c)

def dijkstra():
    # 第一步,初始化距离矩阵,visit集合
    dis = [float('inf')]*(N+1)
    dis[1] = 0
    visit = set()
    # 第二步,进行N次循环,每次选取距离最小的点
    for _ in range(N):
        t = -1
        # 利用暴力的方法寻找未被保存的最短路径点
        for j in range(1,N+1):
            if j not in visit and (t == -1 or dis[t]>dis[j]):
                t = j 
        # 把最短路径点加入visit
        visit.add(t)
        # 更新其余点的最短距离
        for j in range(1, N+1):
            dis[j] = min(dis[j], dis[t]+g[t][j])
    if dis[N] == float('inf'):
        return -1
    return dis[N]

ans = dijkstra()
print(ans)

堆优化Dijkstra算法

  • 堆优化: O ( m l o g n ) O(mlogn) O(mlogn) 适合稀疏图。

一般可以直接采用python中基于堆的优先队列。复杂度实际上为 m l o g m mlogm mlogm。例题
另外还可以如上面的链接那样,建立邻接表。

import heapq
# 基于堆的dijkstra算法采用邻接表来存储
# 注意这里都是考虑的有向边
N, M = map(int, input().split())
h = [[] for _ in range((N+1))] # 构建一个邻接表
    
for i in range(M):
    a, b, c = map(int, input().split())
    h[a].append([b,c]) # 依次建立邻接表

def dijkstra():
    # 第一步,初始化距离矩阵,visit集合,堆
    dis = [float('inf')]*(N+1)
    min_ = [[0,1]]
    heapq.heapify(min_)
    visit = set()
    # 第二步,循环,每次选取距离最小的点
    while min_:
        # 利用堆的方法寻找未被保存的最短路径点
        distance, cur = heapq.heappop(min_)
        if cur in visit:
            continue
        visit.add(cur)
        dis[cur] = distance
        for next, w in h[cur]:
            if distance+w<dis[next]:
                heapq.heappush(min_, [distance+w, next])
                dis[next] = distance+w
    if dis[N] == float('inf'):
        return -1
    return dis[N]

ans = dijkstra()
print(ans)    

存在负权边

Bellman-Ford算法

  • Bellman-Ford算法: O ( n m ) O(nm) O(nm)

Bellman-Ford算法是可以解决限制了最多经过 K 条边到达 n 的最短路径问题的。

外层循环遍历所有的点n,内层循环遍历所有边m,维护最短的路径dis[b] = min(dis[b], dis[b]+w)需要注意,存在负权边时候,如果存在负权重环,可能无最短距离

如果第n次迭代,依然有更新最短边,说明存在一个至少为n+1的最短路径。存在环。

## bellman算法
# 注意这里考虑的是有向边
N,M = map(int, input().split())
g = [[0,0]]
for i in range(M):
    a,b,c = map(int, input().split())
    g.append([a,b,c])

def bellman():
    dis = [float('inf')]*(N+1)
    dis[1] = 0
    
    for i in range(N):
        backup = dis.copy() # 这里需要注意,进行了复制,防止迭代出现混乱
        for j in range(M):
            a,b,c = g[j]
            dis[b] = min(backup[a]+c, dis[b])
            
    if dis[N] == float('inf'):
        return 'impossible'
    return dis[N]
ans = bellman()
print(ans)

SPFA算法

  • SPFA算法:一般是 O ( m ) O(m) O(m),最坏时间复杂度 O ( n m ) O(nm) O(nm)

核心特点是,堆中不存储路径,只存储提升了的节点。并且为了防止重复入堆,维护了一个标记数组控制入堆的元素

一般情况下,spfa算法的速度比Dijkstra还要快。但是可以卡时间,所以如果被卡了,就改成dijkstra算法。

## SPFA
# 注意这里考虑的都是有向边
import heapq
N,M = map(int, input().split())
# 注意,这里不可以是g = [[]*(N+1)]
g = [[] for _ in range(N+1)]

for i in range(M):
    a,b,c = map(int, input().split())
    g[a].append([b,c])

def SPFA():
    dis = [float('inf')]*(N+1)
    dis[1] = 0
    queue = [1]
    heapq.heapify(queue)
    # 需要维护一个数组判断某个点在不在当前队列当中
    has = [0]*(N+1)
    has[1] = 1
    
    while queue:
        cur = heapq.heappop(queue)
        has[cur] = 0
        for next, w in g[cur]:
            if dis[cur]+w < dis[next]:
                dis[next] = dis[cur]+w
                if has[next] == 0: # 没有必要重复加入
                    heapq.heappush(queue, next)
                    has[next] = 1
            
    if dis[N] == float('inf'):
        return 'impossible'
    return dis[N]
ans = SPFA()
print(ans)

SPFA算法还可以检测负环。除了维护dis以外,还需要维护一个cnt,每次进行状态转移时候,cnt(next) = cnt(cur)+1,如果cnt>N表示存在负环。

def SPFA():
    # 需要维护一个数组判断某个点在不在当前队列当中
    #----------不同点1:初始全部放入队列,因为是检测整个图是否存在负环-------
    dis = [0]*(N+1) # 可以直接设置为0,反正有负环也会更新
    queue = [i for i in range(1,N+1)]
    heapq.heapify(queue)
    has = [1]*(N+1) 
    cnt = [0]*(N+1)
    
    while queue:
        cur = heapq.heappop(queue)
        has[cur] = 0
        for next, w in g[cur]:
            if dis[cur]+w < dis[next]:
                dis[next] = dis[cur]+w
                cnt[next] = cnt[cur]+1
                if cur[next] >= N: # 一共N个点,所以最多N-1条边。
                	return True
                if has[next] == 0: # 没有必要重复加入
                    heapq.heappush(queue, next)
                    has[next] = 1
    return False
ans = SPFA()
print(ans)

(2). 多源汇最短路

解释:多个起点,多个终点。从x号点,到y号点的最短距离。是可以处理重边,自环和负权边的。但是因为研究的是最短路问题,因此不能出现负环。

Folyd算法

  • Folyd算法: O ( n 3 ) O(n^3) O(n3) 基于动态规划,因此已经要牢记枚举,k i j的枚举顺序。
  • 存储方法采用邻接矩阵。
# N,M,Q分别为点的个数,边的个数,和查询的个数
# 注意这里考虑的是有向边
N,M,Q = map(int, input().split())
# 采用邻接矩阵进行存储
dis = [[float('inf')]*(N+1) for _ in range(N+1)]

for i in range(1,N+1):
    dis[i][i] = 0   
for i in range(M):
    a,b,c = map(int, input().split())
    dis[a][b] = min(dis[a][b], c)

def Foldy():
    for k in range(1, N+1):
        for i in range(1,N+1):
            for j in range(1, N+1):
                dis[i][j] = min(dis[i][j], dis[i][k]+dis[k][j])
    return 
    
Foldy()
for i in range(Q):
    a,b = map(int, input().split())
    if dis[a][b] == float('inf'):
        print(-1)
    else:
        print(dis[a][b])

4. 最小生成树

最小生成树的问题一般都是无向图,这个与最短路不太一样,最短路一般是有向图

Prim算法

基本思路:首先选择出来一个点,作为起点,然后依次更新每个点到集合距离,选择距离最小的那个点,将该条边加入最小生成树,然后以这个点更新各个未加入树的点到集合的距离。重复N次则有N个点加入。

  • 时间复杂度: O ( n 2 ) O(n^2) O(n2),适用于稠密图
N, M = map(int, input().split())

g = [[float('inf')]*(N+1) for _ in range(N+1)]
for i in range(M):
    a,b,c = map(int, input().split())
    g[a][b] = g[b][a] = min(g[a][b], c)
    
def Prim():
    res = 0
    dis = [float('inf')]*(N+1)
    has = [0]*(N+1)
    for i in range(N):
        t = -1
        for j in range(1, N+1):
            # 寻找到距离集合距离最短的边
            if has[j] == 0 and (t == -1 or dis[t]>dis[j]):
                t = j
        # 选出来的最小值是无穷,说明有一条边是到不了的。
        if i != 0 and dis[t] == float('inf'):
            return 'impossible'
        #  这里一定要注意,先累加,再更新,否则会错在自环上
        if i != 0: # 第一次只是选点,树不需要添加边
            res += dis[t]
        # 更新每个点到集合的距离 
        for j in range(1,N+1):
            dis[j] = min(dis[j], g[t][j])
        has[t] = 1
    return res
    
res = Prim()
print(res)
        
    

另外存在堆优化Prim算法,但是很少用。

  • 时间复杂度: O ( m l o g n ) O(mlogn) O(mlogn),适用于稀疏图,用的少

Kruskal算法

思路:1. 将所有边按照权重从小到大进行排序。2. 枚举每条边a, b, 权重c。只要不连通,就连接两个点,并加入边。

  • 时间复杂度: O ( m l o g m ) O(mlogm) O(mlogm),也是适用于稀疏图,首选。

图的构建可以随意,因为会遍历所有的边。

N,M = map(int, input().split())
g = []
for i in range(M):
    a,b,c = map(int, input().split())
    g.append([a,b,c])

fa = {}
def new(x):
    if x in fa:
        return 
    fa[x] = x
    
def find(x):
    if x == fa[x]:
        return x
    fa[x] = find(fa[x])
    return fa[x]

def union(a,b):
    pa = fa[a]
    pb = fa[b]
    fa[pa] = pb
    
def Kruskal():
    res = 0
    g.sort(key = lambda x:x[2])
    cnt = 0
    for i in range(M):
        a,b,c = g[i]
        new(a)
        new(b)
        pa = find(a)
        pb = find(b)
        if pa != pb:
            res += c
            union(a,b)
            cnt += 1
    if cnt < N-1:
        return 'impossible'
    return res
        
ans = Kruskal()
print(ans)

5. 二分图:染色法、匈牙利算法

如何判别是不是二分图

二分图的判断:当且仅当图中不存在奇数环。
方法:染色法。

  • 时间复杂度: O ( m + n ) O(m+n) O(m+n)
## 染色法判断二分图
N,M = map(int, input().split())
g = [[] for _ in range(N+1)]
color = [-1]*(N+1)

for i in range(M):
    a,b = map(int, input().split())
    g[a].append(b)
    g[b].append(a)
# --------------------------dfs思路-----------------------------------
def dfs(x,c):
    color[x] = c
    for ne in g[x]:
        if color[ne] == -1:
            if dfs(ne, 1-c) == False:
                return False
        elif color[ne] == c:
            return False
    return True
flag = 1
for i in range(1,N+1):
    if color[i] == -1:
        if dfs(i,0) == False:
            flag = 0
            break
if flag == 0:
    print('No')
else:
    print('yes')
#----------------------------bfs思路----------------------------------
   color = [-1]*(n)
   queue = collections.deque()
   queue.append(0)
   color[0] = 0
   for i in range(n):
       if color[i] == -1:
           queue.append(i)
           color[i] = 0
       while queue:
           cur = queue.popleft()
           now = color[cur]
           for i in g[cur]:
               if color[i] == -1:
                   color[i] = 1-now
                   queue.append(i)
               elif color[i] == now:
                   return False        

匈牙利算法

算法目的,在两个集合中,寻找到数量最多的一一匹配。
算法思路:考虑男生与女生配对的问题,依次考虑每个男生,去匹配每个女生,并考虑冲突的女生配对的男生是否存在别的可能。

  • 时间复杂度: O ( n m ) O(nm) O(nm),实际运行时间一般远小于 O ( n m ) O(nm) O(nm)
# 匈牙利算法,求解左右两个图的最大匹配度
# 输入a,b,c,分别是左半侧,右半侧的点和边的数量
n1, n2, m = map(int, input().split())
g = [[]for _ in range(n1+1)] # 只需要存储左边指向右边的边的个数
for _ in range(m):
    a,b = map(int, input().split())
    g[a].append(b)
has = set()  
# 存储当前girl已经匹配的对象
atch = [-1]*(n2+1)

def find(x):
    for c in g[x]: # 枚举目前的男生可以选择的全部女生
        if c not in has: # 每个女生只考虑一次,防止嵌套
            has.add(c)
            if match[c] == -1 or find(match[c]): # 如果当前女生还未被匹配,或匹配的男生可以修改
                match[c] = x
                return True
    return False # 只有一切可能都不行才返回False
            
res = 0
for i in range(1,n1+1):
    # -------注意:每次新的循环需要初始化girls的序列---
    has.clear()
    # 匹配成功就+1
    if find(i):
        res += 1
print(res)

你可能感兴趣的:(专题讲解)