POJ2449-第K短路,Astar,优先队列

#include <cstring>
#include <iostream>
#include <cstdio>
#include <queue>
using namespace std;

const int NN=1005;
const int MM=100010;
const int INF=0x1fffffff;

struct node{   //第一次用优先队列
   int v,g,h;
   bool operator <(node a) const
   {
       return a.g+a.h<g+h;    //Astar算法的启发函数设计
   }
};
struct Edge{
   int u,v,dis,next,next2;
}edge[MM];
int head[NN],head2[NN],h[NN];
int n,m,ecnt,S,T,k;

void addedge(int u,int v,int dis)
{
    edge[ecnt].u=u;
    edge[ecnt].v=v;
    edge[ecnt].dis=dis;
    edge[ecnt].next=head[u];
    edge[ecnt].next2=head2[v];  //next2域为了反向spfa求最短路用
    head[u]=ecnt;
    head2[v]=ecnt++;
}

bool spfa()  //反向求到T的距离,即求估计函数h[i]
{
    bool inq[NN];
    memset(inq,false,sizeof(inq));
    for (int i=1; i<=n; i++) h[i]=INF;
    h[T]=0;
    queue<int> q;
    q.push(T);
    while (!q.empty())
    {
        int u=q.front();
        q.pop();
        inq[u]=false;
        for (int i=head2[u]; i!=-1; i=edge[i].next2)
        {
            int v=edge[i].u;
            if (h[v]>h[u]+edge[i].dis)
            {
                h[v]=h[u]+edge[i].dis;
                if (!inq[v])
                {
                    inq[v]=true;
                    q.push(v);
                }
            }
        }
    }
    if (h[S]==INF) return false;
    return true;
}

int kth_shortest()
{
    if (!spfa()) return -1;   //S,T不连通时可提前结束
    if (S==T) k++;          //题意S,T相同时,S到T的最短路不是0,而这儿算了0,因此要加1
    int cou[NN];             //计算出队次数
    memset(cou,0,sizeof(cou));
    priority_queue<node> q;
    node x,y;
    x.v=S; x.g=0; x.h=h[S];
    q.push(x);
    while (!q.empty())
    {
        x=q.top();
        q.pop(); 
        if (cou[x.v]==k) continue; //出队次数已经达到k了,就不用处理了
        if (++cou[x.v]==k && x.v==T) return x.g;   //第几次出队,出队时的距离就是最几短距离
        for (int i=head[x.v]; i!=-1; i=edge[i].next)
        {
            y.v=edge[i].v;
            if (cou[y.v]==k) continue;
            y.g=x.g+edge[i].dis;
            y.h=h[y.v];
            q.push(y);
        }
    }
    return -1;  //没加时也过了,数据原因。。。
}

int main()
{
    ecnt=0;
    memset(head,-1,sizeof(head));
    memset(head2,-1,sizeof(head2));
    scanf("%d%d",&n,&m);
    int x,y,z;
    for (int i=1; i<=m; i++)
    {
        scanf("%d%d%d",&x,&y,&z);
        addedge(x,y,z);
    }
    scanf("%d%d%d",&S,&T,&k);
    printf("%d\n",kth_shortest());
    return 0;
}


你可能感兴趣的:(算法,struct,include)