搜索问题之状态空间搜索(状态压缩+记忆化搜索+ BFS)

文章目录

        • 1. 前言
        • 2. 问题举例(九宫格问题)
        • 3. 问题分析
          • 3.1 状态编码与解码
          • 3.2 哈希映射
          • 3.3 集合判重
        • 4. 问题实现
        • 推荐阅读

1. 前言

之前介绍的回溯法常用于 解空间的搜索 问题,即 找到一个或者所有满足约束条件的解,它通常是将解空间组织成树或者图,然后进行DFS(深度优先遍历)并注意在搜索的时候进行 剪枝操作。

但是状态空间搜索则是需要 找到一条从起始状态到终止状态的路径,其一般需要考虑一下问题:

  • 状态的表示,即我们怎样表示一个状态。
  • 状态的转移,即通过研究初始状态和目标状态的 差别 ,我们定义怎样的操作来进行状态的转移。
  • 状态的压缩和记忆化搜索,即我们如何压缩一个状态的表示,使得 我们能够存储已经搜索过的状态的结果。这样能够避免大量重复状态的搜索。

2. 问题举例(九宫格问题)

下面用经典的搜索问题 八数码(九宫格问题) 举例,参考算法竞赛经典入门 第二版。
搜索问题之状态空间搜索(状态压缩+记忆化搜索+ BFS)_第1张图片
OJ例题:

  • POJ 1077 Eight
  • HDUOJ 1043 需要离线打表

3. 问题分析

题中所说要找到移动步数最少的路径,即我们需要从起始状态到目标状态进行BDS(广度优先搜索),而我们需要考虑以下问题:

如何表示状态
显然,最直接的方式是直接用一个 3X3 的二维矩阵,简化为一个 1X9 的一维数组。

怎么进行状态压缩和记忆化搜索?
我们可以直接申请一个9维数组vis,然后根据vis[s0][s1][s2][s3][s4][s5][s6][s7][s8] 是否等于1来判重,需要的数组大小为 9 9 = 387 , 420 , 489 9 ^9 = 387,420,489 99=387,420,489 项,太多了,而且实际上最多的结点数也只有 0~8 的全排列 9 ! = 362 , 880 9! = 362,880 9!=362,880 项而已。

所以,如何进行状态压缩,常见右3种思路:

3.1 状态编码与解码

将每一种状态与一个整数编码一一对应起来,然后只开一个一维数组来判重。
而本题就是将 0 ~ 8 的排列数与 0 ~ 362879 对应起来,常见的方式是康拓展开,即 我们将一个排列数与其在所有排列中的字典序一一对应起来,例如 012345678 <–> 0 , 876543210 <–> 362879。

这种方法时间效率高,但是当状态空间的结点总数非常大时,编码也会很大,因为是一一对应的。

3.2 哈希映射

这种方法也是将状态映射成整数,但是不必是一一对应的。他可以映射到一个 [ 0 , M − 1 ] [0,M-1] [0,M1] 范围内的整数,然后开一个 M 大小的数组来存储,相同哈希值的存放在一起,例如使用 链表 连在一起,称为一个 桶(bucket)。这种方法注意三点:

  • 哈希表的大小M设置为多少,一般M越大,冲突的概率会比较小。
  • 哈希函数怎么设置,即如何将状态映射成整数。在 哈希表中,哈希函数的作用很关键,一个设置良好的哈希函数应该保证哈希值的冲突尽可能少。这里,我们的哈希函数可以直接将状态映射成一个9位的排列数,然后对M取余。
  • 冲突怎么解决,我们可以将相同哈希值的元素用链表连起来,也可以设置一个规则,如果冲突了,则向后或者向前移动几位等等。
3.3 集合判重

我们可以用一个STL中的集合来存储访问过的排列数来进行判重,但是,STL底层是基于红黑树的,其插入和查找的复杂度都在 O ( l o g n ) O(logn) O(logn) 而编码和哈希在最好情况下(哈希的冲突为0)是数组的直接索引,复杂度在 O ( 1 ) O(1) O(1)。当然,使用STL的代码比较简洁,我们可以先用它来实现判重,然后再验证程序其他部分的正确性,然后转化为编码或者哈希表。

4. 问题实现

先给出这个问题的BFS的大致框架:

typedef int State[9]; // 定义状态,九宫格
const int maxn = 0x7fffff;  // 最多的可能状态

State st[maxn]; // 存储状态
State goal; // 目标状态

int fa[maxn]; // 存储状态的前一状态
char pre[maxn]; // 存储前一状态变化到当前状态所用的操作
const char op[5] = "udlr";

const int dx[] = { -1,1,0,0 };
const int dy[] = { 0,0,-1,1 }; // 四个方向,上,下,左,右

void init_lookup_table();
int try_to_insert(int s);

void printState(State& s);

int bfs() {
	// 若成功,则返回目标状态在状态数组中的位置
	init_lookup_table(); // 初始化查找表
	int front = 1, rear = 2;
	while (front < rear) {
		State& s = st[front]; // 使用“引用”指向同一片内存,节省赋值操作
		//printState(s);
		if (memcmp(s, goal, sizeof(s)) == 0) return front; // 成功
		int z; // 0 的位置,即空格
		for (z = 0; s[z] != 0; z++);
		int x = z / 3, y = z % 3;
		for (int i = 0; i < 4; i++) {
			int newx = x + dx[i];
			int newy = y + dy[i];
			int newz = newx * 3 + newy;
			if (newx >= 0 && newx < 3 && newy >= 0 && newy < 3) {
				State& t = st[rear]; // 新状态
				memcpy(t, s, sizeof(s));
				t[newz] = s[z];
				t[z] = s[newz];
				fa[rear] = front;
				pre[rear] = op[i];
				if (try_to_insert(rear))
					rear++; // 此状态没有出现过
			}// if
		}// for
		front++;
	}
	return 0;
}

其中,init_look_table() 和 try_insert() 就是我们的判重操作,即初始化查找表和判断该状态是否已经搜索过。也就是我们上面所说的3种判重方式:

集合判重

set<int> vis;
void init_lookup_table() {
	vis.clear();
}
int try_to_insert(int s) {
	// 试图插入一个状态
	State& ma = st[s];
	int num = 0; // 转换为一个9位数
	for (int i = 0; i < 9; i++) num = num * 10 + ma[i];
	if (vis.count(num)) return 0;
	else {
		vis.insert(num);
		return 1;
	}
}

简单但是效率低。

哈希表

const int hashsize = 1e+6 + 3; // 哈希表的大小
int head[hashsize], Next[maxn]; // 哈希链表

void init_lookup_table() {
	memset(head, 0, sizeof(head));
}
int hashfunc(State& s) {
	// 一个状态的哈希函数
	int num = 0;
	for (int i = 0; i < 9; i++) num = num * 10 + s[i];
	return num % hashsize;
}
int try_to_insert(int s) {
	// 试图插入一个状态
	State& ma = st[s];
	int h = hashfunc(ma);
	int u = head[h];
	// 查找状态
	while (u) {
		if (memcmp(ma, st[u], sizeof(ma)) == 0) return 0; // 已经存在了
		u = Next[u];
	}
	// 头插法插入结点
	Next[s] = head[h];
	head[h] = s;
	return 1;
}
void printState(State& s) {
	for (int i = 0; i < 9; i++) {
		if(s[i]) printf("%d", s[i]);
		else printf("X");
		if ((i + 1) % 3 == 0) printf("\n");
		else printf(" ");
	}
}

编码解码

int vis[362880], fact[9]; // 判重数组和阶乘
void init_lookup_table() {
	memset(vis, 0, sizeof(vis));
	fact[0] = 1;
	for (int i = 1; i < 9; i++) fact[i] = fact[i - 1] * i;
}
int canto(State& s) {
	// 将一个状态转成康拓编码
	int code = 0;
	for (int i = 0; i < 9; i++) {
		int cnt = 0; // 计算逆序数
		for (int j = i + 1; j < 9; j++) if (s[j] < s[i]) cnt++;
		code += cnt * fact[8 - i];
	}
	return code;
}
int try_to_insert(int s) {
	// 试图插入一个状态
	int code = canto(st[s]);
	if (vis[code]) return 0;
	else return vis[code] = 1;
}

示例AC代码

/* 八数码问题 BFS中状态空间搜索 */
#include
#include
#include
#include
#include
using namespace std;

typedef int State[9]; // 定义状态,九宫格
const int maxn = 0x7fffff;  // 最多的可能状态
const int hashsize = 1e+6 + 3; // 哈希表的大小

int head[hashsize], Next[maxn]; // 哈希链表
State st[maxn]; // 存储状态
State goal; // 目标状态
int fa[maxn]; // 存储状态的前一状态
char pre[maxn]; // 存储前一状态变化到当前状态所用的操作
const char op[5] = "udlr";

const int dx[] = { -1,1,0,0 };
const int dy[] = { 0,0,-1,1 }; // 四个方向,上,下,左,右

void init_lookup_table() {
	memset(head, 0, sizeof(head));
}
int hashfunc(State& s) {
	// 一个状态的哈希函数
	int num = 0;
	for (int i = 0; i < 9; i++) num = num * 10 + s[i];
	return num % hashsize;
}
int try_to_insert(int s) {
	// 试图插入一个状态
	State& ma = st[s];
	int h = hashfunc(ma);
	int u = head[h];
	// 查找状态
	while (u) {
		if (memcmp(ma, st[u], sizeof(ma)) == 0) return 0; // 已经存在了
		u = Next[u];
	}
	// 头插法插入结点
	Next[s] = head[h];
	head[h] = s;
	return 1;
}
void printState(State& s) {
	for (int i = 0; i < 9; i++) {
		if(s[i]) printf("%d", s[i]);
		else printf("X");
		if ((i + 1) % 3 == 0) printf("\n");
		else printf(" ");
	}
}
int bfs() {
	// 若成功,则返回目标状态在状态数组中的位置
	init_lookup_table(); // 初始化查找表
	int front = 1, rear = 2;
	while (front < rear) {
		State& s = st[front]; // 使用“引用”指向同一片内存,节省赋值操作
		//printState(s);
		if (memcmp(s, goal, sizeof(s)) == 0) return front; // 成功
		int z; // 0 的位置,即空格
		for (z = 0; s[z] != 0; z++);
		int x = z / 3, y = z % 3;
		for (int i = 0; i < 4; i++) {
			int newx = x + dx[i];
			int newy = y + dy[i];
			int newz = newx * 3 + newy;
			if (newx >= 0 && newx < 3 && newy >= 0 && newy < 3) {
				State& t = st[rear]; // 新状态
				memcpy(t, s, sizeof(s));
				t[newz] = s[z];
				t[z] = s[newz];
				fa[rear] = front;
				pre[rear] = op[i];
				if (try_to_insert(rear)) 
					rear++; // 此状态没有出现过
			}// if
		}// for
		front++;
	}
	return 0;
}
void printPath(int s) {
	// 打印路径
	if (s == 1) return;
	printPath(fa[s]);
	printf("%c", pre[s]);
}
int main() {
	char c;
	for (int i = 0; i < 9; i++) {// 初始状态
		cin >> c;
		if (c == 'x') st[1][i] = 0;
		else st[1][i] = c - '0';
	}
	for (int i = 0; i < 9; i++) goal[i] = i + 1;
	goal[8] = 0; // 目标状态
	int ans = bfs();
	if (ans == 0) printf("unsolvable\n");
	else {
		printPath(ans);
		printf("\n");
	}
	return 0;
}
/*2 3 4 1 5 x 7 6 8*/

推荐阅读

  • 《算法竞赛入门经典 》7.5 路径搜索问题
  • 八数码的八种境界

你可能感兴趣的:(#,搜索,状态空间搜索)