题意:求第k短路。
上次做图论最短路的时候,无意间看到这求第k短路的算法,然后就顺势初学一下A*算法。
以下转载自魏神:http://blog.csdn.net/sdj222555/article/details/7690081
网上大部分的方法都是用A* + 最短路的方法做的。
对于A* ,估价函数 = 当前值+当前位置到终点的距离,即 F(p)=g(p)+h(p),每次扩展估价函数值中最小的一个。对于k短路来说,g(p)为当前从s到p所走的长度,h(p)为从p到 t 的最短路的长度,则F(p)的意义就是从s按照当前路径走到 p 后要走到终点 t 一共至少要走多远。也就是说我们每次的扩展都是有方向的扩展,这样就可以提高求解速度和降低扩展的状态数目。为了加速计算,h(p)需要从A*搜索之前进行预处理,只要将原图的所有边反向,再从终点 t 做一次单源最短路径就可以得到每个点的h(p)了。
在下面这个代码中:A结构体中,v代表的是当前走到的点,f和g分别为f函数和g函数的值,每次优先搜的是f函数较小的。这样就能保证搜索出来的一定是第K小短路,并且避免了一定的不必要计算。
#include <iostream> #include <cstdio> #include <fstream> #include <algorithm> #include <cmath> #include <deque> #include <vector> #include <list> #include <queue> #include <string> #include <cstring> #include <map> #define PI acos(-1.0) #define mem(a,b) memset(a,b,sizeof(a)) #define sca(a) scanf("%d",&a) #define pri(a) printf("%d\n",a) #define MM 500002 #define MN 1002 #define INF 168430090 using namespace std; typedef long long ll; int n,m,st,en,k,cnt,tmp,vis[MN],d[MN],q[MM*5],head[MN],rhead[MN]; struct edge { int v,w,next; }e[MM],re[MM];//一个正向,一个反向 struct A { int f,g,v; bool operator < (const A a)const { if(a.f==f) return a.g<g; return a.f<f; } }; void add(int u,int v,int w) { e[cnt].v=v; e[cnt].w=w; //正向邻接表 e[cnt].next=head[u]; head[u]=cnt; re[cnt].v=u; re[cnt].w=w; //反向邻接表,为求的是h(p)函数 re[cnt].next=rhead[v]; rhead[v]=cnt++; } void spfa() { int i,l=0,r=1,u,v,w; for(i=1;i<=n;i++) d[i]=INF; mem(vis,0); q[0]=en; d[en]=0; while(l<r) //模拟队列,如果用队列可能会RE { u=q[l++]; vis[u]=0; for(i=rhead[u];i!=-1;i=re[i].next) { v=re[i].v; w=re[i].w; if(d[v]>d[u]+w) { d[v]=d[u]+w; //更新每点的最短路,h(p)在本题中就是d数组 if(!vis[v]) { q[r++]=v; vis[v]=1; } } } } } int Astar() { priority_queue<A>Q; if(st==en) k++; //这句WA了一发,因为源点与终点相同时少了本身的一次,所以要加上 if(d[st]==INF) return -1; A s1,s2; s1.v=st; s1.g=0; s1.f=s1.g+d[st]; //即f(p)=g(p)+h(p) Q.push(s1); while(!Q.empty()) { s2=Q.top(); Q.pop(); if(s2.v==en) { tmp++; if(tmp==k) return s2.g; } for(int i=head[s2.v];i!=-1;i=e[i].next) { s1.v=e[i].v; s1.g=s2.g+e[i].w; s1.f=s1.g+d[s1.v]; Q.push(s1); } } return -1; } int main() { int i,u,v,w; scanf("%d%d",&n,&m); mem(head,-1); mem(rhead,-1); for(i=0;i<m;i++) { scanf("%d%d%d",&u,&v,&w); add(u,v,w); } scanf("%d%d%d",&st,&en,&k); spfa(); printf("%d\n",Astar()); return 0; }