A-Star算法探索和实现(一)

什么是A-Star算法?

A*搜寻算法俗称A星算法,又叫A-Star算法。A*算法是比较流行的启发式搜索算法之一,被广泛应用于路径优化领域(引用自百度百科)。在游戏开发中,我们可以将A*算法作为一种敌人的寻路算法,在伊庭齐志所著的《AI游戏开发和深度学习进阶》中有关于A*算法在各种游戏中的运用。如果对A-Star算法不是很了解,推荐浏览这两篇博文,一篇是英文原著,一篇是针对该英文原著的汉译版本。

英文原著 英文原著汉译版本

A-Star算法的思路

A*算法要求寻找到从起始点到终点的最佳路径,从起始点开始,首先以起始点作为当前点,开始收集当前点下一步可前往的点,保存为一个点集并对这些点的权值进行计算,然后以权值最小的点作为下一步要前往的点,将该点设为当前点再重复上述过程,直至到达终点。

示例代码UML图

A-Star算法探索和实现(一)_第1张图片

代码示例(Python)

Grid.py

import matplotlib.pyplot as pp
import numpy as np
from matplotlib.pyplot import MultipleLocator
from myCodes import SharedData as sd
from myCodes import Node as node


# 视图的构建和绘制
class Grid:
    # 横纵坐标轴的刻度标记
    __xLabels = []
    __yLabels = []
    # 网格单元格的信息
    __data = None
    # 是否完成初始化
    __isInit = False

    def __init__(self, p_rowsCount, p_colsCount):
        if self.__isInit is False:
            # 网格的行列数
            self.__rowsCount = p_rowsCount
            self.__colsCount = p_colsCount
            # Node工厂
            self.__nodeFactory = node.NodeFactory()
            # 将网格的行列数设置为共享信息
            sd.SharedData.rowsCount = self.__rowsCount
            sd.SharedData.colsCount = self.__colsCount
            # 将终点Node设置为共享信息
            v_endNode = self.__nodeFactory.GetEndNode(self.__rowsCount, self.__colsCount)
            sd.SharedData.endNode = v_endNode
            self.__isInit = True

    # 绘制图形
    def Draw(self):
        # 获取坐标轴实例
        v_fig, v_ax = pp.subplots()
        self._AxisSet(v_ax)
        self._GridValueSet(v_ax)
        # 将数据以二维图片的形式进行显示
        v_ax.imshow(self.__data, cmap='Accent', vmin=0, vmax=255)
        # 将网格和刻度线显示在多数artists上方
        v_ax.set_axisbelow(False)
        # 标记起始点和终点
        self._StartPointSet(self.__nodeFactory.GetStartNode())
        self._PathNodeSet()
        self._EndPointSet(sd.SharedData.endNode)
        # 布置网格线
        pp.grid(visible=True, color='w')
        # 采用紧凑型布局
        v_fig.tight_layout()
        pp.show()

    # 标记起始点
    def _StartPointSet(self, p_node: node.Node):
        self._Mark(p_node, 30, 'yo')

    # 标记终点
    def _EndPointSet(self, p_node: node.Node):
        self._Mark(p_node, 30, 'ro')

    # 标记所查找到的路径
    def _PathNodeSet(self):
        v_pathNodes = sd.SharedData.pathNodes
        v_length = len(v_pathNodes)
        for i in range(v_length):
            if 0 < i < v_length - 1:
                self._Mark(v_pathNodes[i], 10, 'bo')
                self._Arrow(v_pathNodes[i - 1], v_pathNodes[i])
            elif i == v_length - 1:
                self._Arrow(v_pathNodes[i - 1], v_pathNodes[i])

    # 标记方法
    def _Mark(self, p_node: node.Node, p_marksize: int, p_fmt: str):
        v_x = p_node.nodePos.x
        v_y = p_node.nodePos.y
        pp.plot(v_x, v_y, p_fmt, markersize=p_marksize, zorder=1)

    # 箭头指向方法
    def _Arrow(self, p_firstNode, p_secondNode):
        v_dx = p_secondNode.nodePos.x - p_firstNode.nodePos.x
        v_dy = p_secondNode.nodePos.y - p_firstNode.nodePos.y
        pp.arrow(p_firstNode.nodePos.x, p_firstNode.nodePos.y, v_dx, v_dy, color='orange', width=0.01, head_width=0.08,
                 zorder=3)

    # 坐标轴设置
    def _AxisSet(self, p_ax):
        v_ax = p_ax
        for i in range(1, self.__rowsCount + 1):
            self.__xLabels.append(str(i))
        for i in range(1, self.__colsCount + 1):
            self.__yLabels.append(str(i))
        # 设置横坐标轴为bottom,纵坐标轴为left
        v_ax.xaxis.set_ticks_position('bottom')
        v_ax.yaxis.set_ticks_position('left')
        # 隐藏边框
        v_ax.spines['bottom'].set(visible=False)
        v_ax.spines['right'].set(visible=False)
        v_ax.spines['top'].set(visible=False)
        v_ax.spines['left'].set(visible=False)
        # 隐藏刻度线
        v_ax.tick_params(left=False, bottom=False, top=False, right=False)
        # 坐标轴的刻度间隔实例
        v_xMinorLocator = MultipleLocator(1)
        v_yMinorLocator = MultipleLocator(1)
        v_low = 1
        if self.__rowsCount > self.__colsCount:
            v_high = self.__rowsCount
        else:
            v_high = self.__colsCount
        self.__data = np.random.randint(v_low, v_high, size=(self.__rowsCount + 1, self.__colsCount + 1))
        # 设置横纵坐标轴的范围
        pp.xlim(1, self.__rowsCount)
        pp.ylim(1, self.__colsCount)
        # 设置坐标轴的刻度标记
        v_ax.set_xticks(np.arange(self.__rowsCount), labels=self.__xLabels, visible=False)
        v_ax.set_yticks(np.arange(self.__colsCount), labels=self.__yLabels, visible=False)
        # 设置坐标轴的次要刻度间隔
        v_ax.xaxis.set_minor_locator(v_xMinorLocator)
        v_ax.yaxis.set_minor_locator(v_yMinorLocator)
        # 设置坐标轴的横纵轴比例相等
        v_ax.set_aspect('equal')

    # 网格内容设置
    def _GridValueSet(self, p_ax):
        v_ax = p_ax
        for i in range(self.__rowsCount + 1):
            for j in range(self.__colsCount + 1):
                v_str = '(' + str(i + 0.5) + ',' + str(j + 0.5) + ')'
                v_ax.text(i + 0.5, j + 0.5, v_str, ha='center', va='center', color='w')

Node.py

import random


# Node坐标类
class NodePosition:
    def __init__(self, p_x: float = 1, p_y: float = 1):
        self.x = p_x
        self.y = p_y

    def __str__(self):
        return '[' + str(self.x) + ',' + str(self.y) + ']'


# Node类
class Node:
    def __init__(self, p_nodeName: str, p_nodePos: NodePosition):
        self.nodeName = p_nodeName
        self.nodePos = p_nodePos
        # f=g+h,g代表从上一个点到该点的代价和,h代表从该点到终点的代价和,f代表总权值weight
        self.f: int = 0
        self.g: int = 0
        self.h: int = 0

    # 设置Node的权值
    def SetWeight(self, p_f: int, p_g: int, p_h: int):
        self.f = p_f
        self.g = p_g
        self.h = p_h

    def __str__(self):
        return '{nodeName:' + self.nodeName + ',nodePos:' + str(self.nodePos) + ',f=' + str(self.f) + ',g=' + str(
            self.g) + ',h=' + str(self.h) + '}'


# Node工厂,用来创建和获取起始点Node和终点Node
class NodeFactory:
    __isCreateEndNode = False
    __isCreateStartNode = False
    __startNode: Node
    __endNode: Node

    def __init__(self):
        pass

    # 获取起始点Node
    def GetStartNode(self):
        if self.__isCreateStartNode is False:
            self._CreateStartNode()
        return self.__startNode

    def _CreateStartNode(self):
        v_nodePos = NodePosition(0.5, 0.5)
        self.__startNode = Node('StartNode', v_nodePos)
        self.__isCreateStartNode = True

    # 获取终点Node
    def GetEndNode(self, p_rowsCount: int, p_colsCount: int):
        if self.__isCreateEndNode is False:
            self._CreateEndNode(p_rowsCount, p_colsCount)
        return self.__endNode

    def _CreateEndNode(self, p_rowsCount, p_colsCount):
        v_x = 0.5
        v_y = 0.5
        while v_x == 0.5 and v_y == 0.5:
            v_i = random.randint(0, p_rowsCount - 1)
            v_x = v_i + 0.5
            v_i = random.randint(0, p_colsCount - 1)
            v_y = v_i + 0.5
        v_nodePos = NodePosition(v_x, v_y)
        self.__endNode = Node('EndNode', v_nodePos)
        self.__isCreateEndNode = True

SearchPath.py

from myCodes import SharedData as sd
from myCodes import Node as node


# 查找最佳路径
class SearchPath:
    # 对角线移动一格的代价
    __diagonalCost = 14
    # 上下或左右移动一格的代价
    __nonDiagonalCost = 10
    __nodeFactory: node.NodeFactory
    __currentNode: node.Node
    # 是否完成了初始化
    __isInit = False
    __openList = []
    __closeList = []
    # 网格行列数
    __rowsCount = 0
    __colsCount = 0
    # Node名称索引
    __nameIndex = 1
    # x和y的最小值
    __minX = 0.5
    __minY = 0.5
    # x和y的最大值
    __maxX = 0
    __maxY = 0
    # 终点Node
    __endNode: node.Node

    def __init__(self):
        if self.__isInit is False:
            self.__nodeFactory = node.NodeFactory()
            self.__currentNode = self.__nodeFactory.GetStartNode()
            self.__openList.append(self.__currentNode)
            self.__rowsCount = sd.SharedData.rowsCount
            self.__colsCount = sd.SharedData.colsCount
            self.__maxX = self.__rowsCount - 0.5
            self.__maxY = self.__colsCount - 0.5
            self.__endNode = sd.SharedData.endNode
            self.__isInit = True

    # 查找最佳路径
    def Search(self):
        while self._UpdateCurrentNode():
            v_list = self._GetNextNodes()
            v_list = self._NodeCheck(v_list)
            v_list = self._CalculateWeight(v_list)
            v_list = self._SortNode(v_list)
            self._AddNode(v_list)
        sd.SharedData.pathNodes = self.__closeList
        self._PrintPath()

    # 获取currentNode下一步可以前往的Node,并将它们保存在一个临时列表v_list中
    def _GetNextNodes(self):
        if self.__currentNode is not None:
            v_p = self.__currentNode.nodePos
            v_posList = [(v_p.x - 1, v_p.y), (v_p.x - 1, v_p.y + 1), (v_p.x, v_p.y + 1), (v_p.x + 1, v_p.y + 1),
                         (v_p.x + 1, v_p.y), (v_p.x + 1, v_p.y - 1), (v_p.x, v_p.y - 1), (v_p.x - 1, v_p.y - 1)]
            v_nameList = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
            v_list = []
            for i in range(len(v_posList)):
                v_nodeName = v_nameList[i] + str(self.__nameIndex)
                v_nodePos = node.NodePosition(v_posList[i][0], v_posList[i][1])
                v_node = node.Node(v_nodeName, v_nodePos)
                if self._isInOpenList(v_node) is False and self._isInCloseList(v_node) is False:
                    v_list.append(v_node)
            self.__nameIndex += 1
            return v_list
        return None

    # 检查临时列表p_list中哪些Node不符合要求,保留符合要求的节点,删除不符合要求的节点
    def _NodeCheck(self, p_list):
        if p_list is not None:
            v_list = []
            for i in range(len(p_list)):
                v_x = p_list[i].nodePos.x
                v_y = p_list[i].nodePos.y
                if self.__minX <= v_x <= self.__maxX and self.__minY <= v_y <= self.__maxY:
                    v_list.append(p_list[i])
            return v_list
        return None

    # 计算临时列表p_list中每个Node的权值
    def _CalculateWeight(self, p_list):
        if p_list is not None:
            v_list = p_list
            v_startNodeX = self.__currentNode.nodePos.x
            v_startNodeY = self.__currentNode.nodePos.y
            v_endNodeX = self.__endNode.nodePos.x
            v_endNodeY = self.__endNode.nodePos.y
            for n in v_list:
                v_x = n.nodePos.x
                v_y = n.nodePos.y
                if v_x == v_y:
                    v_g = abs((v_x - v_startNodeX)) * self.__diagonalCost
                    if v_endNodeX == v_endNodeY:
                        v_h = abs((v_endNodeX - v_x)) * self.__diagonalCost
                    else:
                        v_h = abs((v_endNodeX - v_x)) * self.__nonDiagonalCost + abs(
                            (v_endNodeY - v_y)) * self.__nonDiagonalCost
                else:
                    v_g = abs((v_x - v_startNodeX)) * self.__nonDiagonalCost + abs(
                        (v_y - v_startNodeY)) * self.__nonDiagonalCost
                    v_h = abs((v_endNodeX - v_x)) * self.__nonDiagonalCost + abs(
                        (v_endNodeY - v_y)) * self.__nonDiagonalCost
                v_f = v_g + v_h
                n.SetWeight(v_f, v_g, v_h)
            return v_list
        return None

    # 根据临时列表p_list中每个Node的权值进行排序,权值越小越接近列表尾
    def _SortNode(self, p_list):
        if p_list is not None:
            v_list = p_list
            for i in range(0, len(v_list)):
                for j in range(i + 1, len(v_list)):
                    if v_list[i].f > v_list[j].f:
                        v_node = v_list[i]
                        v_list[i] = v_list[j]
                        v_list[j] = v_node
            return v_list
        return None

    # 将临时列表p_list拼接在openList的列表尾
    def _AddNode(self, p_list):
        if p_list is not None:
            v_list = []
            for i in range(len(p_list)):
                v_n = p_list[i]
                if self._isInOpenList(v_n) is False and self._isInCloseList(v_n) is False:
                    v_list.append(v_n)
            self.__openList.append(v_list[0])
            return True
        return False

    # 判断p_node是否为终点Node
    def _ReachEndNodeCheck(self, p_node: node.Node):
        v_x = p_node.nodePos.x
        v_y = p_node.nodePos.y
        if v_x == self.__endNode.nodePos.x and v_y == self.__endNode.nodePos.y:
            return True
        return False

    # 判断p_node是否在openList中
    def _isInOpenList(self, p_node: node.Node):
        v_x = p_node.nodePos.x
        v_y = p_node.nodePos.y
        for n in self.__openList:
            if v_x == n.nodePos.x and v_y == n.nodePos.y:
                return True
        return False

    # 判断p_node是否在closeList中
    def _isInCloseList(self, p_node: node.Node):
        v_x = p_node.nodePos.x
        v_y = p_node.nodePos.y
        for n in self.__closeList:
            if v_x == n.nodePos.x and v_y == n.nodePos.y:
                return True
        return False

    # 打印最佳路径
    def _PrintPath(self):
        v_str = ''
        v_length = len(self.__closeList)
        for i in range(v_length):
            if i < v_length - 1:
                v_str += str(self.__closeList[i]) + '-->'
            else:
                v_str += str(self.__closeList[i])
        print(v_str)

    # 从openList中获取列表尾的元素,将之作为currentNode并加入closeList中,然后将其从openList中移除
    def _UpdateCurrentNode(self):
        if len(self.__openList) > 0:
            self.__currentNode = self.__openList[0]
            if self._isInCloseList(self.__currentNode) is False:
                self.__closeList.append(self.__currentNode)
                del self.__openList[0]
                if self._ReachEndNodeCheck(self.__currentNode):
                    return False
                return True
        return False

SharedData.py

from myCodes import Node as node


# 共享信息类
class SharedData:
    # 网格行列数
    rowsCount = 0
    colsCount = 0
    # 终点Node
    endNode: node.Node
    # 所查找的路径Node集
    pathNodes = []

Main.py

from myCodes import SearchPath as sp
from myCodes import Grid as grid

g = grid.Grid(5, 5)
sp.SearchPath().Search()
g.Draw()

运行截图

A-Star算法探索和实现(一)_第2张图片
A-Star算法探索和实现(一)_第3张图片

代码解说

在这个示例中,我们分为两个层,一个是视图层,一个是数据层,视图层包括Grid类,数据层包括Node类、NodePosition类、NodeFactory类、SharedData类、SearchPath类。Grid类负责对A_Star算法进行可视化展现,包括基础的网格、文本和标记的构建和显示;Node类作为节点类,对应着视图层的每一个网格单元;NodePosition类用于记录节点的坐标信息;NodeFactory类用于创建和获取起始点Node和终点Node;SharedData类用于全局信息的共享;SearchPath类是路径查找类,该类中主要编写了A_Star算法实现的逻辑,包括获取下一步可前往的Node集、对Node集中的Node进行筛选、计算Node的权值、对Node集进行排序、将Node集添加至OpenList中以及从OpenList中读取Node等。主要的调用顺序:1.在Main方法中创建Grid的实例,并且传递网格的行列数;2.执行SearchPath中的Search方法查找最佳路径;3.调用Grid中的Draw方法对最佳路径进行显示。

下一篇将讲解添加障碍物后的A-Star算法探索和实现

如果这篇文章对你有帮助,请给作者点个赞吧!

你可能感兴趣的:(算法探索,算法)