题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1688
题意:给出一个有重边的有向图,求给出的2个点间有多少条最短路以及和最短路的路程差1的次短路
思路:dijstra算法魔改一下,用队列进行操作,dist[i][0]记录到i点最短路,再用dist[i][1]记录次短路,cnt[i][0]和cnt[i][1]数组则记录最短路和次短路的个数。
每次进行松弛操作的时候有4种情况分别是:
1.到i的距离比记录的最短距离要短,更新dist[i][0]以及dist[i][1],入队进行松弛操作
2.到i的距离比记录的次短路要短但比最短路长,更新dist[i][1],入队进行松弛操作
3.到i的距离和记录最短路一样长,更新最短路的个数
4.到i的距离和记录次短路一样长,更新次短路的个数
#include <iostream> #include <cstdio> #include <algorithm> #include <cstring> #include <queue> #include <vector> #define maxn 10030 #define inf 0x3f3f3f3f using namespace std; struct Edge { int v,w; Edge(int a,int b):v(a),w(b){}; }; struct Node { int k,dis,pos; bool operator <(const Node & q)const { return dis>q.dis; } }st; priority_queue<Node> que; int vis[1030][2],dist[1030][2],cnt[1030][2]; vector<Edge> edge[1030]; void init() { memset(edge,0,sizeof(edge)); memset(vis,0,sizeof(vis)); memset(dist,inf,sizeof(dist)); memset(cnt,0,sizeof(cnt)); while (!que.empty()) que.pop(); } void dijstra() { que.push(st); while (!que.empty()) { Node tmp=que.top(); que.pop(); int u=tmp.pos,k=tmp.k; if (vis[u][k]) continue; else vis[u][k]=1; for (int i=0;i<edge[u].size();i++) { int dis=edge[u][i].w+tmp.dis; int v=edge[u][i].v; Node nxt; if (dis<dist[v][0]) { cnt[v][1]=cnt[v][0]; dist[v][1]=dist[v][0]; nxt.dis=dist[v][1]; nxt.pos=v; nxt.k=1; que.push(nxt); cnt[v][0]=cnt[u][0]; dist[v][0]=dis; nxt.dis=dist[v][0]; nxt.pos=v; nxt.k=0; que.push(nxt); } else if (dis==dist[v][0]) { cnt[v][0]+=cnt[u][0]; } else if (dis<dist[v][1]) { cnt[v][1]=cnt[u][k]; dist[v][1]=dis; nxt.dis=dis; nxt.pos=v; nxt.k=1; que.push(nxt); } else if (dis==dist[v][1]) { cnt[v][1]+=cnt[u][k]; } } } return; } int main() { int t,n,m; scanf("%d",&t); while (t--) { init(); scanf("%d%d",&n,&m); for (int i=0;i<m;i++) { int u,v,w; scanf("%d%d%d",&u,&v,&w); edge[u].push_back(Edge(v,w)); } int stay,end; scanf("%d%d",&stay,&end); st.pos=stay; st.dis=0; st.k=0; dist[stay][0]=0; cnt[stay][0]=1; dijstra(); int res=cnt[end][0];//cout<<dist[end][0]<<":"<<dist[end][1]<<endl; if (dist[end][1]==dist[end][0]+1) res+=cnt[end][1]; printf("%d\n",res); } }