python 回溯算法总结

python 回溯算法

  • 回溯算法理论基础
  • 组合
  • 组合总数III
  • 电话号码的字母组合
  • 组合总和
  • 组合总和ii
  • 分割回文串
  • 复原IP地址
  • 子集问题
  • 子集问题II
  • 递增序列
  • 全排列
  • 全排列II
  • 重新安排行程
  • N皇后
  • 解数独


回溯算法理论基础

回溯算法解决的问题都可以抽象为树形结构(N叉树),用树形结构来理解回溯会容易很多。

回溯法一般可以解决如下几种问题:

  • 组合问题:N个数里面按一定规则找出K个数的集合
  • 切割问题:一个字符串按一定的规则有几种切割方式
  • 子集问题:一个N个数的集合里有多少符合条件的子集
  • 排列问题:N个数按一定规则全排列,有几种排列方式
  • 棋盘问题:N皇后,解数独等

回溯算法模板:

  • 回溯算法的返回值和参数
    回溯算法中返回值一般为空,参数更具后续的逻辑来,先写逻辑,再写参数,需要什么参数就填什么参数

  • 回溯终止条件:存放结果

  • 回溯的遍历过程:处理节点, for(调用自己, 回溯(递归),撤销处理结果)。for为横向遍历,递归为纵向遍历

组合

# 不剪枝
class Solution:
    def combine(self, n: int, k: int) -> List[List[int]]:
        self.res = []  # 存放符合条件结果的集合
        self.combine = []  # 用来存放符合条件结果
        self.backtrack(n, k , 1)
        return self.res
    
    def backtrack(self, n, k, start_index):
        if len(self.combine) == k:
            # 这里要append的是self.combine里面的值而不是self.combine,因为self.combine在循环里被pop()最后都是空的
            self.res.append(self.combine[:])
            return
        
        for i in range(start_index, n + 1):
            self.combine.append(i)  # 处理节点
            self.backtrack(n, k , i + 1)  # 递归
            self.combine.pop()  # 回溯,撤销处理的节点
# 剪枝
class Solution:
    def combine(self, n: int, k: int) -> List[List[int]]:
        self.res = []
        self.combine = []
        self.backtrack(n, k , 1)
        return self.res
    
    def backtrack(self, n, k, start_index):
        if len(self.combine) == k:
            # 这里要append的是self.combine里面的值而不是self.combine,因为self.combine在循环里被pop()最后都是空的
            self.res.append(self.combine[:])
            return 
        
        for i in range(start_index, n - (k - len(self.combine)) + 2):  # 剪枝
            self.combine.append(i)
            self.backtrack(n, k , i + 1)
            self.combine.pop()

组合总数III

class Solution:
    def combinationSum3(self, k: int, n: int) -> List[List[int]]:
        self.res = []
        self.combine = []
        self.backtrack(n, k, 1)
        return self.res

    def backtrack(self, n, k, start_index):
        if sum(self.combine) > n:
            return
        # 这里把计算sum放前面可能会超时,因为每个都要计算,所以先计算长度
        if len(self.combine) == k and sum(self.combine) == n:
            self.res.append(self.combine[:])
            return
        
        for i in range(start_index, 9 - (k - len(self.combine)) + 1 + 1):
            self.combine.append(i)
            self.backtrack(n, k, i + 1)
            self.combine.pop()

电话号码的字母组合

class Solution:
    def letterCombinations(self, digits: str) -> List[str]:
        if not digits: return []
        self.res = []
        self.answer = ''
        self.number = list(digits)
        self.map = {'2': 'abc',
                    '3': 'def',
                    '4': 'ghi',
                    '5': 'jkl',
                    '6': 'mno',
                    '7': 'pqrs',
                    '8': 'tuv',
                    '9': 'wxyz'}
        self.backtrack(0)
        return self.res
        
    def backtrack(self, index):
        if index == len(self.number):
            self.res.append(self.answer[:])
            return
        
        letters = self.map[self.number[index]]
        for letter in letters:
            self.answer += letter
            self.backtrack(index + 1)
            self.answer = self.answer[:-1]

组合总和

class Solution:
    def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]:
        if target < all(_ for _ in candidates):
            return []
        self.len = len(candidates)
        self.result = []
        self.answer = []
        self.sum = 0
        self.backtrack(candidates, target, 0)
        return self.result

    
    def backtrack(self, candidates, target, start_index):
        if self.sum == target:
            self.result.append(self.answer[:])
            return
        if self.sum > target:
            return
        
        for i in range(start_index, self.len):
            self.answer.append(candidates[i])
            self.sum += candidates[i]
            # 这里你可以从该点选无数次,但是之后要从后面选,不能选了后面的之后又从前面选,这样就会重复
            # 如果不从i开始,会有重复的answer出现
            self.backtrack(candidates, target, i)
            self.sum -= candidates[i]
            self.answer.pop()

组合总和ii

class Solution:
    def combinationSum2(self, candidates: List[int], target: int) -> List[List[int]]:
        if target < all(_ for _ in candidates):
            return []
        self.len = len(candidates)
        self.result = []
        self.answer = []
        self.sum = 0
        # 提前排序
        candidates.sort()
        self.backtrack(candidates, target, 0)
        return self.result

    
    def backtrack(self, candidates, target, start_index):
        if self.sum == target:
            self.result.append(self.answer[:])
            return
        if self.sum > target:
            return
        
        for i in range(start_index, self.len):
            # 不能有相同的组合,意味着同一层不能使用相同的元素
            # 跳过同一层中已经使用过的元素
            if i > start_index and candidates[i] == candidates[i - 1]:
                continue

            self.answer.append(candidates[i])
            self.sum += candidates[i]
            self.backtrack(candidates, target, i + 1)
            self.sum -= candidates[i]
            self.answer.pop()
# 用used去重
class Solution:
    def combinationSum2(self, candidates: List[int], target: int) -> List[List[int]]:
        if target < all(_ for _ in candidates):
            return []
        self.len = len(candidates)
        self.result = []
        self.answer = []
        self.sum = 0
        self.used = [False] * len(candidates)
        candidates.sort()
        self.backtrack(candidates, target, 0)
        return self.result

    
    def backtrack(self, candidates, target, start_index):
        if self.sum == target:
            self.result.append(self.answer[:])
            return
        if self.sum > target:
            return
        
        for i in range(start_index, self.len):
            # 检查同一树层是否出现曾经使用过的相同元素
            # 若数组中前后元素值相同,但前者却未被使用(used == False),说明是for loop中的同一树层的相同元素情况
            if i > 0 and candidates[i] == candidates[i - 1] and self.used[i - 1] == False:
                continue

            self.answer.append(candidates[i])
            self.sum += candidates[i]
            self.used[i] = True
            self.backtrack(candidates, target, i + 1)
            self.used[i] = False
            self.sum -= candidates[i]
            self.answer.pop()

分割回文串

class Solution:
    def partition(self, s: str) -> List[List[str]]:
        self.res = []
        self.answer = []
        self.backtrack(s, 0)
        return self.res
    
    def backtrack(self, s, start_index):
    # 这里只answer只存回文,如果不是回文就continue,所有当start_index走到最后,answer里的都是回文,而且已经全部切割完成了
        if start_index >= len(s):
            self.res.append(self.answer[:])
            return 
        
        for i in range(start_index, len(s)):
            if self.is_palindrome(s, start_index, i):
                self.answer.append(s[start_index: i + 1])
                self.backtrack(s, i + 1)
                self.answer.pop()
            else:
                continue


    def is_palindrome(self, s, start, end):
        left = start
        right = end
        while left < right:
            if s[left] != s[right]:
                return False
            left += 1
            right -= 1
        return True

复原IP地址

class Solution:
    def restoreIpAddresses(self, s):
        if len(s) > 16:
            return []
        self.res = []
        self.answer = ''
        self.backtrack(s, 0, 0)
        return self.res

    def backtrack(self, s, start_index, times):
        if times == 3:
            if len(s[start_index:]) > 3:
                return
            self.answer += s[start_index:]
            if self.is_vaild_ip(self.answer):
                self.res.append(self.answer[:])
            return

        for i in range(start_index, min(len(s), start_index+3)):
            self.answer += s[start_index: i + 1] + '.'
            self.backtrack(s, i + 1, times + 1)
            self.answer = self.answer[: start_index + times]


    def is_vaild_ip(self, s):
        ips = list(s.split('.'))
        for ip in ips:
            if ip:
                if int(ip) > 255:
                    return False
                if len(ip) != 1 and ip[0] == '0':
                    return False
            else:
                return False
        return True

子集问题

class Solution:
    def subsets(self, nums: List[int]) -> List[List[int]]:
        self.res = []
        self.answer = []
        self.backtrack(nums, 0)
        return self.res
    
    def backtrack(self, nums, start_index):
        self.res.append(self.answer[:])
        if start_index == len(nums):
            return
        
        for i in range(start_index, len(nums)):
            self.answer.append(nums[i])
            self.backtrack(nums, i + 1)
            self.answer.pop()

子集问题II

class Solution:
    def subsetsWithDup(self, nums: List[int]) -> List[List[int]]:
        self.res = []
        self.answer = []
        nums.sort()
        self.backtrack(nums, 0)
        return self.res

    def backtrack(self, nums, start_index):
        self.res.append(self.answer[:])
        if start_index == len(nums):
            return
        
        for i in range(start_index, len(nums)):
            if i > start_index and nums[i] == nums[i - 1]:
                continue
            self.answer.append(nums[i])
            self.backtrack(nums, i + 1)
            self.answer.pop()

递增序列

用useg去重

class Solution:
    def findSubsequences(self, nums: List[int]) -> List[List[int]]:
        self.res = []
        self.answer = []
        self.backtrack(nums, 0)
        return self.res

    def backtrack(self, nums, start_index):
        if len(self.answer) > 1:
            self.res.append(self.answer[:])
        if start_index == len(nums):
            return
        
        useg = set()
        
        for i in range(start_index, len(nums)):
            if self.answer and nums[i] < self.answer[-1] or nums[i] in useg:
                continue
            useg.add(nums[i])
            self.answer.append(nums[i])
            self.backtrack(nums, i + 1)
            self.answer.pop()

全排列

class Solution:
    def permute(self, nums: List[int]) -> List[List[int]]:
        self.res = []
        self.answer = []
        self.backtrack(nums)
        return self.res

    def backtrack(self, nums):
        if not nums:
            self.res.append(self.answer[:])
        
        for i in range(len(nums)):
            self.answer.append(nums[i])
            _pop = nums.pop(i)
            self.backtrack(nums)
            nums.insert(i, _pop)
            self.answer.pop()

最优:

class Solution:
    def permute(self, nums: List[int]) -> List[List[int]]:
        self.res = []
        self.answer = []
        self.backtrack(nums)
        return self.res

    def backtrack(self, nums):
        if len(nums) == len(self.answer):
            self.res.append(self.answer[:])
        
        for i in range(len(nums)):
            if nums[i] in self.answer:
                continue
            self.answer.append(nums[i])
            self.backtrack(nums)
            self.answer.pop()
class Solution:
    def permute(self, nums: List[int]) -> List[List[int]]:
        self.res = []
        self.answer = []
        self.useg = []
        self.backtrack(nums)
        return self.res

    def backtrack(self, nums):
        if len(nums) == len(self.answer):
            self.res.append(self.answer[:])

        for i in range(len(nums)):
            if nums[i] in self.useg:
                continue
            self.useg.append(nums[i])
            self.answer.append(nums[i])
            self.backtrack(nums)
            self.answer.pop()
            self.useg.pop()

全排列II

class Solution:
    def permuteUnique(self, nums: List[int]) -> List[List[int]]:
        self.res = []
        self.answer = []
        nums.sort()
        self.used = [0] * len(nums)
        self.backtrack(nums)
        return self.res

    def backtrack(self, nums):
        if len(nums) == len(self.answer):
            self.res.append(self.answer[:])
            return

        for i in range(len(nums)):
            if not self.used[i]:
            	# 这里的意思是每种值按顺序选取,但是选取值的顺序随机
                if i > 0 and nums[i] == nums[i - 1] and not self.used[i - 1]:
                    continue
                self.used[i] = 1
                self.answer.append(nums[i])
                self.backtrack(nums)
                self.answer.pop()
                self.used[i] = 0

重新安排行程

from collections import defaultdict

class Solution:
    def findItinerary(self, tickets: List[List[str]]) -> List[str]:
        self.path = ['JFK']
        self.len_tickets = len(tickets)
        self.ticket_dict = defaultdict(list)
        for _, ticket in enumerate(tickets):
            self.ticket_dict[ticket[0]].append(ticket[1])
        self.Travel('JFK')
        return self.path
    
    def Travel(self, start):
        if len(self.path) == self.len_tickets + 1:
            return True
        
        self.ticket_dict[start].sort()
        for _ in self.ticket_dict[start]:
            end = self.ticket_dict[start].pop(0)
            self.path.append(end)
            if self.Travel(end):  # 这里找到一个就返回 path, 设计的就很巧妙
                return True
            self.path.pop()
            self.ticket_dict[start].append(end)

N皇后

class Solution:
    def solveNQueens(self, n: int) -> List[List[str]]:
        self.n = n
        self.result = []
        # 这里有个问题:不能写成[['.'] * n] * n, 不然后面赋值的时候会出问题
        self.chess_table = [['.'] * n for _ in range(n)]
        self.backtrack(0)
        return self.result

    def backtrack(self, row):
        if row == self.n:
            answer = []
            for item in self.chess_table:
                answer.append(''.join(item))
            self.result.append(answer[:])
            return

        for col in range(self.n):
            if not self.is_vaild(row, col):
                continue
            self.chess_table[row][col] = 'Q'
            self.backtrack(row + 1)
            self.chess_table[row][col] ='.'
                
                
    def is_vaild(self, row, col):
        # 列
        for i in range(self.n):
            if self.chess_table[i][col] == 'Q':
                return False
            
        # 左上角
        i, j = row - 1, col - 1
        while i >=0 and j >=0 :
            if self.chess_table[i][j] == 'Q':
                return False
            i -= 1
            j -= 1
        # 右上角
        i, j = row - 1, col + 1
        while i >=0 and j < self.n:
            if self.chess_table[i][j] == 'Q':
                return False
            i -= 1
            j += 1

        return True

解数独

class Solution:
    def solveSudoku(self, board: List[List[str]]) -> None:
        """
        Do not return anything, modify board in-place instead.
        """
        self.backtrack(board)
  
    def backtrack(self, board):
        for i in range(9):
            for j in range(9):
                if board[i][j] != '.':
                    continue
                for num in range(1, 10):
                    if self.is_vaild(i, j, num, board):
                        board[i][j] = str(num)
                        if self.backtrack(board): return True
                        board[i][j] = '.'
                # 若1到9填入都没有用,则无解
                return False
        # 如果走到了最后一个就返回 True
        return True
        
    def is_vaild(self, row, col, num, board):
        for i in range(9):
            if board[row][i] == str(num):
                return False
        for i in range(9): 
            if board[i][col] == str(num):
                return False

        space_row = row // 3
        space_col = col // 3
        for i in range(space_row * 3, (space_row + 1) * 3):
            for j in range(space_col * 3, (space_col + 1) * 3):
                if board[i][j] == str(num):
                    return False  
        return True

你可能感兴趣的:(数据结构)