A* 路径搜索算法介绍及完整代码

1.简介

A* (A-Star) 算法一种静态路网中求解最短路径最有效的方法之一, 是一种常用的启发式算法.

启发式算法:通过启发函数(heruistic)计算出键值, 引导算法的搜索方向.

2. 算法描述

Ray Wenderlich - Introduction to A* Pathfinding

此文非常好的介绍了A*算法的逻辑及其中的关键点, 而且英文通俗易懂, 因此本文并不会详细的翻译此篇博文.

简单说来, A*算法是从起点开始, 通过对相邻可到达点的键值计算以及和其它可选点的键值对比, 从中选出最小键值的点, 然后继续计算此点的相邻点的键值, 直到抵达终点. 想象如同一个石子落入水中, 水的波纹外扩散, 一圈一圈的波纹可看作键值相同的点. A*算法并不需要遍历所有的点, 相反, 如果启发函数即当前点距离估计与实际值越接近, A*算法即会沿最短路径向终点移动, 避免遍历键值高的点.

因此算法关键在于利用OPEN和CLOSE列表记录访问点的情况以及如何计算键值

2.1 OPEN和CLOSE列表

OPEN列表: 所有正在被考虑的备选点列表, 此表并不是所有的点的列表, 而是A*算法在计算相邻点时, 可到达的点.

CLOSE列表: 所有访问过的点列表.

当从所有相邻可到达点中选中最小键值的点, 此点则放入CLOSE列表, 其余的点则放入OPEN列表.

2.2 键值计算 F = G+H

G值: 代表从起点到当前点的代价

H值: 代表从当前点到终点的代价, 由于从当前点出发, 并不知道最终会如何抵达终点, 所以此值为启发函数估计的值.

需要注意的是, 如果当计算相邻点(相邻点已在OPEN中)的键值F小于已存储的键值, 代表有更好的路径到达相邻点, 那么需要更新相邻点的键值以及G值.

2.3 H (heuristic)

百度百科关于h(n)的选择

以h(n)表达状态n到目标状态估计的距离, h*(n) 代表实际距离,那么h(n)的选取大致有如下三种情况:

  1. 如果h(n)< h*(n),这种情况,搜索的点数多,搜索范围大,效率低。但能得到最优解。

  2. 如果h(n)=h*(n),此时的搜索效率是最高的。

  3. 如果 h(n)>h*(n),搜索的点数少,搜索范围小,效率高,但不能保证得到最优解。

通常的启发式函数可以有曼哈顿距离, 对角线距离, 欧几里得距离等等, 或者根据实际情况建立启发式方程.

堪称最好最全的A*算法详解(译文) 此博文对于启发式函数的选择有更深入的讨论.

曼哈顿距离

h(n) = D*(abs(a.x - b.x ) + abs(a.y - b.y ))

对角线距离

h(n) = D*max(abs(a.x - b.x), abs(a.y - b.y))

欧几里得距离

h(n) = D*sqrt((a.x - b.x)^2 + (a.y - b.y)^2)

注意启发式函数选择和如何抵达相邻的点也有关系, 比如如果只能上下左右移动到达相邻的点, 那么对角线距离的选择就是不合适的.

2.4 算法循环退出条件:

1) 找到一条从起点到终点的最短路径

2) 遍历所有点无法找到路径 (OPEN列表最终为空)

3. 举例

寻找下图中从(7, 3) 到(7, 13) 的最短路径

A* 路径搜索算法介绍及完整代码_第1张图片

4. Python代码实现

import numpy as np
import heapq
import matplotlib.pyplot as plt
import time


# def heuristic(a, b):
#     return np.sqrt((b[0] - a[0]) ** 2 + (b[1] - a[1]) ** 2)

def heuristic(a, b):
    return abs(a[0]-b[0]) + abs(a[1]-b[1])


def astar(array, start, goal):
    close_set = set()
    parent_from_dict = {}
    gscore = {start: 0}
    fscore = {start: heuristic(start, goal)}
    open_list = []
    heapq.heappush(open_list, (fscore[start], start))
    while open_list:
        current = heapq.heappop(open_list)[1]  # return the smallest value
        if current == goal:
            route_loc = []
            while current in parent_from_dict:
                route_loc.append(current)
                current = parent_from_dict[current]
            # append start point to route and reorder route backward to get correct node sequence
            route_loc = route_loc + [start]
            route_loc = route_loc[::-1]
            return route_loc, close_set, open_list

        close_set.add(current)

        for i, j in neighbor_direction:
            neighbor = current[0] + i, current[1] + j
            tentative_g_score = gscore[current] + heuristic(current, neighbor)
            if 0 <= neighbor[0] < array.shape[0]:
                if 0 <= neighbor[1] < array.shape[1]:
                    if array[neighbor[0]][neighbor[1]] == 1:
                        # print('neighbor %s hit wall' % str(neighbor))
                        continue
                else:
                    # array bound y walls
                    # print('neighbor %s hit y walls' % str(neighbor))
                    continue
            else:
                # array bound x walls
                # print('neighbor %s hit x walls' % str(neighbor))
                continue

            if neighbor in close_set and tentative_g_score >= gscore.get(neighbor, 0):
                continue

            if tentative_g_score < gscore.get(neighbor, 0) or neighbor not in [i[1] for i in open_list]:
                parent_from_dict[neighbor] = current
                gscore[neighbor] = tentative_g_score
                fscore[neighbor] = tentative_g_score + heuristic(neighbor, goal)
                heapq.heappush(open_list, (fscore[neighbor], neighbor))
    return None, close_set, open_list


def plot(layout, path=None, close_set=None, open_set=None):
    fig, ax = plt.subplots(figsize=(20, 20))
    ax.imshow(layout, cmap=plt.cm.Dark2)

    ax.scatter(start[1], start[0], marker="o", color="red", s=200)
    ax.scatter(goal[1], goal[0], marker="*", color="green", s=200)

    if path:
        # extract x and y coordinates from route list
        x_coords = []
        y_coords = []
        for k in (range(0, len(path))):
            x = path[k][0]
            y = path[k][1]
            x_coords.append(x)
            y_coords.append(y)

        ax.plot(y_coords, x_coords, color="black")

    if close_set:
        for c in close_set:
            ax.scatter(c[1], c[0], marker='o', color='green')

    if open_set:
        for p in open_set:
            loc = p[1]
            ax.scatter(loc[1], loc[0], marker='x', color='blue')

    ax.xaxis.set_ticks(np.arange(0, layout.shape[1], 1))
    ax.yaxis.set_ticks(np.arange(0, layout.shape[0], 1))
    ax.xaxis.tick_top()
    plt.grid()
    plt.show()


if __name__ == '__main__':
    # neighbor_direction = [(0, 1), (0, -1), (1, 0), (-1, 0), (1, 1), (1, -1), (-1, 1), (-1, -1)]
    neighbor_direction = [(0, 1), (0, -1), (1, 0), (-1, 0)]

    S = np.array([
        [1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1],
        [0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0],
        [1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0],
        [0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1],
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0],
        [0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1],
        [0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1],
        [1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0],
        [1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1],
        [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
        [1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0],
    ])

    start = (7, 3)
    goal = (7, 17)

    # S[(7, 12)] = 1
    # S[(7, 15)] = 1

    t_start = time.perf_counter()
    route, c_set, open_l = astar(S, start, goal)
    t_end = time.perf_counter()
    print(f"astar finished in {t_end - t_start:0.4f} seconds")
    if route:
        print(route)
        plot(S, route, c_set, open_l)
    else:
        plot(S, None, c_set, open_l)

A* 路径搜索算法介绍及完整代码_第2张图片

 从上图可以发现, 算法仅仅计算了图中非常少的一些点即找到了终点(绿色"o"代表CLOSE里的点, 蓝色"x"代表OPEN里的点)

A* 路径搜索算法介绍及完整代码_第3张图片

上图为当(7,12)有障碍后的最短路径, 这时可已看出A*算法计算的点大大超过无障碍的情况, 但算法始终有效的没有遍历所有的点.

A* 路径搜索算法介绍及完整代码_第4张图片

上图为当A*算法无法找路径的情况, 可以看出算法遍历了所有可能经过的点(绿色"o"代表CLOSE里的点).

4. 算法总结

A*算法能够快速的寻找到从起点到终点的路径, 而且通过不同的启发式函数, 能控制决定是否一定要得到最短路径或者一个次优路径, 当计算时间非常长的情况下.

你可能感兴趣的:(最短路径,python,算法,启发式算法)