回溯框架总结

什么是回溯算法

其实回溯算法和我们常说的 DFS 算法非常类似,本质上就是一种暴力穷举算法。回溯算法和 DFS 算法的细微差别是:回溯算法是在遍历「树枝」,DFS 算法是在遍历「节点」,本文就是简单提一下,等你看到后文图论算法基础 时就能深刻理解这句话的含义了。
废话不多说,直接上回溯算法框架,解决一个回溯问题,实际上就是一个决策树的遍历过程,站在回溯树的一个节点上,你只需要思考 3 个问题:
1、路径:也就是已经做出的选择。
2、选择列表:也就是你当前可以做的选择。
3、结束条件:也就是到达决策树底层,无法再做选择的条件。

代码方面,回溯算法的框架:

result = []
def backtrack(路径, 选择列表):
    if 满足结束条件:
        result.add(路径)
        return
    for 选择 in 选择列表:
    	
    	if(illegal()):(剪枝函数,如果当前选择列表当中的元素不合法,就跳过本轮循环)
    		continue;
        做选择
        backtrack(路径, 选择列表)
        撤销选择

其核心就是 for 循环里面的递归,在递归调用之前「做选择」,在递归调用之后「撤销选择」,特别简单。

什么问题可以用回溯算法

回溯算法多用于解决排列组合问题,遇到问题可以先尝试画递归树,如果问题能够用递归树的形式进行描述,那么就可以用回溯法来解决。
值得一提的是,回溯算法还经常被用来遍历二维数组(岛屿问题、矩阵最长递增路径),如果你把二维矩阵中的每一个位置看做一个节点,这个节点的上下左右四个位置就是相邻节点,那么整个矩阵就可以抽象成一个树结构,注意,这里的上是原路返回到父节点,这与普通的树是不同的,二维数组的遍历框架如下:

// 方向数组,分别代表上、下、左、右
int[][] dirs = new int[][]{{-1,0}, {1,0}, {0,-1}, {0,1}};

void dfs(int[][] grid, int i, int j, boolean[][] visited) {
    int m = grid.length, n = grid[0].length;
    if (i < 0 || j < 0 || i >= m || j >= n) {
        // 超出索引边界
        return;
    }
    if (visited[i][j]) {
        // 已遍历过 (i, j)
        return;
    }

    // 进入节点 (i, j)
    visited[i][j] = true;
    // 递归遍历上下左右的节点
    for (int[] d : dirs) {
        int next_i = i + d[0];
        int next_j = j + d[1];
        dfs(grid, next_i, next_j, visited);
    }
    // 离开节点 (i, j)
}

总的来说,回溯算法能够解决的问题主要有两类,能够转换成树结构的问题和二维数组的遍历问题。
代表问题有以下几种:

全排列问题

无重复数字的全排列

class Solution:
    def permute(self , num: List[int]) -> List[List[int]]:
        # write code here
        used=[False]*len(num)
        res=[]
        temp=[]
        def backtrack():
            if len(temp)==len(num):
                res.append(temp[:])
                return
            for i in range(len(num)):
                if used[i]:
                    continue
                temp.append(num[i])
                used[i]=True
                backtrack()
                temp.pop(-1)
                used[i]=False
        backtrack()
        return res

有重复数字的全排列

class Solution:
    def permuteUnique(self , num: List[int]) -> List[List[int]]:
        # write code here
        num.sort()
        used=[False]*len(num)
        res=[]
        temp=[]
        def backtrack():
            if len(temp)==len(num):
                res.append(temp[:])
                return
            for i in range(len(num)):
                if used[i]:
                    continue
                if i>0 and num[i]==num[i-1] and used[i-1]:
                    continue
                temp.append(num[i])
                used[i]=True
                backtrack()
                temp.pop(-1)
                used[i]=False
        backtrack()
        return res

括号生成问题

class Solution:
    def generateParenthesis(self , n: int) -> List[str]:
        # write code here
        left=n
        right=n
        res=[]
        temp=[]
        def backtrack(left,right):
            if right<left:
                return
            if right<0 or left<0:
                return
            if left==0 and right==0:
                s="".join(temp)
                res.append(s)
            temp.append("(")
            backtrack(left-1,right)
            temp.pop(-1)

            temp.append(")")
            backtrack(left,right-1)
            temp.pop(-1)
        backtrack(left,right)
        return res

N皇后问题

class Solution:
    def Nqueen(self , n: int) -> int:
        # write code here
        matrix = [["." for _ in range(n)] for _ in range(n)]
        res = 0
        def check(r, c, matrix):
            for i in range(r):
                if matrix[i][c] == "Q":
                    return False
            i, j = r, c
            while i > 0 and j > 0:
                if matrix[i - 1][j - 1] == "Q":
                    return False
                i -= 1
                j -= 1
            i, j = r, c
            while i > 0 and j < n - 1:
                if matrix[i - 1][j + 1] == "Q":
                    return False
                i -= 1
                j += 1
            return True
        def dfs(r):
            nonlocal res, matrix
            if r == n:
                res += 1
                return
            for i in range(n):
                if check(r, i, matrix):
                    matrix[r][i] = "Q"
                    dfs(r + 1)
                    matrix[r][i] = "."
        dfs(0)
        return res

矩阵路径问题

矩阵最长递增路径

class Solution:
    global dirs
    #记录四个方向
    dirs = [[-1, 0], [1, 0], [0, -1], [0, 1]] 
    global n, m
    #深度优先搜索,返回最大单元格数
    def dfs(self, matrix:List[List[int]], dp: List[List[int]], i:int, j:int) :
        if dp[i][j] != 0:
            return dp[i][j]
        dp[i][j] += 1
        for k in range(4):
            nexti = i + dirs[k][0]
            nextj = j + dirs[k][1]
            #判断条件
            if  nexti >= 0 and nexti < n and nextj >= 0 and nextj < m and matrix[nexti][nextj] > matrix[i][j]:
                dp[i][j] = max(dp[i][j], self.dfs(matrix, dp, nexti, nextj) + 1)
        return dp[i][j]
    
    def solve(self , matrix: List[List[int]]) -> int:
        global n,m
        #矩阵不为空
        if len(matrix) == 0 or len(matrix[0]) == 0:
            return 0
        res = 0
        n = len(matrix)
        m = len(matrix[0])
        #i,j处的单元格拥有的最长递增路径
        dp = [[0 for col in range(m)] for row in range(n)]  
        for i in range(n):
            for j in range(m):
                #更新最大值
                res = max(res, self.dfs(matrix, dp, i, j)) 
        return res

岛屿问题

class Solution:
    def solve(self , grid: List[List[str]]) -> int:
        # write code here
        def dfs(i,j):
            if i<0 or j<0 or i>len(grid)-1 or j >len(grid[0])-1:
                return
            if grid[i][j]=="0":
                return
            grid[i][j]="0"
            dfs(i+1,j)
            dfs(i,j+1)
            dfs(i-1,j)
            dfs(i,j-1)
        res=0
        for i in range(len(grid)):
            for j in range(len(grid[0])):
                if grid[i][j]=="1":
                    res+=1
                    dfs(i,j)
        return res

你可能感兴趣的:(力扣,算法)