1.spfa计算出最短路
2.如果是最短路上的边的话 uv之间建一条容量为1的边 dist[v]=dist[u]+cost[u][v]就能说明uv是最短路上的一条边
3.跑一边sap
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> #include<queue> using namespace std; const int MAXN=2222; const int MAXM=333333; const int inf=1<<30; struct Edge { int v,cap,next; }edge[MAXM]; int head[MAXN],pre[MAXN],cur[MAXN],level[MAXN],gap[MAXN]; int NE,NV,n,m; void init() { NE=0; memset(head,-1,sizeof(head)); } void add(int u,int v,int cap) { edge[NE].v=v;edge[NE].cap=cap; edge[NE].next=head[u];head[u]=NE++; edge[NE].v=u;edge[NE].cap=0; edge[NE].next=head[v];head[v]=NE++; } int SAP(int vs,int vt,int NV) { memset(pre,-1,sizeof(pre)); memset(level,0,sizeof(level)); memset(gap,0,sizeof(gap)); for(int i=0;i<NV;i++)cur[i]=head[i]; int u=pre[vs]=vs,maxflow=0,aug=-1; gap[0]=NV; while(level[vs]<NV){ loop: for(int &i=cur[u];i!=-1;i=edge[i].next){ int v=edge[i].v; if(edge[i].cap&&level[u]==level[v]+1){ aug==-1?aug=edge[i].cap:aug=min(aug,edge[i].cap); pre[v]=u; u=v; if(v==vt){ maxflow+=aug; for(u=pre[u];v!=vs;v=u,u=pre[u]){ edge[cur[u]].cap-=aug; edge[cur[u]^1].cap+=aug; } aug=-1; } goto loop; } } int minlevel=NV; for(int i=head[u];i!=-1;i=edge[i].next){ int v=edge[i].v; if(edge[i].cap&&minlevel>level[v]){ cur[u]=i; minlevel=level[v]; } } gap[level[u]]--; if(gap[level[u]]==0)break; level[u]=minlevel+1; gap[level[u]]++; u=pre[u]; } return maxflow; } struct Node { int v,cap; }; vector<Node>vet[MAXN]; int dist[MAXN]; bool vis[MAXN]; void spfa(int vs,int vt) { memset(vis,0,sizeof(vis)); for(int i=1;i<=n;i++) dist[i]=inf; dist[vs]=0; queue<int>q; q.push(vs); while(!q.empty()) { int u=q.front(); q.pop(); vis[u]=false; for(int i=0;i<vet[u].size();i++) { int v=vet[u][i].v; int w=vet[u][i].cap; if(dist[u]+w<dist[v]) { dist[v]=dist[u]+w; if(!vis[v]) { vis[v]=1; q.push(v); } } } } } int main() { //freopen("C:\\Users\\Administrator\\Desktop\\input.txt","r",stdin); int T,u,v,d,vs,vt; scanf("%d",&T); while(T--) { scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) vet[i].clear(); for(int i=1;i<=m;i++) { scanf("%d%d%d",&u,&v,&d); if(u==v) continue; Node p; p.v=v,p.cap=d; vet[u].push_back(p); } scanf("%d%d",&vs,&vt); spfa(vs,vt); NE=0,NV=n; memset(head,-1,sizeof(head)); for(int i=1;i<=n;i++) { for(int j=0;j<vet[i].size();j++) { if(dist[vet[i][j].v]==dist[i]+vet[i][j].cap) add(i,vet[i][j].v,1); } } printf("%d\n",SAP(vs,vt,NV)); } return 0; }