回溯算法分析

1.回溯算法介绍

解决一个回溯问题,实际上就是一个决策树的遍历过程。你只需要思考 3 个问题:
1、路径:也就是已经做出的选择。
2、选择列表:也就是你当前可以做的选择。
3、结束条件:也就是到达决策树底层,无法再做选择的条件。
如果你不理解这三个词语的解释,没关系,我们后面会用「全排列」和「N 皇后问题」这两个经典的回溯算法问题来帮你理解这些词语是什么意思,现在你先留着印象。
代码方面,回溯算法的框架:

result = []
def backtrack(路径, 选择列表):
    if 满足结束条件:
        result.add(路径)
        return
    
    for 选择 in 选择列表:
        做选择
        backtrack(路径, 选择列表)
        撤销选择

其核心就是 for 循环里面的递归,在递归调用之前「做选择」,在递归调用之后「撤销选择」,特别简单。
什么叫做选择和撤销选择呢,这个框架的底层原理是什么呢?下面我们就通过「全排列」这个问题来解开之前的疑惑,详细探究一下其中的奥妙!

2.相关题目

(1)全排列问题

我们在高中的时候就做过排列组合的数学题,我们也知道 n 个不重复的数,全排列共有 n! 个。
PS:为了简单清晰起见,我们这次讨论的全排列问题不包含重复的数字。
那么我们当时是怎么穷举全排列的呢?比方说给三个数 [1,2,3],你肯定不会无规律地乱穷举,一般是这样:
先固定第一位为 1,然后第二位可以是 2,那么第三位只能是 3;然后可以把第二位变成 3,第三位就只能是 2 了;然后就只能变化第一位,变成 2,然后再穷举后两位……
其实这就是回溯算法,我们高中无师自通就会用,或者有的同学直接画出如下这棵回溯树:
回溯算法分析_第1张图片
只要从根遍历这棵树,记录路径上的数字,其实就是所有的全排列。我们不妨把这棵树称为回溯算法的「决策树」。

为啥说这是决策树呢,因为你在每个节点上其实都在做决策。比如说你站在下图的红色节点上:
回溯算法分析_第2张图片
你现在就在做决策,可以选择 1 那条树枝,也可以选择 3 那条树枝。为啥只能在 1 和 3 之中选择呢?因为 2 这个树枝在你身后,这个选择你之前做过了,而全排列是不允许重复使用数字的。

现在可以解答开头的几个名词:[2] 就是「路径」,记录你已经做过的选择;[1,3] 就是「选择列表」,表示你当前可以做出的选择;「结束条件」就是遍历到树的底层,在这里就是选择列表为空的时候。

如果明白了这几个名词,可以把「路径」和「选择」列表作为决策树上每个节点的属性,比如下图列出了几个节点的属性:

回溯算法分析_第3张图片
代码:

package backtrace;

import java.util.LinkedList;
import java.util.List;

/**
 * @author chengzhengda
 * @version 1.0
 * @date 2020-04-14 12:13
 * @desc 回溯算法解决全排列问题
 */
public class premute {
    List<List<Integer>> res = new LinkedList<>();

    List<List<Integer>> permute(int[] nums) {
        LinkedList<Integer> track = new LinkedList<>();
        backtrace(nums, track);
        return res;
    }

    void backtrace(int[] nums, LinkedList<Integer> track) {
        if (track.size() == nums.length) {
            res.add(new LinkedList<>(track));
            return;
        }
        for (int i = 0; i < nums.length; i++) {
            // 剪枝
            if (track.contains(nums[i])) {
                continue;
            }
            // 添加节点
            track.add(nums[i]);
            // 回溯
            backtrace(nums, track);
            // 撤销节点
            track.removeLast();
        }

    }

    public static void main(String[] args) {
        premute pp = new premute();
        int[] nums = {1, 2, 3, 4};
        pp.permute(nums);
        for (List<Integer> list : pp.res) {
            for (Integer i : list) {
                System.out.print(i + " ");
            }
            System.out.println();
        }
    }
}

(2)N皇后问题

这个问题很经典了,简单解释一下:给你一个 N×N 的棋盘,让你放置 N 个皇后,使得它们不能互相攻击。
PS:皇后可以攻击同一行、同一列、左上左下右上右下四个方向的任意单位。
这个问题本质上跟全排列问题差不多,决策树的每一层表示棋盘上的每一行;每个节点可以做出的选择是,在该行的任意一列放置一个皇后。

直接套用框架:

package backtrace;

import java.util.ArrayList;
import java.util.List;

/**
 * @author chengzhengda
 * @version 1.0
 * @date 2020-04-14 12:38
 * @desc
 */
public class nQueen {
    List<List<String>> res = new ArrayList<>();

    public List<List<String>> solveNQueens(int n) {
        //棋盘,默认为0表示空,1表示皇后
        int[][] board = new int[n][n];
        //row当前填写得的行号
        backtrack(board, 0);
        return res;
    }

    /**
     * 回溯
     *
     * @param borad
     * @param row
     */
    void backtrack(int[][] borad, int row) {
        int n = borad.length;
        if (row == n) {
            res.add(tranfer(borad));
        }
        for (int col = 0; col < n; col++) {
            if (!isValid(borad, row, col)) {
                continue;
            }
            borad[row][col] = 1;
            backtrack(borad, row + 1);
            borad[row][col] = 0;
        }
    }

    /**
     * 判断(row,col)位置上是否可以放置皇后
     *
     * @param board
     * @param row
     * @param col
     * @return
     */
    private boolean isValid(int[][] board, int row, int col) {
        //检查列上有无皇后
        for (int i = 0; i < row; i++) {
            if (board[i][col] == 1) return false;
        }
        //检查左上至右下对角线有无皇后
        for (int i = col - 1; i >= 0; i--) {
            if (i + row - col < 0) break;
            if (board[i + row - col][i] == 1) return false;
        }
        //检查右上至左下对角线有无皇后
        for (int i = col + 1; i < board.length; i++) {
            if (row + col - i < 0) break;
            if (board[row + col - i][i] == 1) return false;
        }
        return true;
    }

    /**
     * 结果转换
     *
     * @param board
     * @return
     */
    private List<String> tranfer(int[][] board) {
        List<String> list = new ArrayList<>();
        for (int i = 0; i < board.length; i++) {
            StringBuilder str = new StringBuilder();
            for (int j = 0; j < board.length; j++) {
                str.append(board[i][j] == 1 ? 'Q' : ".");
            }
            list.add(str.toString());
        }
        return list;
    }


    public static void main(String[] args) {
        nQueen t = new nQueen();
        List<List<String>> res = t.solveNQueens(4);
        for (List<String> list : res) {
            for (String str : list) {
                System.out.print(str + " ");
            }
            System.out.println();
        }
    }
}

(3)三数之和

给你一个包含 n 个整数的数组 nums,判断 nums 中是否存在三个元素 a,b,c ,使得 a + b + c = 0 ?请你找出所有满足条件且不重复的三元组。
注意:答案中不可以包含重复的三元组。
示例:
给定数组 nums = [-1, 0, 1, 2, -1, -4],
满足要求的三元组集合为:
[
[-1, 0, 1],
[-1, -1, 2]
][]
本题的最优解不是回溯算法,但是所有的树形结构的查找都可以用回溯算法来解决,因此本题套用回溯算法的公式依然可以解决。

package summary;

import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

/**
 * @author chengzhengda
 * @version 1.0
 * @date 2020-04-19 14:30
 * @desc
 */
public class t9 {
    List<List<Integer>> res = new LinkedList<>();

    public static void main(String[] args) {
        int[] nums = {-1, 0, 1, 2, -1, -4, -2};
        t9 tt = new t9();
        List<List<Integer>> lists = tt.threeSum(nums);
        for (List<Integer> list : lists) {
            for (Integer i : list) {
                System.out.print(i + " ");
            }
            System.out.println();
        }
    }


    /**
     * 回溯法
     *
     * @param nums
     * @return
     */
    public List<List<Integer>> threeSum(int[] nums) {
        LinkedList<Integer> trace = new LinkedList<>();
        backtrack(nums, trace, 0);
        return res;

    }

    public void backtrack(int[] nums, LinkedList<Integer> track, int n) {
        if (track.size() == 3) {
            int sum = 0;
            for (int i : track) {
                sum += i;
            }

            if (sum == 0) {
                LinkedList<Integer> temp = new LinkedList<>(track);
                Collections.sort(temp);
                if (!res.contains(temp)) {
                    res.add(temp);
                }
            }
            return;
        }

        for (int i = n; i < nums.length; i++) {

            track.add(nums[i]);
            backtrack(nums, track, i + 1);
            track.removeLast();
        }
    }
}

你可能感兴趣的:(算法)