A*算法求解N数码问题的设计与实现
Table of Contents
任务要求:
1. 关于A*算法:
2. 算法复杂度:
3. Solution:
4. CODE
A*算法的核心代价函数的设计,loss = g + h,g为当前状态深度,h是关键,h代表当前状态到达目标状态的估计值,h必须满足某些条件,而最重要的条件是h < r,r为当前状态到目标状态的估计值。
以实例理解上述条件。例如在八数码问题中,h可以被设计为“各个数到目标状态需要走的步数的和”,这显然是小于真实需要步数和的,h也可以被设计为“各个数和目标状态不同的个数和”,这个条件显然比第一个条件更加宽松,必然小于真实需要步数。以上即为本次设计的两个代价函数。
另一个例子是连续地图中的寻路算法,h一般设计为欧式距离,即直线距离,这显然比真实需要走的路要短(可能有障碍,弯路等等)。
通过以上两个例子,我们可以直观地理解,为什么估计值h必须小于真实值,但要尽可能的大,接近真实值的下界。例如寻路算法中,你估算的距离越接近真实距离,那么你启发式找到可能的路径就会越准确。理论上,可以严格证明满足这些条件,必然可以找到最优解。
从另一个角度来审视A*算法,它可以视为以代价为步长的广度优先算法,这一点要从代码实现上才能感受到。每次都优先处理代价最小的状态,如果观察搜索树,将会看到它在整个搜索数的节点之间无序跳动(选全局估计代价最小)。
再次,例如在寻路算法中,A*算法在实际运行中,类似于我们人类在找两点之间的最快路径,尽管我们无法确定中间要从何处绕开障碍,但是我们却知道要忘目标靠。下面是A*算法运行实例,它在每个状态,都会对目标计算一次欧式距离,以此约束选择,就仿佛被欧式距离牵引到目标状态一样。
init_state
[[3 1 7]
[6 8 0]
[4 2 5]]
target
[[7 2 1]
[4 8 6]
[3 0 5]]
----------A star with cost_function 1-----------
27 步之后能到达目标
生成节点数 5383
耗时:0.9075729846954346
----------A star with cost_function 2-----------
27 步之后能到达目标
生成节点数 181440
耗时:29.369495630264282
----------breadth search-----------
27 步之后能到达目标
生成节点数 9698
耗时:19.344292402267456
=========================================================
init_state
[[2 6 7]
[5 0 3]
[4 1 8]]
target
[[5 2 3]
[4 1 0]
[8 7 6]]
----------A star with cost_function 1-----------
16 步之后能到达目标
生成节点数 115
耗时:0.016954660415649414
----------A star with cost_function 2-----------
16 步之后能到达目标
生成节点数 132050
耗时:18.046760320663452
----------breadth search-----------
16 步之后能到达目标
生成节点数 3629
耗时:0.7639575004577637
===================================================================
init_state
[[2 8 7]
[4 5 1]
[0 6 3]]
target
[[5 7 4]
[6 1 8]
[3 2 0]]
----------A star with cost_function 1-----------
无解
耗时:0.0
----------A star with cost_function 2-----------
无解
耗时:0.0
----------breadth search-----------
无解
耗时:0.0009975433349609375
==================================================================
init_state
[[3 2 1]
[6 7 5]
[4 8 0]]
target
[[1 8 0]
[5 2 6]
[3 7 4]]
----------A star with cost_function 1-----------
25 步之后能到达目标
生成节点数 1805
耗时:0.26030421257019043
----------A star with cost_function 2-----------
27 步之后能到达目标
生成节点数 181440
耗时:31.541586875915527
----------breadth search-----------
25 步之后能到达目标
生成节点数 15633
耗时:16.068492650985718
import numpy as np
from queue import PriorityQueue, Queue
import time
np.random.seed(0)
# 计算g(n)
def g(cur):
return cur.depth
# 计算h(n)
def h(cur):
h = 0
for x1 in range(3):
for y1 in range(3):
if cur.data[x1][y1] == 0:
continue
x2,y2 = n2posi[cur.data[x1][y1]]
# print(x1,y1,np.abs(x1 - x2) + np.abs(y1 - y2))
h += np.abs(x1 - x2) + np.abs(y1 - y2)
return h
def h2(cur):
return np.sum(cur.data == target_state.data)
# 计算cost(n)
def cost(cur):
return g(cur) + h(cur)
def cost2(cur):
return g(cur) + h2(cur)
# 测试是否有界
def is_solutable(init_, target):
sum_a = 0
sum_b = 0
a = init_.data.reshape(-1)
b = target.data.reshape(-1)
for i in range(len(a)):
sum_a += np.sum((a[i:] < a[i])&(a[i:]!=0))
for i in range(len(b)):
sum_b += np.sum((b[i:] < b[i])&(b[i:]!=0))
# print(sum_a, sum_b)
return (sum_a%2) == (sum_b%2)
# a star
def a_star(init_state, target_state, cost):
if is_solutable(init_state,target_state) == False:
print("无解")
return None
opens_ = PriorityQueue() # 存放已观察未访问节点
closes_ = PriorityQueue() # 存放已经访问节点
states = {}
directions = [(-1,0), (0, -1), (1, 0), (0, 1)]
opens_.put(init_state)
cur = init_state
while(True):
# 获得当前代价最小的进行访问
cur = opens_.get()
closes_.put(cur)
if cur == target_state:
break
for dx,dy in directions:
x,y = cur.position()
x_n,y_n = x+dx, y+dy
# 越界跳出
if x_n < 0 or x_n >= 3 or y_n < 0 or y_n >= 3:
continue
# 移动空白,创建新节点
data = cur.data.copy()
depth = cur.depth + 1
data[x,y],data[x_n,y_n] = data[x_n,y_n],data[x,y]
new_state = State(data, depth, x_n, y_n)
new_state.root = cur
new_state.cost_ = cost(new_state)
# 该状态是否已经访问过,如果是则更新状态
if new_state in states:
if new_state.cost_ < states[new_state].cost_:
# 更新
states[new_state].root = cur
states[new_state].cost_ = new_state.cost_
# 更新两个最小堆
opens_.put(opens_.get())
closes_.put(closes_.get())
else:
states[new_state] = new_state
opens_.put(new_state)
r = cur
num = 0
# 输出路径
while True:
if cur is None:
break
num += 1
# print("----^----")
# print(cur)
cur = cur.root
print('{} 步之后能到达目标'.format(num))
print('生成节点数 {}'.format(len(states)))
return r
# 广度搜索
def bf_search(init_state, target_state):
if is_solutable(init_state, target_state) == False:
print("无解")
return
q = Queue() # 队列
states = {} # 访问状态记录
directions = [(-1,0), (0, -1), (1, 0), (0, 1)]
states[init_state] = init_state
q.put(init_state)
while q.empty() != True:
cur = q.get()
if cur == target_state:
break
for dx,dy in directions:
x,y = cur.position()
x_n,y_n = x+dx, y+dy
# 越界跳出
if x_n < 0 or x_n >= 3 or y_n < 0 or y_n >= 3:
continue
# 移动空白,创建新节点
data = cur.data.copy()
depth = cur.depth + 1
data[x,y],data[x_n,y_n] = data[x_n,y_n],data[x,y]
new_state = State(data, depth, x_n, y_n)
new_state.root = cur
# 是否访问过,如果是且更近更新状态
if new_state in states:
if new_state.depth < states[new_state].depth:
states[new_state].depth = new_state.depth
states[new_state].root = cur
else:
states[new_state] = new_state
q.put(new_state)
r = cur
num = 0
# 输出路径
while True:
if cur is None:
break
num += 1
# print("----^----")
# print(cur)
cur = cur.root
print('{} 步之后能到达目标'.format(num))
print('生成节点数 {}'.format(q.qsize()))
class State:
def __init__(self, data, depth, blank_x, blank_y, cost_ = 1000):
self.data = data
# fac 用于状态hash值的计算
self.fac = np.array([10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000])
self.depth = depth
self.x = blank_x
self.y = blank_y
self.cost_ = cost_
self.root = None
# 返回空白位置
def position(self):
return self.x, self.y
# 重载方法,便于在字典,最小堆等数据结构的使用自定义类
def __hash__(self):
return int(np.dot(self.data.reshape(-1), self.fac))
def __eq__(self, other):
return np.sum(self.data == other.data)==9
def __str__(self):
return str(self.data)
def __lt__(self, other):
return self.cost_ < other.cost_
# 目标状态
target = np.random.choice(range(9), (3,3), replace=False).reshape(3, 3)
x,y = np.where(target == 0)
target_state = State(target, 0, x[0], y[0], cost_ = 0)
# 保存目标矩阵各个数字的position
n2posi = {}
for i in range(3):
for j in range(3):
n2posi[target[i][j]] = (i,j)
'''
a_star 代价函数一
'''
cost_func = cost
# 初始状态
init_ = np.random.choice(range(9), (3,3), replace=False).reshape(3, 3)
x,y = np.where(init_ == 0)
init_state = State(init_, 0, x[0], y[0])
init_state.cost_ = cost_func(init_state)
print("init_state")
print(init_state)
print("target")
print(target)
print('----------A star with cost_function 1-----------')
start_time = time.time()
a_star(init_state,target_state, cost_func)
end_time = time.time()
print("耗时:{}".format(end_time - start_time))
'''
a_star 代价函数二
'''
print('----------A star with cost_function 2-----------')
cost_func = cost2
# 初始状态
init_state = State(init_, 0, x[0], y[0])
init_state.cost_ = cost_func(init_state)
# print("init_state")
# print(init_state)
# print("target")
# print(target)
start_time = time.time()
a_star(init_state,target_state, cost_func)
end_time = time.time()
print("耗时:{}".format(end_time - start_time))
'''
广搜
'''
print('----------breadth search-----------')
# 初始状态
x,y = np.where(init_ == 0)
init_state = State(init_, 0, x[0], y[0])
init_state.cost_ = cost_func(init_state)
# print("init_state")
# print(init_state)
# print("target")
# print(target)
start_time = time.time()
bf_search(init_state, target_state)
end_time = time.time()
print("耗时:{}".format(end_time - start_time))