166 数独(dfs之剪枝)

1. 问题描述:

数独是一种传统益智游戏,你需要把一个 9×9 的数独补充完整,使得图中每行、每列、每个 3×3 的九宫格内数字 1∼9 均恰好出现一次。请编写一个程序填写数独。

输入格式

输入包含多组测试用例。每个测试用例占一行,包含 81 个字符,代表数独的 81 个格内数据(顺序总体由上到下,同行由左到右)。每个字符都是一个数字(1−9)或一个 .(表示尚未填充)。您可以假设输入中的每个谜题都只有一个解决方案。文件结尾处为包含单词 end 的单行,表示输入结束。

输出格式

每个测试用例,输出一行数据,代表填充完全后的数独。

输入样例:

4.....8.5.3..........7......2.....6.....8.4......1.......6.3.7.5..2.....1.4......
......52..8.4......3...9...5.1...6..2..7........3.....6...1..........7.4.......3.
end

输出样例:

417369825632158947958724316825437169791586432346912758289643571573291684164875293
416837529982465371735129468571298643293746185864351297647913852359682714128574936
来源:https://www.acwing.com/problem/content/description/168/

2. 思路分析:

1. 数独问题是一道很经典的dfs搜索问题,对于dfs需要搜索所有方案的问题,首先需要考虑的问题是如何搜索才可以将所有的方案枚举出来,也即需要考虑枚举的顺序将所有方案枚举出来,比较容易想到的是我们可以任意选择一些没有填的格子,然后枚举可以填哪些数字,这样枚举一定可以将所有方案枚举出来,第二个问题是dfs的剪枝优化,对于这道题目其实有三个优化:第一个优化:优化选择顺序,之前是随意选择一些还没有填的格子,其实我们可以选择一些分支数量较少的格子进行枚举,也即先选择一些可以选择填的数字比较少的空格子进行枚举;第二个优化:可行性剪枝,当前选择的数字与当前的行,列,以及对应的九宫格是没有重复的;第三个优化:虽然这道题目使用到的剪枝优化比较少,但是由于每一行,列以及九宫格都对应一个状态,所以我们可以使用二进制状态进行优化,也即可以使用位运算进行优化。可以声明row,col,cell来记录每一行,每一列,每一个九宫格对应的二进制状态,其中cell为3 * 3的二维数组或者二维列表,cell[i][j]表示(x,y)对应在9 * 9的数独中的某一个九宫格状态(i = x / 3,j = y / 3),后面在递归的时候可以通过(x / 3, y / 3) 找到数独中九宫格的位置,一开始的时候他们的初始状态都是111111111(2),第i位为1表示数字i是可以使用的,为0表示不可以使用,所以只有当行,列,九宫格的第i位都是1的时候说明数字i是可以使用的,可以使用位运算中的与运算获取对应的状态,在输入数独的时候初始化对应的行,列,九宫格对应的状态。一开始的时候可以预处理ones和map,ones存储(0, 1<< n)中每一个二进制状态对应的1的个数,map映射1 << i位的数字对应的数字为i。

2. 在dfs搜索的时候首先找到分支数量较少还没有填数字的二维坐标(x,y),分支数量较少表示对应的行,列以及九宫格相与的二进制状态对应的1的数目较少,这样在搜索的时候当前位置可以选择填的数字就越少;找到分支数量较少的二维坐标(x,y)之后计算出行,列,九宫格相与的二进制状态,二进制状态中的1表示可以填当前的第i位数字,这里可以使用lowbit函数进行优化,找到当前二进制状态最低位的1对应的数字(1000100,当前最低位的1对应的数字为100,也即4),然后使用之前预处理的map计算出对应可以填的数字是几,枚举所有可以填的数字那么递归即可,在递归之前需要将行,列以及对应的九宫格状态减去对应的数字,这样对应位的1就会变成0,这样使用完当前的数字之后往下递归的时候当前的数字就不可以填了,可以使用一个draw方法,传入一个flag标记表示当前是填数字还是回溯的过程,如果是回溯那么则需要减去对应的相反数。

3. 代码如下:

python代码(超时):python语言有的测试数据时间长的需要大概计算6s左右

from typing import List


class Solution:
    res = ""

    # flag为是将某个位置填某个数字还是回溯的标记
    def draw(self, x: int, y: int, t: int, row: List[int], col: List[int], cell: List[List[int]], g: List[List[str]],
             flag: int):
        if flag == 1:
            g[x][y] = str(t + 1)
        else:
            g[x][y] = "."
        # 左移对应的位数那么相当于将9位二进制中的某一位置为1
        v = 1 << int(t)
        # flag = 0表示回溯, 所以要取反
        if flag == 0: v = -v
        # 如果某一个位置填当前的数字那么应该将对应的行, 列, 以及所在的九宫格的数字减去v, 回溯的时候相当于是恢复为之前的数字, 相当于加上v, 减去对应的数字表示对应的数字不可用
        row[x] -= v
        col[y] -= v
        # x // 3, y // 3表示二维坐标中对应的是哪一个九宫格
        cell[x // 3][y // 3] -= v

    # 获取第x行, y列, 以及当前行与列对应的九宫格相与的状态, 只有当三种状态对应的位置都是1的时候才可以使用对应位置上的数字
    def get(self, x: int, y: int, row: List[int], col: List[int], cell: List[List[int]]):
        return row[x] & col[y] & cell[x // 3][y // 3]

    # 获取x的最低位1对应的数字, 这样在后面枚举的时候不用每一次循环的时候都枚举9次, 而是有多少个1枚举多少次
    def lowbit(self, x: int):
        return x & -x

    def dfs(self, count: int, n: int, g: List[List[str]], ones: dict, map: dict, row: List[int], col: List[int],
            cell: List[List[int]]):
        if count == 0:
            # 拼接答案返回
            for i in range(n):
                self.res += "".join(g[i])
            return True
        # 枚举整个数独找到分支数量最小的对应的位置, 分支数量较小对应的三种状态相与的数字越小
        minv = 10
        # 找到当前分支数量最少的位置
        x = y = 0
        for i in range(n):
            for j in range(n):
                if g[i][j] == ".":
                    # get方法得到的是三种状态相与的数字对应每一位二进制的情况, 这个时候使用之前预处理的ones计算出相与的状态中1的数目个数, 对应的位就是可以使用的数字
                    t = self.get(i, j, row, col, cell)
                    if ones[t] < minv:
                        minv = ones[t]
                        x, y = i, j
        state = self.get(x, y, row, col, cell)
        # lowbit函数优化, 这样每一次可以不用枚举9次, lowbit函数每一次找到最低位1对应的数字
        while state > 0:
            k = self.lowbit(state)
            state -= k
            # map[k]表示当前获取的最低位1对应的数字是几
            t = map[k]
            self.draw(x, y, t, row, col, cell, g, 1)
            # 只要找到了答案那么直接返回就不用往下继续搜索了
            if self.dfs(count - 1, n, g, ones, map, row, col, cell): return True
            self.draw(x, y, t, row, col, cell, g, 0)
        return False

    # 使用位运算进行优化效率会高一点但是比较麻烦
    def process(self):
        n = 9
        # 先预处理两个字典ones, map, ones表示二进制中1的数目, map将对应二进制中的1映射为整数(声明为字典类型比较好处理)
        ones, map = dict(), dict()
        for i in range(1 << n):
            ones[i] = 0
            for j in range(n):
                # 计算每一个状态中二进制的1的数目
                ones[i] += i >> j & 1
        for i in range(n):
            map[1 << i] = i
        while True:
            s = input()
            if s[0] == "e": break
            # row, col, cell分别表示每一行/列/九宫格的状态
            row, col, cell = [(1 << n) - 1] * n, [(1 << n) - 1] * n, [[(1 << n) - 1] * 3 for i in range(3)]
            # count用来计算数独中未填的数字个数
            k = count = 0
            # 因为最终需要输出答案所以需要搞一个二维列表来存储答案
            g = [["."] * n for i in range(n)]
            for i in range(n):
                for j in range(n):
                    if s[k] != ".":
                        # 将对应的行, 列, 九宫格以及g初始化对应的状态
                        self.draw(i, j, int(s[k]) - 1, row, col, cell, g, 1)
                    else:
                        count += 1
                    k += 1
            self.dfs(count, n, g, ones, map, row, col, cell)
            print(self.res)
            # 重置答案
            self.res = ""


if __name__ == '__main__':
    Solution().process()

c++代码:c++代码确实很快就计算出结果:

#include 
#include 
#include 

using namespace std;

const int N = 9, M = 1 << N;

int ones[M], map[M];
int row[N], col[N], cell[3][3];
char str[100];

void init()
{
    for (int i = 0; i < N; i ++ )
        row[i] = col[i] = (1 << N) - 1;

    for (int i = 0; i < 3; i ++ )
        for (int j = 0; j < 3; j ++ )
            cell[i][j] = (1 << N) - 1;
}

void draw(int x, int y, int t, bool is_set)
{
    if (is_set) str[x * N + y] = '1' + t;
    else str[x * N + y] = '.';

    int v = 1 << t;
    if (!is_set) v = -v;

    row[x] -= v;
    col[y] -= v;
    cell[x / 3][y / 3] -= v;
}

int lowbit(int x)
{
    return x & -x;
}

int get(int x, int y)
{
    return row[x] & col[y] & cell[x / 3][y / 3];
}

bool dfs(int cnt)
{
    if (!cnt) return true;

    int minv = 10;
    int x, y;
    for (int i = 0; i < N; i ++ )
        for (int j = 0; j < N; j ++ )
            if (str[i * N + j] == '.')
            {
                int state = get(i, j);
                if (ones[state] < minv)
                {
                    minv = ones[state];
                    x = i, y = j;
                }
            }

    int state = get(x, y);
    for (int i = state; i; i -= lowbit(i))
    {
        int t = map[lowbit(i)];
        draw(x, y, t, true);
        if (dfs(cnt - 1)) return true;
        draw(x, y, t, false);
    }

    return false;
}

int main()
{
    for (int i = 0; i < N; i ++ ) map[1 << i] = i;
    for (int i = 0; i < 1 << N; i ++ )
        for (int j = 0; j < N; j ++ )
            ones[i] += i >> j & 1;

    while (cin >> str, str[0] != 'e')
    {
        init();

        int cnt = 0;
        for (int i = 0, k = 0; i < N; i ++ )
            for (int j = 0; j < N; j ++, k ++ )
                if (str[k] != '.')
                {
                    int t = str[k] - '1';
                    draw(i, j, t, true);
                }
                else cnt ++ ;

        dfs(cnt);

        puts(str);
    }

    return 0;
}

你可能感兴趣的:(acwing-提高,算法)