回溯算法:以数独求解为例

回溯算法: 以数独求解为例

回溯问题的解决核心是遍历决策树. 在进行决策时, 实际需要考虑的问题是:

  1. 路径: 目前已经做的决策是什么?
  2. 选择列表: 还可以做那些选择?
  3. 结束条件: 何时算作 “到达决策树底端”, 问题解决?

回溯问题的框架为:

result = []
def backtrack(路径, 选择列表):
    if 满足结束条件:
        result.add(新路径)
        return
    
    for 选择 in 选择列表:
        作出新选择      # 在原有决策的基础上作出下一步的选择

        backtrack(新路径, 新选择列表)       #检查能否进一步做选择
        
        撤销新选择      # 如果不能进一步做选择的话则撤销无效的选择

下面以数独求解问题为例:

Q: 给定一个 9 ∗ 9 9*9 99 的数独, 求出它的其中一个解.

程序输出:

  1. 我们需要打印待求解的数独和数独的一个解. 鉴于我们求解过程中修改是在数独的基础上进行的, 因此两个情况的代码可以复用:
    def print_board(self, flag=0):
        # Print the sudoku board

        if flag == 1:
            print("\n= = = = = SOLUTION = = = = = \n")
        for r in range(self.len):
            if r % 3 == 0 and r != 0:
                print("= = = = = = = = = = = =")
            for c in range(len(self.board[0])):
                if c % 3 == 0 and c != 0:
                    print(" | ", end="")
                if c == 8:
                    print(self.board[r][c])
                else:
                    print(str(self.board[r][c]) + " ", end="")
  1. 除此以外, 我们还希望得知程序的运行时间以评价执行效率:
    def stats(self):
        self.end = time.time()
        print("Time elapsed: {}s".format(self.end - self.start))

程序逻辑:

  1. 要求解数独, 我们首先需要找到待填的空位:
    def search_empty(self):
        # Search for empty slots

        for r in range(self.len):
            for c in range(self.len):
                if self.board[r][c] == 0:
                    return r, c   # return row, column

        return None
  1. 要定义回溯函数, 我们还需要确定选择列表:
    def find_valid_nums(self, pos):
        # find all possible valid numbers for position 'pos'

        num_list = list(range(1, 10))
        for i in range(9):
            if self.board[pos[0]][i] != 0 and self.board[pos[0]][i] in num_list:
                num_list.remove(self.board[pos[0]][i])
            if self.board[i][pos[1]] != 0 and self.board[i][pos[1]] in num_list:
                num_list.remove(self.board[i][pos[1]])
        x, y = pos[0] // 3, pos[1] // 3
        for i in range(3*x, 3*x + 3):
            for j in range(3*y, 3*y + 3):
                if self.board[i][j] != 0 and self.board[i][j] in num_list:
                    num_list.remove(self.board[i][j])
        return num_list
  1. 此外, 我们还需要依据数独规则判断作出的选择是否有效:
    def validate(self, num, pos):
        # Validate the selected choice

        for c in range(len(self.board[0])):
            if self.board[pos[0]][c] == num and pos[1] != c:
                return False
        for r in range(len(self.board[0])):
            if self.board[r][pos[1]] == num and pos[0] != r:
                return False
        x = pos[0] // 3
        y = pos[1] // 3
        for i in range(x*3, x*3 + 3):
            for j in range(y*3, y*3 + 3):
                if self.board[i][j] == num and (i,j) != pos:
                    return False
  1. 最后, 我们可以定义回溯主函数:
    def solution(self):
        # Solution, using DFS Algorithm

        find = self.search_empty()
        if not find:
            return True
        else:
            row, col = find
        for i in self.find_valid_nums([row, col]):
            if self.validate(i, (row, col)):
                self.board[row][col] = i

                if self.solution():
                    return True

                self.board[row][col] = 0

        return False

程序整理总结如下:

import time


class Sudoku:
    """Define Sudoku Class, and work out its solution"""
    def __init__(self):
        # Initialize

        self.board = [
            [7, 0, 0, 0, 3, 4, 8, 0, 0],
            [8, 0, 4, 6, 0, 0, 0, 0, 0],
            [0, 3, 9, 0, 5, 0, 0, 0, 0],
            [1, 0, 0, 5, 0, 0, 6, 0, 0],
            [0, 4, 0, 7, 0, 9, 0, 3, 0],
            [0, 0, 3, 0, 0, 8, 0, 0, 9],
            [0, 0, 0, 0, 7, 0, 3, 2, 0],
            [0, 2, 6, 0, 0, 1, 9, 0, 5],
            [0, 0, 7, 9, 2, 0, 0, 0, 4]
        ]
        self.len = len(self.board)
        self.start = time.time()
        self.end = 0

    def main(self):
        # Construct Main Program, maintain outputs

        self.print_board()
        self.solution()
        self.print_board(1)
        self.stats()

    def search_empty(self):
        # Search for empty slots

        for r in range(self.len):
            for c in range(self.len):
                if self.board[r][c] == 0:
                    return r, c   # return row, column

        return None

    def print_board(self, flag=0):
        # Print the sudoku board
        if flag == 1:
            print("\n= = = = = SOLUTION = = = = = \n")
        for r in range(self.len):
            if r % 3 == 0 and r != 0:
                print("= = = = = = = = = = = =")
            for c in range(len(self.board[0])):
                if c % 3 == 0 and c != 0:
                    print(" | ", end="")
                if c == 8:
                    print(self.board[r][c])
                else:
                    print(str(self.board[r][c]) + " ", end="")

    def find_valid_nums(self, pos):
        # find all possible valid numbers for position 'pos'

        num_list = list(range(1, 10))
        for i in range(9):
            if self.board[pos[0]][i] != 0 and self.board[pos[0]][i] in num_list:
                num_list.remove(self.board[pos[0]][i])
            if self.board[i][pos[1]] != 0 and self.board[i][pos[1]] in num_list:
                num_list.remove(self.board[i][pos[1]])
        x, y = pos[0] // 3, pos[1] // 3
        for i in range(3*x, 3*x + 3):
            for j in range(3*y, 3*y + 3):
                if self.board[i][j] != 0 and self.board[i][j] in num_list:
                    num_list.remove(self.board[i][j])
        return num_list

    def validate(self, num, pos):
        # Validate the selected choice

        for c in range(len(self.board[0])):
            if self.board[pos[0]][c] == num and pos[1] != c:
                return False
        for r in range(len(self.board[0])):
            if self.board[r][pos[1]] == num and pos[0] != r:
                return False
        x = pos[0] // 3
        y = pos[1] // 3
        for i in range(x*3, x*3 + 3):
            for j in range(y*3, y*3 + 3):
                if self.board[i][j] == num and (i,j) != pos:
                    return False

        return True

    def solution(self):
        # Solution, using DFS Algorithm

        find = self.search_empty()
        if not find:
            return True
        else:
            row, col = find
        for i in self.find_valid_nums([row, col]):
            if self.validate(i, (row, col)):
                self.board[row][col] = i

                if self.solution():
                    return True

                self.board[row][col] = 0

        return False

    def stats(self):
        self.end = time.time()
        print("Time elapsed: {}s".format(self.end - self.start))


if __name__ == "__main__":
    sud = Sudoku()
    sud.main()

你可能感兴趣的:(算法)