A星(AStar)算法的实现

关于AStar的原理这里简述一下,

首先有一张地图,然后准备一个open list 和 close list,open list存放所有可能的路径,但是需要注意的是这个列表是动态怎加的,也就是每走一步就把当前可能的路径都加进去,然后每次从open list中去除一个最小代价的点,最为下一步的路径,并用该点计算之后可能的路径 在加入到open list中去,然后把这个代价最小的点放到close list中,表示该点已经走过了,如此循环直到从open list取出的最小代价点为 终点,此时通过回溯找到路径;代价的计算方法具有多样性,如何计算可以自行百度;

首先是python的实现:

import numpy as np
import heapq


class Node:

    def __init__(self, coord, parent=None, g=0, h=0):
        self.coord = coord
        self.parent: Node = parent
        self.g = g
        self.h = h
        self.f = g + h
        self.master = 0
        # self.iter = self.gen_data()

    # def gen_data(self):
    #     cur_node = self
    #     while 1:
    #         if cur_node is None:
    #             break
    #         yield cur_node
    #         cur_node = cur_node.parent
    #
    # def __iter__(self):
    #     return self.iter

    def __gt__(self, other):
        return self.f > other.f

    def __lt__(self, other):
        return self.f < other.f

    def __eq__(self, other):
        return self.f == other.f

    def __str__(self):
        return "<(%s,%s) G: %s,H: %s,F: %s> " % (self.coord[0], self.coord[1], self.g, self.h, self.f)

    __repr__ = __str__


class Walker:
    def __init__(self, node):
        self.node = node
        self.iter = self.gen_data()

    def gen_data(self):
        cur_node = self.node
        while 1:
            if cur_node is None:
                break
            yield cur_node
            cur_node = cur_node.parent

    def __iter__(self):
        return self.iter


class AStar:
    def __init__(self, world_map, bar_value, direct_cost, oblique_cost):
        self.world_map = world_map
        self.bar_value = bar_value
        self.direct_cost = direct_cost
        self.oblique_cost = oblique_cost
        self.map_cache = np.zeros_like(world_map)

    def new_node(self, prev_node: Node, end_node, x, y, weight):
        coord = x, y
        height, width = self.world_map.shape
        if x < 0 or x >= width or y < 0 or y >= height:
            return
        color = self.world_map[y, x]
        if color == self.bar_value:
            return
        H = abs(end_node.coord[0] - x) * self.direct_cost + abs(end_node.coord[1] - y) * self.direct_cost
        G = prev_node.g + weight
        t_node = Node(coord, prev_node, G, H)
        return t_node

    def get_neighbors(self, p_node, end_node):
        coord = p_node.coord
        up = self.new_node(p_node, end_node, coord[0], coord[1] - 1, self.direct_cost)
        down = self.new_node(p_node, end_node, coord[0], coord[1] + 1, self.direct_cost)
        left = self.new_node(p_node, end_node, coord[0] - 1, coord[1], self.direct_cost)
        right = self.new_node(p_node, end_node, coord[0] + 1, coord[1], self.direct_cost)

        # return up, down, left, right
        left_up = self.new_node(p_node, end_node, coord[0] - 1, coord[1] - 1, self.oblique_cost)
        right_up = self.new_node(p_node, end_node, coord[0] + 1, coord[1] - 1, self.oblique_cost)
        left_down = self.new_node(p_node, end_node, coord[0] - 1, coord[1] + 1, self.oblique_cost)
        right_down = self.new_node(p_node, end_node, coord[0] + 1, coord[1] + 1, self.oblique_cost)

        return up, down, left, right, left_up, right_up, left_down, right_down

    def find_path(self, start_node, end_node):
        open_ls = []
        close_ls = []
        self.map_cache[:, :] = 0
        heapq.heappush(open_ls, start)
        self.map_cache[start_node.coord[1], start_node.coord[0]] = 1
        while 1:
            if len(open_ls) == 0:
                print("failed!")
                break
            cur_node: Node = heapq.heappop(open_ls)
            if cur_node.coord == end_node.coord:
                print("success")
                return cur_node
            for node in self.get_neighbors(cur_node, end_node):
                if node is None:
                    continue
                if self.map_cache[node.coord[1], node.coord[0]] != 0:
                    continue
                heapq.heappush(open_ls, node)
                self.map_cache[node.coord[1], node.coord[0]] = 1
            close_ls.append(cur_node)
            self.map_cache[cur_node.coord[1], cur_node.coord[0]] = 2


# #
if __name__ == '__main__':
    import time
    import cv2
    default_bar_value = 70
    default_set_path = 255
    DIRECT_WEIGHT = 10
    OBLIQUE_WEIGHT = 14
    # create a map
    maps = np.zeros((650, 750), np.intc) + 1
    maps[40:, 20:30] = default_bar_value
    maps[:400, 100:110] = default_bar_value
    maps[100:, 200:210] = default_bar_value
    maps[:200, 300:310] = default_bar_value
    maps[220:230, 210:710] = default_bar_value
    maps[220:600, 710:720] = default_bar_value
    maps[600:610, 300:720] = default_bar_value
    maps[300:610, 300:310] = default_bar_value

    start = Node((10, 10))  # start coord
    end = Node((600, 400))  # end coord
    finder = AStar(maps, default_bar_value, DIRECT_WEIGHT, OBLIQUE_WEIGHT)
    t0 = time.time()
    node = finder.find_path(start, end)
    print("耗时:", time.time()-t0)

    for node in Walker(node):
        maps[node.coord[1], node.coord[0]] = default_set_path

    maps = maps.astype(np.uint8)
    maps = maps.reshape((*maps.shape, 1))
    maps = maps.repeat(3, 2)
    cv2.circle(maps, tuple(start.coord), 5, (0, 255, 0), 5)
    cv2.circle(maps, tuple(end.coord), 5, (255, 0, 0), 5)
    maps[maps[:, :, 0] == default_set_path] = 50, 255, 50
    maps[maps[:, :, 0] == default_bar_value] = 0, 0, 255
    cv2.imshow("result", maps)
    cv2.waitKey()
    cv2.destroyAllWindows()

需要注意的这里为了提高性能,省去了一些步骤,所以个标准的算法略微区别

一下是c++的实现:

// dllmain.cpp : 定义 DLL 应用程序的入口点。
#include "pch.h"
#include 
#include 
#include 
#include 

using namespace std;

struct Coord
{
    int x;
    int y;

    bool operator==(const Coord& c) {
        return x == c.x && y == c.y;
    }
    Coord() {
        x = 0;
        y = 0;
    }
    Coord(int x, int y) {
        this->x = x;
        this->y = y;
    }
};

struct Point
{
    int x;
    int y;
    Point* next;

    Point(int x, int y) {
        this->x = x;
        this->y = y;
        next = nullptr;
    }
};

struct Node
{
    Coord coord;
    Node* parent;
    int g;
    int h;
    int f;
    Node(Coord coord) {
        this->coord.x = coord.x;
        this->coord.y = coord.y;
        this->parent = 0;
        this->g = 0;
        this->h = 0;
        this->f = 0;
    }
    Node(Coord coord, Node* parent) {

        this->coord.x = coord.x;
        this->coord.y = coord.y;
        this->parent = parent;
        this->g = 0;
        this->h = 0;
        this->f = 0;
    }
    Node(Coord coord, Node* parent, int g, int h) {

        this->coord.x = coord.x;
        this->coord.y = coord.y;
        this->parent = parent;
        this->g = g;
        this->h = h;
        this->f = g + h;
    }
};

class HeapCompare_f
{
public:

    bool operator() (const Node* x, const Node* y) const
    {
        return x->f > y->f;
    }
};

class AStar
{
public:
    int* world_map;
    int width;
    int height;
    int bar_value;
    int direct_cost;
    int oblique_cost;
    int set_path_value;
    Node** neighbors;
    Node** map_cache;
    vector open_ls;
    vector close_ls;
    AStar(int* world_map, int width, int height, int bar_value, int dircect_cost, int oblique_cost, int set_path_value);
    ~AStar();
    Node* new_node(Node* parent, Coord end, int x, int y, int cost);
    void get_neighbors(Node* parent, Coord coord);
    Point* find_path(Coord start, Coord end);
    void free_ls();
};

AStar::AStar(int* world_map, int width, int height, int bar_value, int direct_cost, int oblique_cost, int set_path_value)
{
    this->world_map = world_map;
    this->width = width;
    this->height = height;
    this->bar_value = bar_value;
    this->direct_cost = direct_cost;
    this->oblique_cost = oblique_cost;
    this->map_cache = new Node * [(size_t)width * (size_t)height];
    this->neighbors = new Node * [8];
    this->set_path_value = set_path_value;
}


AStar::~AStar()
{
    delete[] map_cache;
    delete[] neighbors;
}

Node* AStar::new_node(Node* parent, Coord end, int x, int y, int cost)
{
    if (x < 0 || x >= width || y < 0 || y >= height)
    {
        return nullptr;
    }
    if (world_map[y * width + x] >= bar_value) {
        return nullptr;
    }
    int H = abs(end.x - x) * direct_cost + abs(end.y - y) * direct_cost;
    int G = parent->g + cost;
    Node* node = new Node(Coord(x, y), parent, G, H);
    return node;
}

void AStar::get_neighbors(Node* parent, Coord end)
{
    Coord cd = parent->coord;
    Node* up, * down, * left, * right, * left_up, * right_up, * left_down, * right_down;
    up = new_node(parent, end, cd.x, cd.y - 1, direct_cost);
    down = new_node(parent, end, cd.x, cd.y + 1, direct_cost);
    left = new_node(parent, end, cd.x - 1, cd.y, direct_cost);
    right = new_node(parent, end, cd.x + 1, cd.y, direct_cost);
    left_up = new_node(parent, end, cd.x - 1, cd.y - 1, oblique_cost);
    right_up = new_node(parent, end, cd.x + 1, cd.y - 1, oblique_cost);
    left_down = new_node(parent, end, cd.x - 1, cd.y + 1, oblique_cost);
    right_down = new_node(parent, end, cd.x + 1, cd.y + 1, oblique_cost);
    neighbors[0] = up;
    neighbors[1] = down;
    neighbors[2] = left;
    neighbors[3] = right;
    neighbors[4] = left_up;
    neighbors[5] = right_up;
    neighbors[6] = left_down;
    neighbors[7] = right_down;
}

void AStar::free_ls() {
    typename vector< Node* >::iterator iter_node;
    for (iter_node = open_ls.begin(); iter_node != open_ls.end(); iter_node++)
    {
        //if (*iter_node != nullptr)
        delete (*iter_node);
    }
    for (iter_node = close_ls.begin(); iter_node != close_ls.end(); iter_node++)
    {
        //if (*iter_node != nullptr)
        delete (*iter_node);
    }
    open_ls.clear();
    close_ls.clear();
}

Point* AStar::find_path(Coord start, Coord end)
{
    memset(map_cache, 0, (size_t)width * (size_t)height);
    Node* start_node = new Node(start);
    open_ls.push_back(start_node);
    make_heap(open_ls.begin(), open_ls.end(), HeapCompare_f());
    push_heap(open_ls.begin(), open_ls.end(), HeapCompare_f());
    map_cache[start.y * width + start.x] = start_node;
    Node* cur_node = nullptr;
    while (true)
    {
        if (open_ls.empty()) {
            printf("failed.\n");
            break;
        }
        cur_node = open_ls.front();
        pop_heap(open_ls.begin(), open_ls.end(), HeapCompare_f());
        open_ls.pop_back();
        if (cur_node->coord == end) {
            printf("success.\n");
            break;
        }
        get_neighbors(cur_node, end);
        for (size_t i = 0; i < 8; i++)
        {
            Node* node = neighbors[i];
            if (node == nullptr) continue;
            Node** cache_node = &map_cache[node->coord.y * width + node->coord.x];
            if (*cache_node != 0)
            {
                delete node;
                continue;
            }
            open_ls.push_back(node);
            push_heap(open_ls.begin(), open_ls.end(), HeapCompare_f());
            *cache_node = node;  //map_cache[node->coord.y * width + node->coord.x] = node;
        }
        close_ls.push_back(cur_node);
        map_cache[cur_node->coord.y * width + cur_node->coord.x] = cur_node;
    }
    Point* last = nullptr;
    if (cur_node == nullptr) goto end;
    last = new Point(cur_node->coord.x, cur_node->coord.y); //printf("x:%d,y:%d\n", cur_node->coord.x, cur_node->coord.y);
    while (true)
    {
        if (set_path_value > 0)
            world_map[cur_node->coord.y * width + cur_node->coord.x] = set_path_value;
        cur_node = cur_node->parent;
        if (!cur_node) { break; }
        Point* t = new Point(cur_node->coord.x, cur_node->coord.y);  //printf("x:%d,y:%d\n", cur_node->coord.x, cur_node->coord.y);
        t->next = last;
        last = t;
    }
end:
    free_ls();
    return last;
}



extern "C" {
    __declspec(dllexport)int CreateAStar(int* world_map, int width, int height, int bar_value, int direct_cost, int oblique_cost, int set_path_value);
    __declspec(dllexport)Point* FindPath(int id, int start_x, int start_y, int end_x, int end_y);
    __declspec(dllexport)void AStarFree(int id);
    __declspec(dllexport)void free_node_ls(Point** node_var);
}
#define MAX_NUMS 64
AStar* FINDER_OBJ_LIST[MAX_NUMS];
int CUR_INDEX = 0;
int CreateAStar(int* world_map, int width, int height, int bar_value, int direct_cost, int oblique_cost, int set_path_value) {
    std::cout << "Hello CreateAStar!\n";
    if (CUR_INDEX >= MAX_NUMS)
        CUR_INDEX = 0;
    int id = CUR_INDEX;
    if ( FINDER_OBJ_LIST[CUR_INDEX]!=0 )
    {
        int i= MAX_NUMS;
        for (i = 0; i < MAX_NUMS; i++)
        {
            if (FINDER_OBJ_LIST[i] == 0)
                break;
        }
        if (i != MAX_NUMS){
            CUR_INDEX = i;
        }
        else {
            delete FINDER_OBJ_LIST[CUR_INDEX];
        }  
    }
    FINDER_OBJ_LIST[CUR_INDEX] = new AStar(world_map, width, height, bar_value, direct_cost, oblique_cost, set_path_value);
    CUR_INDEX++;
    return id;
}

Point* FindPath(int id,int start_x,int start_y,int end_x,int end_y) {
    if (id > MAX_NUMS || id < 0) return 0;
    auto a = FINDER_OBJ_LIST[id];
    if (a == 0) return 0;
    std::cout << "Hello FindPath!\n";
    return a->find_path(Coord(start_x, start_y), Coord(end_x, end_y));
}

void AStarFree(int id) {
    if (id > MAX_NUMS||id<0) return;
    auto a = FINDER_OBJ_LIST[id];
    if (a == 0) return;
    std::cout << "Hello AStarFree!\n";
    delete a;
    FINDER_OBJ_LIST[id] = 0;
}

void free_node_ls(Point** node_var) {
    Point* node = *node_var;
    Point* next;
    for (;;)
    {
        if (!node)
            break;
        next = node->next;
        delete node;
        node = next;
        //printf("free\n");
    }
    //if (node) {
    *node_var = 0;
    //}
}

int main()
{

    int MAP_WIDTH = 20;
    int MAP_HEIGHT = 20;

    int BarValue = 9;
    int AreaValue = 1;


    int direct_cost = 10;
    int oblique_cost = 14;

    int is_direction = 0;

    int map[] = {
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,   // 00
        1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 1,   // 01
        1, 9, 9, 1, 1, 9, 9, 9, 1, 9, 1, 9, 1, 9, 1, 9, 9, 9, 1, 1,   // 02
        1, 9, 9, 1, 1, 9, 9, 9, 1, 9, 1, 9, 1, 9, 1, 9, 9, 9, 1, 1,   // 03
        1, 9, 1, 1, 1, 1, 9, 9, 1, 9, 1, 9, 1, 1, 1, 1, 9, 9, 1, 1,   // 04
        1, 9, 1, 1, 9, 1, 1, 1, 1, 9, 1, 1, 1, 1, 9, 1, 1, 1, 1, 1,   // 05
        1, 9, 9, 9, 9, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 1, 1, 1, 1, 1,   // 06
        1, 9, 9, 9, 9, 9, 9, 9, 9, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 1,   // 07
        1, 9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 1, 1, 1, 1, 1, 1, 1, 1,   // 08
        1, 9, 1, 9, 9, 9, 9, 9, 9, 9, 1, 1, 9, 9, 9, 9, 9, 9, 9, 1,   // 09
        1, 9, 1, 1, 1, 1, 9, 1, 1, 9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,   // 10
        1, 9, 9, 9, 9, 9, 1, 9, 1, 9, 1, 9, 9, 9, 9, 9, 1, 1, 1, 1,   // 11
        1, 9, 1, 9, 1, 9, 9, 9, 1, 9, 1, 9, 1, 9, 1, 9, 9, 9, 1, 1,   // 12
        1, 9, 1, 9, 1, 9, 9, 9, 1, 9, 1, 9, 1, 9, 1, 9, 9, 9, 1, 1,   // 13
        1, 9, 1, 1, 1, 1, 9, 9, 1, 9, 1, 9, 1, 1, 1, 1, 9, 9, 1, 1,   // 14
        1, 9, 1, 1, 9, 1, 1, 1, 1, 9, 1, 1, 1, 1, 9, 1, 1, 1, 1, 1,   // 15
        1, 9, 9, 9, 9, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 1, 1, 1, 1, 1,   // 16
        1, 1, 9, 9, 9, 9, 9, 9, 9, 1, 1, 1, 9, 9, 9, 1, 9, 9, 9, 9,   // 17
        1, 9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 1, 1, 1, 1, 1, 1, 1, 1,   // 18
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,   // 19
    };
    AStar a = AStar(map, MAP_WIDTH, MAP_HEIGHT, BarValue, direct_cost, oblique_cost, 7);
    a.find_path(Coord(0, 0), Coord(3, 3));
    for (int i = 0; i < MAP_HEIGHT; i++)
    {
        for (int j = 0; j < MAP_WIDTH; j++)
        {
            printf("%d  ", map[i * MAP_WIDTH + j]);
        }
        cout << endl;
    }
}

BOOL APIENTRY DllMain( HMODULE hModule,
                       DWORD  ul_reason_for_call,
                       LPVOID lpReserved
                     )
{
    switch (ul_reason_for_call)
    {
    case DLL_PROCESS_ATTACH:
    case DLL_THREAD_ATTACH:
    case DLL_THREAD_DETACH:
    case DLL_PROCESS_DETACH:
        break;
    }
    return TRUE;
}

c++的实现包好了 导出为 dll,以及测试

效果图如下:

A星(AStar)算法的实现_第1张图片

你可能感兴趣的:(c/c++,计算机视觉,Python,AStar,A星,c++,python,算法)