递归进化,请叫【回溯】算法!
递归:关注代码实现
回溯:关注问题解决
「回溯是递归的副产品,只要有递归就会有回溯」
回溯法就是暴力搜索,并不是什么高效的算法,最多再剪枝一下。
回溯法------》深度优先搜索
回溯算法能解决如下问题:
- 组合问题:N个数里面按一定规则找出k个数的集合
- 排列问题:N个数按一定规则全排列,有几种排列方式
- 切割问题:一个字符串按一定规则有几种切割方式
- 子集问题:一个N个数的集合里有多少符合条件的子集
- 棋盘问题:N皇后,解数独等等
回溯算法设计心法:
1. 脑中浮现:问题状态搜索树
2. 切勿妄想一步到位:先实现,再优化
3. 搜索剪枝记心尖:无招胜有招
回溯法模板:
1. 递归函数(参数和返回值) 2. 确定终止条件 3. 单层递归逻辑 //一定要分成横纵两个方面思考回溯 void backtracking(参数) { if (终止条件) { 存放结果; return; } for (选择:本层集合中元素(树中节点孩子的数量就是集合的大小)) { 处理节点; backtracking(路径,选择列表); //递归 回溯,撤销处理结果 } }
/*********************************************************
从1——n这n个整数中随机选取m个,每种方案里的数从小到大排列,
按字典序输出所有可能的选择方案。
输入:
输入两个整数n,m(1 <= m <= n <= 10)
输出:
每行一组方案,每组方案中两个数之间用空格分隔。
注意每行最后一个数后没有空格。
样例输入:
3 2
样例输出:
1 2
1 3
2 3
*********************************************************/
// 1)第一种写法
#include
int a[10];
int ind = 0;
void print_one_result(int n) {
for (int i = 0; i < n; i++) {
if (i) {
putchar(' ');
}
printf("%d", a[i]);
}
printf("\n");
}
//j:当前这个位置可以选取的最小数字
void func(int j, int n, int m) {
if (ind == m) { //边界
print_one_result(ind);
return ;
}
//剪枝优化:m-ind表示枚举到m还需要多少个,n+1-k表示后面还剩下的数字数量
for (int k = j; m - i <= n + 1 - k; k++) {
a[ind++] = k;
func(k + 1, n, m);
ind--;
}
return ;
}
int main(int argc, char *argv[]) {
int n, m; //n:最大可以选取的数字 m:随机选取多少个
scanf("%d%d", &n, &m);
func(1, n, m);
return 0;
}
//2) 第二种写法
#include
int a[10];
int ind = 0;
void print_one_result(int n) {
for (int i = 0; i < n; i++) {
if (i) {
putchar(' ');
}
printf("%d", a[i]);
}
printf("\n");
}
void func(int k, int n, int m)
{
//失败出口
if (m - ind > n + 1 - k) {
return ;
}
//成功出口
if (ind == m) {
print_one_result(ind);
return ;
}
//把起始位置的数据k放到数组里面, ind:表示已经配好的数据个数
a[ind++] = k;
//继续向后分析
func(k + 1, n, m);
//把起始位置的数据k从数组里面拿出来
ind--;
//继续向后分析
func(k + 1, n, m);
return ;
}
int main(int argc, char *argv[]) {
int n, m; //n:最大可以选取的数字 m:随机选取多少个
scanf("%d%d", &n, &m);
func(1, n, m);
return 0;
}
/*********************************************************
组合总和:
[1-9] 选取k个数,和为n
从数字1-9中选取k个数,输出所有和为n的组合
例如:
k=3,n=9 [1、2、6] [1、3、5] [2、3、4]
k=2,n=4; [1、3]
*********************************************************/
// 1)第一种写法
#include
int a[9];
int ind = 0;
void print_one_result(int n) {
for (int i = 0; i < n; i++) {
if (i) {
putchar(' ');
}
printf("%d", a[i]);
}
printf("\n");
}
void backtrack(int startindex, int sum, int k, int n) {
if (ind == k) { //边界
if (sum == n)
print_one_result(ind);
return ;
}
//剪枝优化:n-ind表示枚举到n还需要多少个,10-i表示后面还剩下的数字数量
for (int i = startindex; k - ind <= 10 - i; i++) {
sum += i;
a[ind++] = i;
backtrack(i + 1, sum, k, n);
ind--;
sum -= i;
}
return ;
}
int main(int argc, char *argv[]) {
int k, n;
scanf("%d%d", &k, &n); //选取k个数,和为n
backtrack(1, 0, k, n);
return 0;
}
// 2)第二种写法
#include
int a[9];
int ind = 0;
void print_one_result(int n) {
for (int i = 0; i < n; i++) {
if (i) {
putchar(' ');
}
printf("%d", a[i]);
}
printf("\n");
}
void backtrack(int startindex, int sum, int k, int n) {
//失败出口
if (k - ind > 10 - startindex) //剪枝优化:n-ind表示枚举到n还需要多少个,10-i表示后面还剩下的数字数量
return ;
if (ind == k) { //边界
if (sum == n)
print_one_result(ind);
return ;
}
sum += startindex;
a[ind++] = startindex;
backtrack(startindex + 1, sum, k, n);
//回溯
ind--;
sum -= startindex;
backtrack(startindex + 1, sum, k, n);
return ;
}
int main(int argc, char *argv[]) {
int k, n;
scanf("%d%d", &k, &n); //选取k个数,和为n
backtrack(1, 0, k, n);
return 0;
}
代码实现:
#define NUM_MAX 256 //4*4*4*4=256 //转化电话按键 char word[10][5] = { "\0", "\0", "abc", "def", "ghi", "jkl", "mno", "pqrs", "tuv", "wxyz" }; void backtrack(int len, int step, char *digits, char **p, char *a, int *returnSize) { //如果满足条件,添加字母组合 if (step == len) { p[*returnSize] = (char *)malloc(len + 1); //注意:len必须+1 strcpy(p[*returnSize], a); //难点:a先出现 字符数组内存从小到大连续存储(节约内存) (*returnSize)++; return; } //可选择的列表 for (int i = 0; i < strlen(word[digits[step] - '0']); i++) { //做选择 a[step] = word[digits[step] - '0'][i]; //回溯 backtrack(len, step + 1, digits, p, a, returnSize); //取消选择 //注意:此时不需要数组下标--;因为回溯之后已经减了 } return; } /** * Note: The returned array must be malloced, assume caller calls free(). */ char **letterCombinations(char *digits, int *returnSize) { //特殊用例1 if (digits == NULL) { *returnSize = 0; return NULL; } //特殊用例2 int str_len = strlen(digits); if (str_len == 0) { *returnSize = 0; return NULL; } //存储结果分配内存 char **p = (char **)malloc(NUM_MAX * sizeof(char *)); memset(p, 0, NUM_MAX * sizeof(char *)); //临时字母组合 char *a = (char *)malloc((str_len + 1) * sizeof(char)); //注意:str_len必须+1 memset(a, 0, (str_len + 1) * sizeof(char)); *returnSize = 0; backtrack(str_len, 0, digits, p, a, returnSize); return p; }
#include
int path[10];
int ind = 0;
void print_one_result(int n) {
for (int i = 0; i < n; i++) {
if (i) {
putchar(' ');
}
printf("%d", a[i]);
}
printf("\n");
}
void backtrack(int *candidate, int target, int n, int sum, int startindex)
{
//失败出口
if (sum > target)
return ;
if (sum == target)
{
print_one_result(ind);
return ;
}
for (int i = startindex; i < n; i++)
{
path[ind++] = candidate[i];
sum += candidate[i];
backtrack(candidate, target, n, sum, i);
// 回溯
ind--;
sum -= candidate[i];
}
return ;
}
int main(int argc, char *argv[])
{
int n, target;
scanf("%d%d", &n, &target);
int candidate[n];
for (int i = 0; i < n; i++)
scanf("%d", &candidate[i]);
backtrack(candidate, target, n, 0, 0);
return 0;
}
代码实现:
/* arr[10, 1, 2, 7, 6, 1, 5] targetSum = 8 [1、1、6] [1、2、5] [1、7] [2、6] 1、将数组排序 2、树层去重 和 树枝去重 */ #include
#include int path[10]; int ind = 0; void print_one_result(int n) { for (int i = 0; i < n; i++) { if (i) { putchar(' '); } printf("%d", a[i]); } printf("\n"); } void backtrack(int *arr, int n, int targetSum, int sum, int startindex, int *used) { if (sum > targetSum) return ; if (targetSum == sum) { print_one_result(ind); return ; } for (int i = startindex; i < n; i++) { if (i > 0 && arr[i] == arr[i - 1] && used[i - 1] == 0) //去重 continue; path[ind++] = arr[i]; sum += arr[i]; used[i] = 1; //用过 backtrack(arr, n, targetSum, sum, i + 1, used); ind--; sum -= arr[i]; used[i] = 0; } return ; } //冒泡排序 void bubble_sort(int *arr, int l, int n) //l:已排好序的元素个数 n:元素总个数 { int flag = 1; //标记位优化 当序列在找到所有的最大值之前就已经将序列排好序了,直接结束循环 for (int i = n - 1; i >= l + 1 && flag; i--) //i决定了哪一轮冒泡 { flag = 0; for (int j = l; j < i; j++) //j决定哪两个元素进行比较 { if (arr[j] > arr[j + 1]) { flag = 1; int temp = arr[j]; arr[j] = arr[j + 1]; arr[j + 1] = temp; } } } } int main(int argc, char *argv[]) { int n, targetSum; scanf("%d%d", &n, &targetSum); int arr[n]; for (int i = 0; i < n; i++) scanf("%d", &arr[i]); // 排序 bubble_sort(arr, 0, n); int used[n]; memset(used, 0, sizeof(used)); backtrack(arr, n, targetSum, 0, 0, used); return 0; }
- 切割问题其实类似组合问题
- 如何模拟那些切割线
- 切割问题中递归如何终止
- 在递归循环中如何截取子串
- 如何判断回文
重点:每个子串范围 [startindex, i]
代码实现:
#include
#include #include char path[10]; int ind = 0; //判断是否为回文串 bool judge(char *str, int start, int end) { for (int i = start, j = end; i < j; i++, j--) { if (str[i] != str[j]) return false; } return true; } void print_one_result(int n) { for (int i = 0; i < n; i++) { if (i) putchar(' '); printf("%c", path[i]); } putchar('\n'); return ; } void func(char *path, char *str, int *ind, int startindex, int n, int str_len) { for (int i = startindex; i <= n; i++) path[(*ind)++] = str[i]; //*ind++ :错误(++优先级高于*) if (str_len != n + 1) //n代表数组下标,须+1 path[(*ind)++] = '|'; //分隔符 return ; } void backtracking(char *str, int startindex, int str_len) { if (startindex == str_len) { print_one_result(ind); return ; } for (int i = startindex; i < str_len; i++) { // 重点:每个子串范围 [startindex, i] if (judge(str, startindex, i)) func(path, str, &ind, startindex, i, str_len); else continue; backtracking(str, i + 1, str_len); ind -= i - startindex + 1; } return ; } int main(int argc, char *argv[]) { char str[10]; gets(str); int str_len = strlen(str); backtracking(str, 0, str_len); return 0; }
代码实现:
(边递归,边存值)
「在树形结构中子集问题是要收集所有节点的结果,而组合问题是收集叶子节点的结果」
代码实现:
#include
int path[10]; int ind = 0; void print_one_result(int n) { printf("["); for (int i = 0; i < n; i++) { if (i) putchar(' '); printf("%d", path[i]); } printf("]\n"); return ; } void backtracking(int *arr, int n, int startindex) { // if (startindex == n) //可以删除 // return ; for (int i = startindex; i < n; i++) { path[ind++] = arr[i]; print_one_result(ind); backtracking(arr, n, i + 1); ind--; } return ; } int main(int argc, char *argv[]) { int n; scanf("%d", &n); int arr[n]; for (int i = 0; i < n; i++) scanf("%d", arr + i); backtracking(arr, n, 0); // 输出最后一个空集 printf("[]\n"); return 0; }
代码实现:
/* arr[1、2、2] [1] [1、2] [1、2、2] [2] [2、2] [] 1、将数组排序 2、树层去重 和 树枝去重 */ #include
#include int path[10]; int ind = 0; void print_one_result(int n) { printf("["); for (int i = 0; i < n; i++) { if (i) putchar(' '); printf("%d", path[i]); } printf("]\n"); return ; } void backtracking(int *arr, int *used, int n, int startindex) { // if (startindex == n) //可以删除 // return ; for (int i = startindex; i < n; i++) { if (i > 0 && arr[i] == arr[i - 1] && used[i - 1] == 0) //去重 continue; path[ind++] = arr[i]; used[i] = 1; //用过 print_one_result(ind); backtracking(arr, used, n, i + 1); // 回溯 ind--; used[i] = 0; } return ; } // 冒泡排序 void bubble_sort(int *arr, int l, int n) //l:已排好序的元素个数 n:元素总个数 { int flag = 1; //标记位优化 当序列在找到所有的最大值之前就已经将序列排好序了,直接结束循环 for (int i = n - 1; i >= l + 1 && flag; i--) //i决定了哪一轮冒泡 { flag = 0; for (int j = l; j < i; j++) //j决定哪两个元素进行比较 { if (arr[j] > arr[j + 1]) { flag = 1; int temp = arr[j]; arr[j] = arr[j + 1]; arr[j + 1] = temp; } } } } int main(int argc, char *argv[]) { int n; scanf("%d", &n); int arr[n]; for (int i = 0; i < n; i++) scanf("%d", arr + i); // 排序 bubble_sort(arr, 0, n); int used[n]; //标记:用于去重 memset(used, 0, sizeof(used)); backtracking(arr, used, n, 0); // 输出最后一个空集 printf("[]\n"); return 0; }
代码实现:
- 每层都是从0开始搜索而不是startIndex
- 需要used数组记录path里都放了哪些元素了
代码实现:
#include
int path[10]; int ind = 0; int used[10] = {0}; //未用:0 用过:1 void print_one_result(int n) { for (int i = 0; i < n; i++) { if (i) { putchar(' '); } printf("%d", path[i]); } printf("\n"); } void backtracking(int *arr, int n) { if (ind == n) { print_one_result(ind); return ; } for (int i = 0; i < n; i++) { if (used[i]) continue; used[i] = 1; path[ind++] = arr[i]; //回溯 backtracking(arr, n); used[i] = 0; ind--; } return ; } int main(int argc, char *argv[]) { int n; scanf("%d", &n); int arr[n]; //注意:41行、42行位置不能调换 for (int i = 0; i < n; i++) scanf("%d", arr + i); backtracking(arr, n); return 0; }
/************************************************************************
请设计一个函数,用来判断在一个矩阵中是否存在一条包含某字符串所有字符的路径。
路径可以从矩阵中任意一格开始,每一步可以在矩阵中向左、右、上、下移到一格。如果
一条路径经过了矩阵的某一格,那么该路径不能再次进入该格子。例如在下面的3×4的矩
阵中包含一条字符串“bfce”的路径(路径中的字母用下划线标出)。但矩阵中不包含字符
串“abfb”的路径,因为字符串的第一个字符b占据了矩阵中的第一行第二格之后,路径不能
再次进入这个格子。
A B T G
C F C S
J D E H
************************************************************************/
#include
#include
#include
#define ROWS 3
#define COLS 4
char map[ROWS][COLS] = {
'A', 'B', 'T', 'G',
'C', 'F', 'C', 'S',
'J', 'D', 'E', 'H'
};
/*
0 1 2 3
4 5 6 7
8 9 10 11
*/
const char *str = "SCFD"; //目标路径
int path[ROWS * COLS]; // 7 6 5 9 存放路径
int pathLength = 0;
bool hasPath(int row, int col, bool visited[ROWS][COLS])
{
//合法边界
if (pathLength == strlen(str))
return true;
bool ret = false;
if (row >= 0 && row < ROWS && col >= 0 && col < COLS &&
map[row][col] == str[pathLength] && visited[row][col] == false)
{
path[pathLength++] = row * COLS + col;
visited[row][col] = true;
ret = hasPath(row - 1, col, visited) || //上
hasPath(row + 1, col, visited) || //下
hasPath(row, col - 1, visited) || //左
hasPath(row, col + 1, visited); //右
if (ret == false)
{
pathLength--;
path[pathLength] = -1; //删除刚才的结点
visited[row][col] = false;
}
}
return ret;
}
bool findPath()
{
// 0: 没有走过 1: 走过
bool visited[ROWS][COLS]; //表示各个位置,是否已经走过
memset(path, -1, sizeof(path));
memset(visited, 0, sizeof(visited));
for (int row = 0; row < ROWS; row++)
{
for (int col = 0; col < COLS; col++)
{
if (hasPath(row, col, visited))
return true;
}
}
return false;
}
int main(int argc, char *argv[])
{
if (findPath())
{
printf("查找成功\n");
for (int i = 0; path[i] >= 0; i++)
printf("%d ", path[i]);
putchar('\n');
}
else
printf("查找失败\n");
return 0;
}
- 不能同行
- 不能同列
- 不能同斜线 (45度和135度角)
代码实现:
bool isValid(char **res, int row, int col, int n) { //Valid:有效的,合理的 //检查列 for (int i = 0; i < row; i++) { // 这是一个剪枝 if (res[i][col] == 'Q') { return false; } } //检查 45度角是否有皇后 for (int i = row - 1, j = col - 1; i >= 0 && j >= 0; i--, j--) { if (res[i][j] == 'Q') { return false; } } //检查 135度角是否有皇后 for(int i = row - 1, j = col + 1; i >= 0 && j < n; i--, j++) { if (res[i][j] == 'Q') { return false; } } return true; } //index:行 void dfs(char ***ans, int *returnSize, int **returnColumnSizes, char **res, int index, int n) { //枚举到最后一行,保存当前棋盘,因为如果不能成为皇后就不能进入下一行 //所以能到最后一行肯定是有效的 if (index == n) { ans[(*returnSize)] = (char **)malloc(sizeof(char *) * n); for(int i = 0; i < n; i++) { ans[(*returnSize)][i] = (char *)malloc(sizeof(char) * (n + 1)); memcpy(ans[(*returnSize)][i], res[i], sizeof(char) * (n + 1)); } (*returnColumnSizes)[(*returnSize)++] = n; return; } //枚举当前行中每一个元素 for (int i = 0; i < n; i++) { //检查当前位置能不能放皇后 if (isValid(res, index, i, n) == true) { res[index][i] = 'Q'; //可以直接放皇后 dfs(ans, returnSize, returnColumnSizes, res, index + 1, n); //进入下一行,重复当前步骤 res[index][i] = '.'; //回溯,进行判断下一个元素 } } return ; } //returnSize:这是一个指向整数的指针,用于输出 N 皇后问题的解的数目。 //returnColumnSizes:这是一个指向整数数组的指针,指向一个大小为 returnSize 的数组,用于存储每个解的列数。 char ***solveNQueens(int n, int *returnSize, int **returnColumnSizes) { char ***ans = (char ***)malloc(sizeof(char **) * 1000); *returnColumnSizes = (int *)malloc(sizeof(int) * 1000); //初始化变量 *returnSize = 0; char **res = (char **)malloc(sizeof(char *) * n); //定义棋盘并初始化 for (int i = 0; i < n; i++) { res[i] = (char *)malloc(sizeof(char) * (n + 1)); memset(res[i], '.', sizeof(char) * n); res[i][n] = '\0'; } //递归枚举每一个位置 dfs(ans, returnSize, returnColumnSizes, res, 0, n); return ans; }
代码实现:
#include
#define MAX_N 200 #define MAX 100000 int to[MAX_N + 1]; int dis[MAX_N + 1]; void dfs(int k, int a, int n) { if (dis[a] <= k) return ; dis[a] = k; if (a + to[a] <= n) { dfs(k + 1, a + to[a], n); } if (a - to[a] >= 1) { dfs(k + 1, a - to[a], n); } return ; } int main(int argc, char *argv[]) { int n, a, b; scanf("%d%d%d", &n, &a, &b); for (int i = 1; i <= n; i++) { scanf("%d", to + i); } for (int i = 1; i <= n; i++) { dis[i] = MAX; } dfs(0, a, n); printf("%d\n", dis[b] == MAX ? -1 : dis[b]); return 0; }
代码实现:
#include
#include #define MAX_N 20 int val[MAX_N + 1]; int ans = 0; bool is_prime(int x) { for (int i = 2; i <= x / 2; i++) { if (x % i == 0) { return false; } } return true; } void dfs(int ind, int startindex, int n, int k, int sum) { if (ind == k) { if (is_prime(sum)) { ans += 1; } return ; } for (int i = startindex; i <= n; i++) { dfs(ind + 1, i + 1, n, k, sum + val[i]); } return ; } int main(int argc, char *argv[]) { int n, k; scanf("%d%d", &n, &k); for (int i = 1; i <= n; i++) { scanf("%d", val + i); } dfs(0, 1, n, k, 0); printf("%d\n", ans); return 0; }
代码实现:
#include
int path[10]; int ind = 0; void print_one_result(int n) { if (n <= 1) { return ; } for (int i = 0; i < n; i++) { if (i) { putchar('+'); } printf("%d", path[i]); } printf("\n"); } void dfs(int startindex, int n) { if (n == 0) { print_one_result(ind); return ; } for (int i = startindex; i <= n; i++) { path[ind++] = i; n -= i; dfs(i, n); ind--; n += i; } } int main(int argc, char *argv[]) { int n; scanf("%d", &n); dfs(1, n); return 0; }
代码实现: