∙ \bullet ∙在最朴素的迪杰斯特拉中,我们每次都要跑一层循环来找到最小的 d [ i ] d[ i ] d[i]( d [ i ] d[ i ] d[i]代表起点到 i i i这个点的最小距离)然后再更新与 i i i点有边相连的没有被走过的点 j j j的 d [ j ] d[j] d[j]
void dij(int start)
{
d[start]=0;
for(int i=1;i<=n;i++){
int minpos=-1;
for(int j=1;j<=n;j++){
if((vis[j]==0)&&(minpos<0||d[j]<d[minpos]))
minpos=j;
}
vis[minpos]=1;
if(d[minpos]==inf)break;
for(int i=head[minpos];i!=-1;i=edge[i].next)
if(d[minpos]+edge[i].val<d[edge[i].to])
d[edge[i].to]=d[minpos]+edge[i].val;
}
}
∙ \bullet ∙优先队列优化就是省去了跑一层循环来找到最小的 d [ i ] d[ i ] d[i]
首先需要一个结构体,这个结构体的作用就相当于数组 d [ i ] d[ i ] d[i]的作用,结构体里的 i d id id=数组下标 i i i,结构体里的 d d d= d [ i ] d[i] d[i],就是换了一种形式表示,但是原始的 d [ i ] d[ i ] d[i]数组还是要保留的,因为它存的是每个点到起点的最小距离。然后再用一个优先队列,优先队列里面装的是 d i s dis dis结构体,因为优先队列是要对里面的元素进行自动排序,并且排序要用到 < < <号,所以在结构体里重载了小于号,我写的这个重载小于号就代表 d d d小的值放在优先队列的前端,大的值放在后端。
重载完了以后就只需要将以 d i s dis dis为类型的结构体加入队列,每次取出队首元素就找到了最小的 d [ i ] d[ i ] d[i]。
struct dis
{
int id,d;
dis(int id,int d):id(id),d(d){}
bool operator < (const dis &a)const//重载小于号
{
return d>a.d;
}
};priority_queue<dis>q;
输入 n , m n,m n,m, n n n代表几个点(编号从0开始), m m m代表几条边,接下来 m m m行输入 a , b , c a,b,c a,b,c,代表 a , b a,b a,b之间有一条边,权值 c c c,再输入 s , t s,t s,t,分别代表起点和终点,问起点到终点的最短距离是多少,如果不存在最短距离,输出-1。下面贴出完整的代码:
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
const int inf=0x3f3f3f3f;
const int maxn = 2e4+10;
int d[maxn],n,m,vis[maxn];
int s,t;
struct dis
{
int id,d;
dis(int id,int d):id(id),d(d){}
bool operator < (const dis &a)const
{
return d>a.d;
}
};priority_queue<dis>q;
struct node{int to,val,next;}e[maxn];//链式向前星存图
int head[maxn],cnt=0;
void add(int s,int E,int val)
{
e[++cnt]={E,val,head[s]};
head[s]=cnt;
}
int dij(int start)
{
int flag=0;
q.push(dis(start,0));
d[start]=0;
while(!q.empty()){
dis a=q.top();q.pop();//取出最小点
int now=a.id,minval=a.d;
vis[now]=1;
for(int i=head[now];i!=-1;i=e[i].next){
int to=e[i].to,val=e[i].val;
if(!vis[to]&&minval+val<d[to]){
d[to]=minval+val;
q.push(dis(to,d[to]));
}
}
}
if(d[t]!=inf)return d[t];
else return -1;
}
void init()
{
for(int i=1;i<=n;i++)d[i]=inf;
memset(head,-1,sizeof(head));
memset(e,0,sizeof(e));
memset(vis,0,sizeof(vis));
while(!q.empty())q.pop();
cnt=0;
}
int main()
{
while(scanf("%d%d",&n,&m)!=EOF)
{
init();
for(int i=1;i<=m;i++){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
a++;b++;
add(a,b,c);
add(b,a,c);
}
scanf("%d%d",&s,&t);
s++;t++;
printf("%d\n",dij(s));
}
return 0;
}