A*搜索算法的Python实现(八数码为例)

1、A*搜索算法介绍

A*搜寻算法,俗称A星算法,作为启发式搜索算法中的一种,这是一种在图形平面上,有多个节点的路径,求出最低通过成本的算法。常用于游戏中的NPC的移动计算,或线上游戏的BOT的移动计算上。

算法核心:
A*算法最为核心的部分,就在于它的一个估值函数的设计上:

f(n)=g(n)+h(n)

其中f(n)是每个可能试探点的估值,它有两部分组成:

  • g(n):表示从起始搜索点到当前点的代价(通常用某结点在搜索树中的深度来表示)。
  • h(n):表示启发式搜索中最为重要的一部分,即当前结点到目标结点的估值,h(n)设计的好坏,直接影响着具有此种启发式函数的启发式算法的是否能称为A*算法。

一种具有f(n)=g(n)+h(n)策略的启发式算法能成为A*算法的充分条件是:
1、搜索树上存在着从起始点到终了点的最优路径。
2、问题域是有限的。
3、所有结点的子结点的搜索代价值>0。
4、h(n)= 当此四个条件都满足时,一个具有f(n)=g(n)+h(n)策略的启发式算法能成为A*算法,并一定能找到最优解。

算法流程:
首先将起始结点S放入OPEN表,CLOSE表置空,算法开始时:
1、如果OPEN表不为空,从表头取一个结点n,如果为空算法失败。
2、n是目标解吗?是,找到一个解(继续寻找,或终止算法)。
3、将n的所有后继结点展开,就是从n可以直接关联的结点(子结点),如果不在CLOSE表中,就将它们放入OPEN表,并把S放入CLOSE表,同时计算每一个后继结点的估价值f(n),将OPEN表按f(x)排序,最小的放在表头,重复算法,回到1

2、Python实现(以八数码为例)

首先定义初始数码状态与目标数码状态:

a = [[1, 3, 4],
     [2, 6, 0],
     [7, 5, 8]]
b = [[3, 0, 2],
     [6, 1, 5],
     [8, 7, 4]]

需要注意的是,初始状态与目标状态序列的逆序对数需要同奇偶,例如a的逆序对数为4,b为10,同为偶,这样可以保证有解。然后我们定义AStar类,在初始化函数中接受S0(初始状态)和G(目标状态),获取行列等信息,然后初始化g(n)、访问过的节点数nodes以及open表和close表。

class AStar:
    def __init__(self, S0, G):
        self.S0 = S0
        self.G = G
        self.max_row = len(S0)
        self.max_col = len(S0[0])
        self.gn = 0
        self.nodes = 0
        self.close_list = {'Sn': [S0],
                           'gn': [0],
                           'hn': [self.get_h(S0)],
                           'fn': [self.get_h(S0)]}
        self.open_list = {'Sn': [], 'gn': [], 'hn': [], 'fn': []}

然后,我们定义一些常用的函数:

(1)get_loc: 获取某个值在数码表中的位置

    def get_loc(array, num):
        for i in array:
            for j in i:
                if j == num:
                    row = array.index(i) + 1
                    col = i.index(j) + 1
                    return row, col
        return None

(2)get_h: 获取某个状态的启发函数h(n)值,这里h(n)为哈密尔顿距离

    def get_h(self, Sn):
        h = 0
        for i in Sn:
            for j in i:
                if not j:
                    continue
                temp = j
                row_n, col_n = self.get_loc(Sn, temp)
                row_g, col_g = self.get_loc(self.G, temp)
                h = h + abs(row_n - row_g) + abs(col_n - col_g)
        return h

(3)add_open_list: 将可走的某个状态写入open表中
(注:如果该状态已在close表中出现过,则跳过)

    def add_open_list(self, temp_s):
        if temp_s in self.close_list['Sn']:
            return 0
        else:
            self.open_list['Sn'].append(temp_s)
            self.open_list['gn'].append(self.gn)
            self.open_list['hn'].append(self.get_h(temp_s))
            self.open_list['fn'].append(self.gn + self.get_h(temp_s))

常用函数准备好之后,接下来就是算法的主要部分,状态的搜索及变化,这里定义move函数来实现这一功能。

  • 在最开始,是获取当前的深度并+1,由于算法每次将在open表中挑选fn值最小的走法,深度的不一定是线性增长,所以每次需要获取最后放入close表中的状态所对应的深度。
  • 其后则是向上下左右四个方向的移动过程,deepcopy用来复制一个新的列表,现有的方法比如b = a[:]在复制后修改b仍会使a的值发生变化,复制的是指针而不是列表值。
	from copy import deepcopy as dp
	
    def move(self, Sn):
        # restore gn
        self.gn = self.close_list['gn'][-1]
        self.gn += 1
        row_0, col_0 = self.get_loc(Sn, 0)
        # up
        if row_0 > 1:
            temp_n = Sn[row_0 - 2][col_0 - 1]
            temp_s = dp(Sn)
            temp_s[row_0 - 2][col_0 - 1] = 0
            temp_s[row_0 - 1][col_0 - 1] = temp_n
            self.add_open_list(temp_s)
        # down
        if row_0 < self.max_row:
            temp_n = Sn[row_0][col_0 - 1]
            temp_s = dp(Sn)
            temp_s[row_0][col_0 - 1] = 0
            temp_s[row_0 - 1][col_0 - 1] = temp_n
            self.add_open_list(temp_s)
        # left
        if col_0 > 1:
            temp_n = Sn[row_0 - 1][col_0 - 2]
            temp_s = dp(Sn)
            temp_s[row_0 - 1][col_0 - 2] = 0
            temp_s[row_0 - 1][col_0 - 1] = temp_n
            self.add_open_list(temp_s)
        # right
        if col_0 < self.max_col:
            temp_n = Sn[row_0 - 1][col_0]
            temp_s = dp(Sn)
            temp_s[row_0 - 1][col_0] = 0
            temp_s[row_0 - 1][col_0 - 1] = temp_n
            self.add_open_list(temp_s)

在可走的状态都放入open表中后,我们需要:

  • 判断f(n)值最小的走法,找到最佳状态
  • 将最佳状态写入close表的最后
  • 从open表中删除这一状态,以免重复走同一步
  • 返回移动后的新状态
		# get best move
		fns = self.open_list['fn']
		best_idx = fns.index(min(fns))
		new_s = self.open_list['Sn'][best_idx]
		# update list
		self.close_list['Sn'].append(new_s)
		self.close_list['gn'].append(self.open_list['gn'][best_idx])
		self.close_list['hn'].append(self.open_list['hn'][best_idx])
		self.close_list['fn'].append(fns[best_idx])
		for key in self.open_list.keys():
		    self.open_list[key].pop(best_idx)
		return new_s

以上定义的move函数实现了单步的判断和行走,我们需要用一个主函数来调用他走完所有的状态,最终达到目标状态。其中搜索停止的条件是h(n)值为0,即目前状态与目标状态一致。

    def run(self):
        Sn = self.S0
        start = time.clock()
        while self.close_list['hn'][-1]:
            Sn = self.move(Sn)
            self.nodes += 1
        elapsed = time.clock() - start
        print(f'步数: {self.gn}, 访问节点数: {self.nodes}, 耗时: {elapsed}')
        return self.close_list

最后实例化AStar:

if __name__ == '__main__':
    method = AStar(a, b)
    method.run()

运行得到结果如下,最佳搜索步数为22步,访问过的总结点数为983。

步数: 22, 访问节点数: 983, 耗时: 0.283306

以上就是A*搜索算法的Python实现,最后附上完整代码

# A-star searching method
from copy import deepcopy as dp
import time

a = [[1, 3, 4],
     [2, 6, 0],
     [7, 5, 8]]
b = [[3, 0, 2],
     [6, 1, 5],
     [8, 7, 4]]

class AStar:
    def __init__(self, S0, G):
        self.S0 = S0
        self.G = G
        self.max_row = len(S0)
        self.max_col = len(S0[0])
        self.gn = 0
        self.nodes = 0
        self.close_list = {'Sn': [S0],
                           'gn': [0],
                           'hn': [self.get_h(S0)],
                           'fn': [self.get_h(S0)]}
        self.open_list = {'Sn': [], 'gn': [], 'hn': [], 'fn': []}

    def run(self):
        Sn = self.S0
        start = time.clock()
        while self.close_list['hn'][-1]:
            Sn = self.move(Sn)
            self.nodes += 1
        elapsed = time.clock() - start
        print(f'步数: {self.gn}, 访问节点数: {self.nodes}, 耗时: {elapsed}')
        return self.close_list

    def move(self, Sn):
        # restore gn
        self.gn = self.close_list['gn'][-1]
        self.gn += 1
        row_0, col_0 = self.get_loc(Sn, 0)
        # up
        if row_0 > 1:
            temp_n = Sn[row_0 - 2][col_0 - 1]
            temp_s = dp(Sn)
            temp_s[row_0 - 2][col_0 - 1] = 0
            temp_s[row_0 - 1][col_0 - 1] = temp_n
            self.add_open_list(temp_s)
        # down
        if row_0 < self.max_row:
            temp_n = Sn[row_0][col_0 - 1]
            temp_s = dp(Sn)
            temp_s[row_0][col_0 - 1] = 0
            temp_s[row_0 - 1][col_0 - 1] = temp_n
            self.add_open_list(temp_s)
        # left
        if col_0 > 1:
            temp_n = Sn[row_0 - 1][col_0 - 2]
            temp_s = dp(Sn)
            temp_s[row_0 - 1][col_0 - 2] = 0
            temp_s[row_0 - 1][col_0 - 1] = temp_n
            self.add_open_list(temp_s)
        # right
        if col_0 < self.max_col:
            temp_n = Sn[row_0 - 1][col_0]
            temp_s = dp(Sn)
            temp_s[row_0 - 1][col_0] = 0
            temp_s[row_0 - 1][col_0 - 1] = temp_n
            self.add_open_list(temp_s)
        # get best move
        fns = self.open_list['fn']
        best_idx = fns.index(min(fns))
        new_s = self.open_list['Sn'][best_idx]
        # update list
        self.close_list['Sn'].append(new_s)
        self.close_list['gn'].append(self.open_list['gn'][best_idx])
        self.close_list['hn'].append(self.open_list['hn'][best_idx])
        self.close_list['fn'].append(fns[best_idx])
        for key in self.open_list.keys():
            self.open_list[key].pop(best_idx)
        return new_s

    def add_open_list(self, temp_s):
        if temp_s in self.close_list['Sn']:
            return 0
        else:
            self.open_list['Sn'].append(temp_s)
            self.open_list['gn'].append(self.gn)
            self.open_list['hn'].append(self.get_h(temp_s))
            self.open_list['fn'].append(self.gn + self.get_h(temp_s))

    def get_h(self, Sn):
        h = 0
        for i in Sn:
            for j in i:
                if not j:
                    continue
                temp = j
                row_n, col_n = self.get_loc(Sn, temp)
                row_g, col_g = self.get_loc(self.G, temp)
                h = h + abs(row_n - row_g) + abs(col_n - col_g)
        return h

    @staticmethod
    def get_loc(array, num):
        for i in array:
            for j in i:
                if j == num:
                    row = array.index(i) + 1
                    col = i.index(j) + 1
                    return row, col
        return None

if __name__ == '__main__':
    method = AStar(a, b)
    method.run()

你可能感兴趣的:(python算法,python,启发式算法)