A*搜寻算法俗称A星算法,又叫A-Star算法。A*算法是比较流行的启发式搜索算法之一,被广泛应用于路径优化领域(引用自百度百科)。在游戏开发中,我们可以将A*算法作为一种敌人的寻路算法,在伊庭齐志所著的《AI游戏开发和深度学习进阶》中有关于A*算法在各种游戏中的运用。如果对A-Star算法不是很了解,推荐浏览这两篇博文,一篇是英文原著,一篇是针对该英文原著的汉译版本。
英文原著 英文原著汉译版本
A*算法要求寻找到从起始点到终点的最佳路径,从起始点开始,首先以起始点作为当前点,开始收集当前点下一步可前往的点,保存为一个点集并对这些点的权值进行计算,然后以权值最小的点作为下一步要前往的点,将该点设为当前点再重复上述过程,直至到达终点。
Grid.py
import matplotlib.pyplot as pp
import numpy as np
from matplotlib.pyplot import MultipleLocator
from myCodes import SharedData as sd
from myCodes import Node as node
# 视图的构建和绘制
class Grid:
# 横纵坐标轴的刻度标记
__xLabels = []
__yLabels = []
# 网格单元格的信息
__data = None
# 是否完成初始化
__isInit = False
def __init__(self, p_rowsCount, p_colsCount):
if self.__isInit is False:
# 网格的行列数
self.__rowsCount = p_rowsCount
self.__colsCount = p_colsCount
# Node工厂
self.__nodeFactory = node.NodeFactory()
# 将网格的行列数设置为共享信息
sd.SharedData.rowsCount = self.__rowsCount
sd.SharedData.colsCount = self.__colsCount
# 将终点Node设置为共享信息
v_endNode = self.__nodeFactory.GetEndNode(self.__rowsCount, self.__colsCount)
sd.SharedData.endNode = v_endNode
self.__isInit = True
# 绘制图形
def Draw(self):
# 获取坐标轴实例
v_fig, v_ax = pp.subplots()
self._AxisSet(v_ax)
self._GridValueSet(v_ax)
# 将数据以二维图片的形式进行显示
v_ax.imshow(self.__data, cmap='Accent', vmin=0, vmax=255)
# 将网格和刻度线显示在多数artists上方
v_ax.set_axisbelow(False)
# 标记起始点和终点
self._StartPointSet(self.__nodeFactory.GetStartNode())
self._PathNodeSet()
self._EndPointSet(sd.SharedData.endNode)
# 布置网格线
pp.grid(visible=True, color='w')
# 采用紧凑型布局
v_fig.tight_layout()
pp.show()
# 标记起始点
def _StartPointSet(self, p_node: node.Node):
self._Mark(p_node, 30, 'yo')
# 标记终点
def _EndPointSet(self, p_node: node.Node):
self._Mark(p_node, 30, 'ro')
# 标记所查找到的路径
def _PathNodeSet(self):
v_pathNodes = sd.SharedData.pathNodes
v_length = len(v_pathNodes)
for i in range(v_length):
if 0 < i < v_length - 1:
self._Mark(v_pathNodes[i], 10, 'bo')
self._Arrow(v_pathNodes[i - 1], v_pathNodes[i])
elif i == v_length - 1:
self._Arrow(v_pathNodes[i - 1], v_pathNodes[i])
# 标记方法
def _Mark(self, p_node: node.Node, p_marksize: int, p_fmt: str):
v_x = p_node.nodePos.x
v_y = p_node.nodePos.y
pp.plot(v_x, v_y, p_fmt, markersize=p_marksize, zorder=1)
# 箭头指向方法
def _Arrow(self, p_firstNode, p_secondNode):
v_dx = p_secondNode.nodePos.x - p_firstNode.nodePos.x
v_dy = p_secondNode.nodePos.y - p_firstNode.nodePos.y
pp.arrow(p_firstNode.nodePos.x, p_firstNode.nodePos.y, v_dx, v_dy, color='orange', width=0.01, head_width=0.08,
zorder=3)
# 坐标轴设置
def _AxisSet(self, p_ax):
v_ax = p_ax
for i in range(1, self.__rowsCount + 1):
self.__xLabels.append(str(i))
for i in range(1, self.__colsCount + 1):
self.__yLabels.append(str(i))
# 设置横坐标轴为bottom,纵坐标轴为left
v_ax.xaxis.set_ticks_position('bottom')
v_ax.yaxis.set_ticks_position('left')
# 隐藏边框
v_ax.spines['bottom'].set(visible=False)
v_ax.spines['right'].set(visible=False)
v_ax.spines['top'].set(visible=False)
v_ax.spines['left'].set(visible=False)
# 隐藏刻度线
v_ax.tick_params(left=False, bottom=False, top=False, right=False)
# 坐标轴的刻度间隔实例
v_xMinorLocator = MultipleLocator(1)
v_yMinorLocator = MultipleLocator(1)
v_low = 1
if self.__rowsCount > self.__colsCount:
v_high = self.__rowsCount
else:
v_high = self.__colsCount
self.__data = np.random.randint(v_low, v_high, size=(self.__rowsCount + 1, self.__colsCount + 1))
# 设置横纵坐标轴的范围
pp.xlim(1, self.__rowsCount)
pp.ylim(1, self.__colsCount)
# 设置坐标轴的刻度标记
v_ax.set_xticks(np.arange(self.__rowsCount), labels=self.__xLabels, visible=False)
v_ax.set_yticks(np.arange(self.__colsCount), labels=self.__yLabels, visible=False)
# 设置坐标轴的次要刻度间隔
v_ax.xaxis.set_minor_locator(v_xMinorLocator)
v_ax.yaxis.set_minor_locator(v_yMinorLocator)
# 设置坐标轴的横纵轴比例相等
v_ax.set_aspect('equal')
# 网格内容设置
def _GridValueSet(self, p_ax):
v_ax = p_ax
for i in range(self.__rowsCount + 1):
for j in range(self.__colsCount + 1):
v_str = '(' + str(i + 0.5) + ',' + str(j + 0.5) + ')'
v_ax.text(i + 0.5, j + 0.5, v_str, ha='center', va='center', color='w')
Node.py
import random
# Node坐标类
class NodePosition:
def __init__(self, p_x: float = 1, p_y: float = 1):
self.x = p_x
self.y = p_y
def __str__(self):
return '[' + str(self.x) + ',' + str(self.y) + ']'
# Node类
class Node:
def __init__(self, p_nodeName: str, p_nodePos: NodePosition):
self.nodeName = p_nodeName
self.nodePos = p_nodePos
# f=g+h,g代表从上一个点到该点的代价和,h代表从该点到终点的代价和,f代表总权值weight
self.f: int = 0
self.g: int = 0
self.h: int = 0
# 设置Node的权值
def SetWeight(self, p_f: int, p_g: int, p_h: int):
self.f = p_f
self.g = p_g
self.h = p_h
def __str__(self):
return '{nodeName:' + self.nodeName + ',nodePos:' + str(self.nodePos) + ',f=' + str(self.f) + ',g=' + str(
self.g) + ',h=' + str(self.h) + '}'
# Node工厂,用来创建和获取起始点Node和终点Node
class NodeFactory:
__isCreateEndNode = False
__isCreateStartNode = False
__startNode: Node
__endNode: Node
def __init__(self):
pass
# 获取起始点Node
def GetStartNode(self):
if self.__isCreateStartNode is False:
self._CreateStartNode()
return self.__startNode
def _CreateStartNode(self):
v_nodePos = NodePosition(0.5, 0.5)
self.__startNode = Node('StartNode', v_nodePos)
self.__isCreateStartNode = True
# 获取终点Node
def GetEndNode(self, p_rowsCount: int, p_colsCount: int):
if self.__isCreateEndNode is False:
self._CreateEndNode(p_rowsCount, p_colsCount)
return self.__endNode
def _CreateEndNode(self, p_rowsCount, p_colsCount):
v_x = 0.5
v_y = 0.5
while v_x == 0.5 and v_y == 0.5:
v_i = random.randint(0, p_rowsCount - 1)
v_x = v_i + 0.5
v_i = random.randint(0, p_colsCount - 1)
v_y = v_i + 0.5
v_nodePos = NodePosition(v_x, v_y)
self.__endNode = Node('EndNode', v_nodePos)
self.__isCreateEndNode = True
SearchPath.py
from myCodes import SharedData as sd
from myCodes import Node as node
# 查找最佳路径
class SearchPath:
# 对角线移动一格的代价
__diagonalCost = 14
# 上下或左右移动一格的代价
__nonDiagonalCost = 10
__nodeFactory: node.NodeFactory
__currentNode: node.Node
# 是否完成了初始化
__isInit = False
__openList = []
__closeList = []
# 网格行列数
__rowsCount = 0
__colsCount = 0
# Node名称索引
__nameIndex = 1
# x和y的最小值
__minX = 0.5
__minY = 0.5
# x和y的最大值
__maxX = 0
__maxY = 0
# 终点Node
__endNode: node.Node
def __init__(self):
if self.__isInit is False:
self.__nodeFactory = node.NodeFactory()
self.__currentNode = self.__nodeFactory.GetStartNode()
self.__openList.append(self.__currentNode)
self.__rowsCount = sd.SharedData.rowsCount
self.__colsCount = sd.SharedData.colsCount
self.__maxX = self.__rowsCount - 0.5
self.__maxY = self.__colsCount - 0.5
self.__endNode = sd.SharedData.endNode
self.__isInit = True
# 查找最佳路径
def Search(self):
while self._UpdateCurrentNode():
v_list = self._GetNextNodes()
v_list = self._NodeCheck(v_list)
v_list = self._CalculateWeight(v_list)
v_list = self._SortNode(v_list)
self._AddNode(v_list)
sd.SharedData.pathNodes = self.__closeList
self._PrintPath()
# 获取currentNode下一步可以前往的Node,并将它们保存在一个临时列表v_list中
def _GetNextNodes(self):
if self.__currentNode is not None:
v_p = self.__currentNode.nodePos
v_posList = [(v_p.x - 1, v_p.y), (v_p.x - 1, v_p.y + 1), (v_p.x, v_p.y + 1), (v_p.x + 1, v_p.y + 1),
(v_p.x + 1, v_p.y), (v_p.x + 1, v_p.y - 1), (v_p.x, v_p.y - 1), (v_p.x - 1, v_p.y - 1)]
v_nameList = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
v_list = []
for i in range(len(v_posList)):
v_nodeName = v_nameList[i] + str(self.__nameIndex)
v_nodePos = node.NodePosition(v_posList[i][0], v_posList[i][1])
v_node = node.Node(v_nodeName, v_nodePos)
if self._isInOpenList(v_node) is False and self._isInCloseList(v_node) is False:
v_list.append(v_node)
self.__nameIndex += 1
return v_list
return None
# 检查临时列表p_list中哪些Node不符合要求,保留符合要求的节点,删除不符合要求的节点
def _NodeCheck(self, p_list):
if p_list is not None:
v_list = []
for i in range(len(p_list)):
v_x = p_list[i].nodePos.x
v_y = p_list[i].nodePos.y
if self.__minX <= v_x <= self.__maxX and self.__minY <= v_y <= self.__maxY:
v_list.append(p_list[i])
return v_list
return None
# 计算临时列表p_list中每个Node的权值
def _CalculateWeight(self, p_list):
if p_list is not None:
v_list = p_list
v_startNodeX = self.__currentNode.nodePos.x
v_startNodeY = self.__currentNode.nodePos.y
v_endNodeX = self.__endNode.nodePos.x
v_endNodeY = self.__endNode.nodePos.y
for n in v_list:
v_x = n.nodePos.x
v_y = n.nodePos.y
if v_x == v_y:
v_g = abs((v_x - v_startNodeX)) * self.__diagonalCost
if v_endNodeX == v_endNodeY:
v_h = abs((v_endNodeX - v_x)) * self.__diagonalCost
else:
v_h = abs((v_endNodeX - v_x)) * self.__nonDiagonalCost + abs(
(v_endNodeY - v_y)) * self.__nonDiagonalCost
else:
v_g = abs((v_x - v_startNodeX)) * self.__nonDiagonalCost + abs(
(v_y - v_startNodeY)) * self.__nonDiagonalCost
v_h = abs((v_endNodeX - v_x)) * self.__nonDiagonalCost + abs(
(v_endNodeY - v_y)) * self.__nonDiagonalCost
v_f = v_g + v_h
n.SetWeight(v_f, v_g, v_h)
return v_list
return None
# 根据临时列表p_list中每个Node的权值进行排序,权值越小越接近列表尾
def _SortNode(self, p_list):
if p_list is not None:
v_list = p_list
for i in range(0, len(v_list)):
for j in range(i + 1, len(v_list)):
if v_list[i].f > v_list[j].f:
v_node = v_list[i]
v_list[i] = v_list[j]
v_list[j] = v_node
return v_list
return None
# 将临时列表p_list拼接在openList的列表尾
def _AddNode(self, p_list):
if p_list is not None:
v_list = []
for i in range(len(p_list)):
v_n = p_list[i]
if self._isInOpenList(v_n) is False and self._isInCloseList(v_n) is False:
v_list.append(v_n)
self.__openList.append(v_list[0])
return True
return False
# 判断p_node是否为终点Node
def _ReachEndNodeCheck(self, p_node: node.Node):
v_x = p_node.nodePos.x
v_y = p_node.nodePos.y
if v_x == self.__endNode.nodePos.x and v_y == self.__endNode.nodePos.y:
return True
return False
# 判断p_node是否在openList中
def _isInOpenList(self, p_node: node.Node):
v_x = p_node.nodePos.x
v_y = p_node.nodePos.y
for n in self.__openList:
if v_x == n.nodePos.x and v_y == n.nodePos.y:
return True
return False
# 判断p_node是否在closeList中
def _isInCloseList(self, p_node: node.Node):
v_x = p_node.nodePos.x
v_y = p_node.nodePos.y
for n in self.__closeList:
if v_x == n.nodePos.x and v_y == n.nodePos.y:
return True
return False
# 打印最佳路径
def _PrintPath(self):
v_str = ''
v_length = len(self.__closeList)
for i in range(v_length):
if i < v_length - 1:
v_str += str(self.__closeList[i]) + '-->'
else:
v_str += str(self.__closeList[i])
print(v_str)
# 从openList中获取列表尾的元素,将之作为currentNode并加入closeList中,然后将其从openList中移除
def _UpdateCurrentNode(self):
if len(self.__openList) > 0:
self.__currentNode = self.__openList[0]
if self._isInCloseList(self.__currentNode) is False:
self.__closeList.append(self.__currentNode)
del self.__openList[0]
if self._ReachEndNodeCheck(self.__currentNode):
return False
return True
return False
SharedData.py
from myCodes import Node as node
# 共享信息类
class SharedData:
# 网格行列数
rowsCount = 0
colsCount = 0
# 终点Node
endNode: node.Node
# 所查找的路径Node集
pathNodes = []
Main.py
from myCodes import SearchPath as sp
from myCodes import Grid as grid
g = grid.Grid(5, 5)
sp.SearchPath().Search()
g.Draw()
在这个示例中,我们分为两个层,一个是视图层,一个是数据层,视图层包括Grid类,数据层包括Node类、NodePosition类、NodeFactory类、SharedData类、SearchPath类。Grid类负责对A_Star算法进行可视化展现,包括基础的网格、文本和标记的构建和显示;Node类作为节点类,对应着视图层的每一个网格单元;NodePosition类用于记录节点的坐标信息;NodeFactory类用于创建和获取起始点Node和终点Node;SharedData类用于全局信息的共享;SearchPath类是路径查找类,该类中主要编写了A_Star算法实现的逻辑,包括获取下一步可前往的Node集、对Node集中的Node进行筛选、计算Node的权值、对Node集进行排序、将Node集添加至OpenList中以及从OpenList中读取Node等。主要的调用顺序:1.在Main方法中创建Grid的实例,并且传递网格的行列数;2.执行SearchPath中的Search方法查找最佳路径;3.调用Grid中的Draw方法对最佳路径进行显示。
下一篇将讲解添加障碍物后的A-Star算法探索和实现
如果这篇文章对你有帮助,请给作者点个赞吧!