针对问题:路径搜索
问题描述:有A,B,C,...N个节点,每个节点之间已定义能否达到以及路径代价,目标为搜索到最佳路径
八数码问题:在3*3九宫格中有1-8八个数,剩下一个为空格,每次只能移动空格一次,给定初始状态和目标状态,求得最佳移动方法和最短移动距离
A*算法和A算法的区别在于对路径计算代价的公式的要求
A算法:代价=历史代价+未来代价(启发函数),其中历史代价指的是走到当前状态的已知代价,未来代价是指当前节点到目标状态的预估代价
A*算法:要求预估的代价必定大于等于真实的未来代价
引入逆序数概念:在一个排列中,如果一对数的前后位置与大小顺序相反,即前面的数大于后面的数,那么它们就称为一个逆序。一个排列中逆序的总数就称为这个排列的逆序数。
则计算出初始状态和目标状态各自的逆序数后,若两者奇偶性一致,则可以到达,反之,则不能
def judge(number):
total = 0
data = [9, 9, 9, 9, 9, 9, 9, 9, 9]
for i in range(9):
if number[i] != '0':
data[i] = int(number[i])
else:
data[i] = 0
for i in range(9):
for j in range(i):
if data[i] * data[j] != 0:
if data[j] > data[i]:
total = total + 1
return total
def addOpen(node):
if len(open) == 0 or node[2] >= open[-1][2]:
open.append(node)
else:
for i in range(len(open)):
if node[2] < open[i][2]:
open.insert(i, node)
break
def manhattan(src, dst):
total = 0
pos_src = [[0 for x in range(2)] for y in range(9)]
pos_dst = [[0 for x in range(2)] for y in range(9)]
for i in range(9):
x = i // 3
y = i % 3
pos_src[int(dst[i])][0] = x
pos_src[int(dst[i])][1] = y
pos_dst[int(src[i])][0] = x
pos_dst[int(src[i])][1] = y
for i in range(9):
total = total + abs(pos_src[i][0]-pos_dst[i][0]) + abs(pos_src[i][1]-pos_dst[i][1])
return total
def position(src):
flag = src.index('0')
row = int(flag // 3)
col = int(flag % 3)
return [flag, row, col]
def exchange(src, x, y, x2, y2):
flag = x * 3 + y
flag2 = x2 * 3 + y2
tmp1 = src[flag]
tmp2 = src[flag2]
dst = copy.copy(src)
dst[flag] = tmp2
dst[flag2] = tmp1
return dst
5.1.6. 打印节点状态
def prtNum(src):
for x in range(3):
for y in range(3):
print(str(src[x * 3 + y] + ' '), end='')
print()
在表中每个节点所携带的信息:
[此节点描述,父节点代号,此节点代价,此节点代号,是否为目标状态]
def expand(src, side):
global crt
global nodeid
pos = position(src)
x = pos[1]
y = pos[2]
rtResult = []
# 计算节点的历史代价(通过查找其父节点直到初始节点)
depth = 0
nodePrt = open[0][4]
if nodePrt == 0:
depth = 0
else:
while True:
for i in range(len(closed)):
if nodePrt == closed[i][3]:
depth = depth + 1
nodePrt = closed[i][4]
if nodePrt == 0:
break
# 向左扩展
if side == 'left' or side == '':
if y > 0:
if_final = 0
crtLeft = exchange(src, x, y, x, y - 1)
leftCost = manhattan(numberFinal, crtLeft) + depth
nodeid = nodeid + 1
if manhattan(numberFinal, crtLeft) == 0:
if_final = 1
rtResult.append([crtLeft, src, leftCost, nodeid, if_final])
# 向右扩展
if side == 'right' or side == '':
if y < 2:
if_final = 0
crtRight = exchange(src, x, y, x, y + 1)
rightCost = manhattan(numberFinal, crtRight) + depth
nodeid = nodeid + 1
if manhattan(numberFinal, crtRight) == 0:
if_final = 1
rtResult.append([crtRight, src, rightCost, nodeid, if_final])
# 向上扩展
if side == 'up' or side == '':
if x > 0:
if_final = 0
crtUp = exchange(src, x, y, x - 1, y)
upCost = manhattan(numberFinal, crtUp) + depth
nodeid = nodeid + 1
if manhattan(numberFinal, crtUp) == 0:
if_final = 1
rtResult.append([crtUp, src, upCost, nodeid,if_final])
# 向下扩展
if side == 'down' or side == '':
if x < 2:
if_final = 0
crtDown = exchange(src, x, y, x + 1, y)
depth = depth + 1
downCost = manhattan(numberFinal, crtDown) + depth
nodeid = nodeid + 1
if manhattan(numberFinal, crtDown) == 0:
if_final = 1
rtResult.append([crtDown, src, downCost, nodeid, if_final])
return rtResult
def handleOpen():
global nodeid # 记录扩展节点数目
global open
while True:
if len(open) == 0:
break
x = 0
tmpOpen = open[0]
tmp = expand(open[0][0], '')
for y in range(len(tmp)):
flag = False
for j in range(len(open)):
if tmp[y][0] == open[j][0]:
flag = True
for k in range(len(closed)):
if tmp[y][0] == closed[k][0]:
flag = True
if not flag:
addOpen([tmp[y][0], tmp[y][1], tmp[y][2], tmp[y][3], open[x][3]])
if tmp[y][4] == 1: # 判断是否到达最终节点
closed.append(tmpOpen)
closed.append(open[0])
open.remove(open[0])
print('Totally', nodeid, 'nodes ayalyzed,find the result.')
prtResult()
print('Success!')
exit("We find it!")
closed.append(tmpOpen)
open.remove(tmpOpen)
# 从close表最后一条开始,查找其前一个节点,直到前一节点为0,并将所有查到的序列写入step,打印出step
def prtResult():
step = [closed[-1]]
nodePrt = closed[-1][4]
while True:
for x in range(len(closed)):
if nodePrt == closed[x][3]:
step.insert(0, closed[x])
nodePrt = closed[x][4]
if nodePrt == 0:
break
for x in range(len(step)):
print('Step', x, ':')
prtNum(step[x][0])
print('Finished!')
time.sleep(10)
if __name__ == '__main__':
# 初始化
open = []
closed = []
nodeid = 1
# 输入初始和目标序列,并打印出来供确认
while True:
print('Please input Original state:', end='\t')
tmp = input()
numberOrig = [tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], tmp[5], tmp[6], tmp[7], tmp[8]]
print('Please input Final state:', end='\t')
tmp = input()
numberFinal = [tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], tmp[5], tmp[6], tmp[7], tmp[8]]
print('Orig is')
prtNum(numberOrig)
print('Final is')
prtNum(numberFinal)
if (judge(numberOrig) + judge(numberFinal)) % 2 == 0:
print('Have answer! Orig is ', judge(numberOrig), ', Final is', judge(numberFinal))
# 提取目标节点的数码位置
# 将初始节点加入open表,开始处理。
open.append([numberOrig, 'NULL', manhattan(numberOrig, numberFinal), 1, 0, 0])
handleOpen()
else:
print('No answer! Orig is ', judge(numberOrig), ', Final is', judge(numberFinal))
Please input Original state: 012345678
Please input Final state: 123450678
Orig is
0 1 2
3 4 5
6 7 8
Final is
1 2 3
4 5 0
6 7 8
Have answer! Orig is 0 , Final is 0
Totally 474 nodes ayalyzed,find the result.
Step 0 :
0 1 2
3 4 5
6 7 8
Step 1 :
1 0 2
3 4 5
6 7 8
Step 2 :
1 2 0
3 4 5
6 7 8
Step 3 :
1 2 5
3 4 0
6 7 8
Step 4 :
1 2 5
3 0 4
6 7 8
Step 5 :
1 2 5
0 3 4
6 7 8
Step 6 :
0 2 5
1 3 4
6 7 8
Step 7 :
2 0 5
1 3 4
6 7 8
Step 8 :
2 5 0
1 3 4
6 7 8
Step 9 :
2 3 5
1 0 4
6 7 8
Step 10 :
2 3 5
1 4 0
6 7 8
Step 11 :
2 3 0
1 4 5
6 7 8
Step 12 :
2 0 3
1 4 5
6 7 8
Step 13 :
0 2 3
1 4 5
6 7 8
Step 14 :
1 2 3
0 4 5
6 7 8
Step 15 :
1 2 3
4 0 5
6 7 8
Step 16 :
1 2 3
4 5 0
6 7 8
Finished!
Success!
We find it!