Python实验--手写A*算法实现八数码问题

1. 问题描述

针对问题:路径搜索

问题描述:有A,B,C,...N个节点,每个节点之间已定义能否达到以及路径代价,目标为搜索到最佳路径

八数码问题:在3*3九宫格中有1-8八个数,剩下一个为空格,每次只能移动空格一次,给定初始状态和目标状态,求得最佳移动方法和最短移动距离

2. A算法原理

  1.  首先将初始节点放入open表
  2. 将初始节点放入closed表,并从初始节点向各个方向扩展节点,将新扩展节点放入open表
  3. 根据新扩展节点所对应的路径计算各自的代价,并以此进行排序
  4. 从open表中选取代价最小的节点作为当前节点放入closed表
  5. 重复2,3.4步,直到寻找到目标节点,返回最佳路径

3. A*算法要求

A*算法和A算法的区别在于对路径计算代价的公式的要求

A算法:代价=历史代价+未来代价(启发函数),其中历史代价指的是走到当前状态的已知代价,未来代价是指当前节点到目标状态的预估代价

A*算法:要求预估的代价必定大于等于真实的未来代价

4. 如何判断给定状态之间是否存在可到达路径

引入逆序数概念:在一个排列中,如果一对数的前后位置与大小顺序相反,即前面的数大于后面的数,那么它们就称为一个逆序。一个排列中逆序的总数就称为这个排列的逆序数

则计算出初始状态和目标状态各自的逆序数后,若两者奇偶性一致,则可以到达,反之,则不能

5. 具体代码实现

5.1. 子函数定义

5.1.1. 逆序数计算

def judge(number):
    total = 0
    data = [9, 9, 9, 9, 9, 9, 9, 9, 9]
    for i in range(9):
        if number[i] != '0':
            data[i] = int(number[i])
        else:
            data[i] = 0
    for i in range(9):
        for j in range(i):
            if data[i] * data[j] != 0:
                if data[j] > data[i]:
                    total = total + 1
    return total

5.1.2. 基于代价调整节点位置

def addOpen(node):
        if len(open) == 0 or node[2] >= open[-1][2]:
            open.append(node)
        else:
            for i in range(len(open)):
                if node[2] < open[i][2]:
                    open.insert(i, node)
                    break

5.1.3. 启发函数(曼哈顿距离)

def manhattan(src, dst):
    total = 0
    pos_src = [[0 for x in range(2)] for y in range(9)]
    pos_dst = [[0 for x in range(2)] for y in range(9)]
    for i in range(9):
        x = i // 3
        y = i % 3
        pos_src[int(dst[i])][0] = x
        pos_src[int(dst[i])][1] = y
        pos_dst[int(src[i])][0] = x
        pos_dst[int(src[i])][1] = y
    for i in range(9):
        total = total + abs(pos_src[i][0]-pos_dst[i][0]) + abs(pos_src[i][1]-pos_dst[i][1])

    return total

 5.1.4. 返回序列中0的位置

def position(src):
    flag = src.index('0')
    row = int(flag // 3)
    col = int(flag % 3)
    return [flag, row, col]

5.1.5. 交换位置

def exchange(src, x, y, x2, y2):
    flag = x * 3 + y
    flag2 = x2 * 3 + y2
    tmp1 = src[flag]
    tmp2 = src[flag2]
    dst = copy.copy(src)
    dst[flag] = tmp2
    dst[flag2] = tmp1
    return dst

5.1.6. 打印节点状态

def prtNum(src):
    for x in range(3):
        for y in range(3):
            print(str(src[x * 3 + y] + ' '), end='')
        print()

5.2. 扩展节点函数

在表中每个节点所携带的信息:

[此节点描述,父节点代号,此节点代价,此节点代号,是否为目标状态]

def expand(src, side):
    global crt
    global nodeid
    pos = position(src)
    x = pos[1]
    y = pos[2]
    rtResult = []

    # 计算节点的历史代价(通过查找其父节点直到初始节点)
    depth = 0
    nodePrt = open[0][4]
    if nodePrt == 0:
        depth = 0
    else:
        while True:
            for i in range(len(closed)):
                if nodePrt == closed[i][3]:
                    depth = depth + 1
                    nodePrt = closed[i][4]
            if nodePrt == 0:
                break

    # 向左扩展
    if side == 'left' or side == '':
        if y > 0:
            if_final = 0
            crtLeft = exchange(src, x, y, x, y - 1)
            leftCost = manhattan(numberFinal, crtLeft) + depth
            nodeid = nodeid + 1
            if manhattan(numberFinal, crtLeft) == 0:
                if_final = 1
            rtResult.append([crtLeft, src, leftCost, nodeid, if_final])
    # 向右扩展
    if side == 'right' or side == '':
        if y < 2:
            if_final = 0
            crtRight = exchange(src, x, y, x, y + 1)
            rightCost = manhattan(numberFinal, crtRight) + depth
            nodeid = nodeid + 1
            if manhattan(numberFinal, crtRight) == 0:
                if_final = 1
            rtResult.append([crtRight, src, rightCost, nodeid, if_final])
    # 向上扩展
    if side == 'up' or side == '':
        if x > 0:
            if_final = 0
            crtUp = exchange(src, x, y, x - 1, y)
            upCost = manhattan(numberFinal, crtUp) + depth
            nodeid = nodeid + 1
            if manhattan(numberFinal, crtUp) == 0:
                if_final = 1
            rtResult.append([crtUp, src, upCost, nodeid,if_final])
    # 向下扩展
    if side == 'down' or side == '':
        if x < 2:
            if_final = 0
            crtDown = exchange(src, x, y, x + 1, y)
            depth = depth + 1
            downCost = manhattan(numberFinal, crtDown) + depth
            nodeid = nodeid + 1
            if manhattan(numberFinal, crtDown) == 0:
                if_final = 1
            rtResult.append([crtDown, src, downCost, nodeid, if_final])
    return rtResult

5.3. open表处理函数

def handleOpen():
    global nodeid # 记录扩展节点数目
    global open
    while True:
        if len(open) == 0:
            break
        x = 0
        tmpOpen = open[0]

        tmp = expand(open[0][0], '')
        for y in range(len(tmp)):
            flag = False
            for j in range(len(open)):
                if tmp[y][0] == open[j][0]:
                    flag = True
            for k in range(len(closed)):
                if tmp[y][0] == closed[k][0]:
                    flag = True
            if not flag:
                addOpen([tmp[y][0], tmp[y][1], tmp[y][2], tmp[y][3], open[x][3]])

            if tmp[y][4] == 1:  # 判断是否到达最终节点
                closed.append(tmpOpen)
                closed.append(open[0])
                open.remove(open[0])
                print('Totally', nodeid, 'nodes ayalyzed,find the result.')
                prtResult()
                print('Success!')
                exit("We find it!")
        closed.append(tmpOpen)
        open.remove(tmpOpen)

5.4. 通过节点代号回溯

# 从close表最后一条开始,查找其前一个节点,直到前一节点为0,并将所有查到的序列写入step,打印出step
def prtResult():
    step = [closed[-1]]
    nodePrt = closed[-1][4]
    while True:
        for x in range(len(closed)):
            if nodePrt == closed[x][3]:
                step.insert(0, closed[x])
                nodePrt = closed[x][4]
        if nodePrt == 0:
            break
    for x in range(len(step)):
        print('Step', x, ':')
        prtNum(step[x][0])
    print('Finished!')
    time.sleep(10)

5.5. 主函数

if __name__ == '__main__':
    # 初始化
    open = []
    closed = []
    nodeid = 1

    # 输入初始和目标序列,并打印出来供确认
    while True:
        print('Please input Original state:', end='\t')
        tmp = input()
        numberOrig = [tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], tmp[5], tmp[6], tmp[7], tmp[8]]
        print('Please input Final state:', end='\t')
        tmp = input()
        numberFinal = [tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], tmp[5], tmp[6], tmp[7], tmp[8]]
        print('Orig is')
        prtNum(numberOrig)
        print('Final is')
        prtNum(numberFinal)
        if (judge(numberOrig) + judge(numberFinal)) % 2 == 0:
            print('Have answer! Orig is ', judge(numberOrig), ', Final is', judge(numberFinal))
            # 提取目标节点的数码位置

            # 将初始节点加入open表,开始处理。
            open.append([numberOrig, 'NULL', manhattan(numberOrig, numberFinal), 1, 0, 0])
            handleOpen()
        else:
            print('No answer! Orig is ', judge(numberOrig), ', Final is', judge(numberFinal))

6. 测试结果

Please input Original state:	012345678
Please input Final state:	123450678
Orig is
0 1 2 
3 4 5 
6 7 8 
Final is
1 2 3 
4 5 0 
6 7 8 
Have answer! Orig is  0 , Final is 0
Totally 474 nodes ayalyzed,find the result.
Step 0 :
0 1 2 
3 4 5 
6 7 8 
Step 1 :
1 0 2 
3 4 5 
6 7 8 
Step 2 :
1 2 0 
3 4 5 
6 7 8 
Step 3 :
1 2 5 
3 4 0 
6 7 8 
Step 4 :
1 2 5 
3 0 4 
6 7 8 
Step 5 :
1 2 5 
0 3 4 
6 7 8 
Step 6 :
0 2 5 
1 3 4 
6 7 8 
Step 7 :
2 0 5 
1 3 4 
6 7 8 
Step 8 :
2 5 0 
1 3 4 
6 7 8 
Step 9 :
2 3 5 
1 0 4 
6 7 8 
Step 10 :
2 3 5 
1 4 0 
6 7 8 
Step 11 :
2 3 0 
1 4 5 
6 7 8 
Step 12 :
2 0 3 
1 4 5 
6 7 8 
Step 13 :
0 2 3 
1 4 5 
6 7 8 
Step 14 :
1 2 3 
0 4 5 
6 7 8 
Step 15 :
1 2 3 
4 0 5 
6 7 8 
Step 16 :
1 2 3 
4 5 0 
6 7 8 
Finished!
Success!
We find it!

你可能感兴趣的:(Python实验,python,算法,人工智能)