题目描述抽象来看,是指有一个有向图,问一个点经过N条边到另一个点的最短距离(边可重复走)
为了搞这题...去研究了下矩阵乘法...我不是计算机专业~~又看了下他们的离散数学教材...有一个例子是说求两点间经过N条边到达的方案数..Mtrix67的Blog的第八题讲的也是这个问题....
首先看经过N条边方案数的这个问题...也就是理解一下这个过程...用一个邻接矩阵来存图...点 ( i , j ) 代表 i 到 j 有多少条路...最初矩阵A的初始化时( i , j ) 为两点i到j直接的边数...那么A1存的实际就是每两点只经过一条边到达的方案数...那么看一下 A^2 也就是 A*A ... 做矩阵乘法时是 MutiMtrix [ i ] [ j ] = sum ( Mtrix1 [ i ] [ k ] * Matrix2 [ k ] [ j ] ) < k=1..点数> ...那么就是说枚举所有的中间点(k) sum( A 中i到k点的方案数* A中 k 到 j 的方案数) 很明显能求出每两点之间经过两条边到达的方案数...也就是说 A^2 就代表综上...同理可证A^3代表两点之间经过3条边到达的方案数...A^k代表两点之间经过k条边到达的方案数..
理解了这个例子后再来看这道题...这道题虽然也是经过多少多少条边两点到达..但求的是最短距离...求最短距离..又联想方案数的应该和矩阵有关系..很容易能想到Floyd...Floyd在求最短路径时枚举中间点..不断更新两点两点的最短距离...回想一下Floyd的更新的方程
if ( Dist [ i ] [ j ] < Dist [ i ] [ k ] + Dist [ k ] [ j ] ) Dist [ i ] [ j ] = Dist [ i ] [ k ] + Dist [ k ] [ j ]
这个表达式是不是很酷似矩阵乘法的运算式?多了一层判断再更新,把乘号变成了加号...
因为题目所给的两点最多有一条边...令一个邻接矩阵A表示两两点的初始关系...也就是题目所给的两点相连的情况..( i , j ) 是边的权值...那么( i , j )显然是 i 到 j 经过一条边最短路径长度...定义一个矩阵乘法的形式Floyd的更新方式的矩阵运算:
pp muti(pp a,pp b) { pp h; int i,j,k; for (i=1;i<=n;i++) for (j=1;j<=n;j++) h.s[i][j]=oo; for (k=1;k<=n;k++) for (i=1;i<=n;i++) for (j=1;j<=n;j++) if (h.s[i][j]>a.s[i][k]+b.s[k][j]) h.s[i][j]=a.s[i][k]+b.s[k][j]; return h; }
这个矩阵运算的方式除了更新,形式上和矩阵乘法时一样的...所以可以运用矩阵乘法的性质来二分求解...例如如果要求 s 到 e 经过N条边到达的最短距离~~实际上是求 A 矩阵做N次后 ( s, e )的值...这里直接就看成乘法来思考..也就是求A^N这个矩阵...明显的用二分的方法来解决就可以了...Mtrix67以及我前一篇文章关于这个方法已经说得很清楚了..
这道题要注意的一点就是虽然点的标号可能是1-1000...但边最多只有100个...所以点最多也就100来个..所以要把点这里处理下~~把离散的点压成从1开始连续的好处理得多...
Program:
#include<iostream> #define MAXN 106 #define ok printf("Yes %d!!\n",p) #define oo 1000000001 using namespace std; struct pp { int s[MAXN][MAXN]; }a,h; struct p1 { int x,y,k; }line[MAXN]; int n,t,s,e,i,m,x,y,k,point[1005]; bool had[1005]; pp muti(pp a,pp b) { pp h; int i,j,k; for (i=1;i<=n;i++) for (j=1;j<=n;j++) h.s[i][j]=oo; for (k=1;k<=n;k++) for (i=1;i<=n;i++) for (j=1;j<=n;j++) if (h.s[i][j]>a.s[i][k]+b.s[k][j]) h.s[i][j]=a.s[i][k]+b.s[k][j]; return h; } pp find(int p) { pp h; if (p==1) return a; h=find(p/2); h=muti(h,h); if (p%2) h=muti(h,a); return h; } int main() { while (~scanf("%d%d%d%d",&n,&t,&s,&e)) { memset(had,false,sizeof(had)); for (i=1;i<=t;i++) { scanf("%d%d%d",&line[i].k,&line[i].x,&line[i].y); had[line[i].x]=had[line[i].y]=true; } m=n; n=0; for (i=1;i<=1000;i++) if (had[i]) { n++; point[i]=n; } for (y=1;y<=n;y++) for (x=1;x<=n;x++) a.s[x][y]=oo; for (i=1;i<=t;i++) { x=point[line[i].x]; y=point[line[i].y]; k=line[i].k; a.s[x][y]=a.s[y][x]=k; } h=find(m); printf("%d\n",h.s[point[s]][point[e]]); } return 0; }