回溯问题的解决核心是遍历决策树. 在进行决策时, 实际需要考虑的问题是:
回溯问题的框架为:
result = []
def backtrack(路径, 选择列表):
if 满足结束条件:
result.add(新路径)
return
for 选择 in 选择列表:
作出新选择 # 在原有决策的基础上作出下一步的选择
backtrack(新路径, 新选择列表) #检查能否进一步做选择
撤销新选择 # 如果不能进一步做选择的话则撤销无效的选择
下面以数独求解问题为例:
Q: 给定一个 9 ∗ 9 9*9 9∗9 的数独, 求出它的其中一个解.
程序输出:
def print_board(self, flag=0):
# Print the sudoku board
if flag == 1:
print("\n= = = = = SOLUTION = = = = = \n")
for r in range(self.len):
if r % 3 == 0 and r != 0:
print("= = = = = = = = = = = =")
for c in range(len(self.board[0])):
if c % 3 == 0 and c != 0:
print(" | ", end="")
if c == 8:
print(self.board[r][c])
else:
print(str(self.board[r][c]) + " ", end="")
def stats(self):
self.end = time.time()
print("Time elapsed: {}s".format(self.end - self.start))
程序逻辑:
def search_empty(self):
# Search for empty slots
for r in range(self.len):
for c in range(self.len):
if self.board[r][c] == 0:
return r, c # return row, column
return None
def find_valid_nums(self, pos):
# find all possible valid numbers for position 'pos'
num_list = list(range(1, 10))
for i in range(9):
if self.board[pos[0]][i] != 0 and self.board[pos[0]][i] in num_list:
num_list.remove(self.board[pos[0]][i])
if self.board[i][pos[1]] != 0 and self.board[i][pos[1]] in num_list:
num_list.remove(self.board[i][pos[1]])
x, y = pos[0] // 3, pos[1] // 3
for i in range(3*x, 3*x + 3):
for j in range(3*y, 3*y + 3):
if self.board[i][j] != 0 and self.board[i][j] in num_list:
num_list.remove(self.board[i][j])
return num_list
def validate(self, num, pos):
# Validate the selected choice
for c in range(len(self.board[0])):
if self.board[pos[0]][c] == num and pos[1] != c:
return False
for r in range(len(self.board[0])):
if self.board[r][pos[1]] == num and pos[0] != r:
return False
x = pos[0] // 3
y = pos[1] // 3
for i in range(x*3, x*3 + 3):
for j in range(y*3, y*3 + 3):
if self.board[i][j] == num and (i,j) != pos:
return False
def solution(self):
# Solution, using DFS Algorithm
find = self.search_empty()
if not find:
return True
else:
row, col = find
for i in self.find_valid_nums([row, col]):
if self.validate(i, (row, col)):
self.board[row][col] = i
if self.solution():
return True
self.board[row][col] = 0
return False
程序整理总结如下:
import time
class Sudoku:
"""Define Sudoku Class, and work out its solution"""
def __init__(self):
# Initialize
self.board = [
[7, 0, 0, 0, 3, 4, 8, 0, 0],
[8, 0, 4, 6, 0, 0, 0, 0, 0],
[0, 3, 9, 0, 5, 0, 0, 0, 0],
[1, 0, 0, 5, 0, 0, 6, 0, 0],
[0, 4, 0, 7, 0, 9, 0, 3, 0],
[0, 0, 3, 0, 0, 8, 0, 0, 9],
[0, 0, 0, 0, 7, 0, 3, 2, 0],
[0, 2, 6, 0, 0, 1, 9, 0, 5],
[0, 0, 7, 9, 2, 0, 0, 0, 4]
]
self.len = len(self.board)
self.start = time.time()
self.end = 0
def main(self):
# Construct Main Program, maintain outputs
self.print_board()
self.solution()
self.print_board(1)
self.stats()
def search_empty(self):
# Search for empty slots
for r in range(self.len):
for c in range(self.len):
if self.board[r][c] == 0:
return r, c # return row, column
return None
def print_board(self, flag=0):
# Print the sudoku board
if flag == 1:
print("\n= = = = = SOLUTION = = = = = \n")
for r in range(self.len):
if r % 3 == 0 and r != 0:
print("= = = = = = = = = = = =")
for c in range(len(self.board[0])):
if c % 3 == 0 and c != 0:
print(" | ", end="")
if c == 8:
print(self.board[r][c])
else:
print(str(self.board[r][c]) + " ", end="")
def find_valid_nums(self, pos):
# find all possible valid numbers for position 'pos'
num_list = list(range(1, 10))
for i in range(9):
if self.board[pos[0]][i] != 0 and self.board[pos[0]][i] in num_list:
num_list.remove(self.board[pos[0]][i])
if self.board[i][pos[1]] != 0 and self.board[i][pos[1]] in num_list:
num_list.remove(self.board[i][pos[1]])
x, y = pos[0] // 3, pos[1] // 3
for i in range(3*x, 3*x + 3):
for j in range(3*y, 3*y + 3):
if self.board[i][j] != 0 and self.board[i][j] in num_list:
num_list.remove(self.board[i][j])
return num_list
def validate(self, num, pos):
# Validate the selected choice
for c in range(len(self.board[0])):
if self.board[pos[0]][c] == num and pos[1] != c:
return False
for r in range(len(self.board[0])):
if self.board[r][pos[1]] == num and pos[0] != r:
return False
x = pos[0] // 3
y = pos[1] // 3
for i in range(x*3, x*3 + 3):
for j in range(y*3, y*3 + 3):
if self.board[i][j] == num and (i,j) != pos:
return False
return True
def solution(self):
# Solution, using DFS Algorithm
find = self.search_empty()
if not find:
return True
else:
row, col = find
for i in self.find_valid_nums([row, col]):
if self.validate(i, (row, col)):
self.board[row][col] = i
if self.solution():
return True
self.board[row][col] = 0
return False
def stats(self):
self.end = time.time()
print("Time elapsed: {}s".format(self.end - self.start))
if __name__ == "__main__":
sud = Sudoku()
sud.main()