牛客网暑期ACM多校训练营(第十场)Rikka with Line Graph(floyd)

题目链接:https://www.nowcoder.com/acm/contest/148#question

 

题目大意:给一个图,得出这个图对应的线图,线图里的每个点表示着原图的一条边,如果原图中有两条边有公共点,那么对应在线图中这两条边对应的点之间有连线。线图中的边边权是两个顶点对应边的边权的和,线图是完全图,对线图中对所有点之间求最短路的和。

 

题目思路:首先我们可以知道,假设线图中1-2-3相连,1表示原图中1-2,2表示原图中2-3,3表示原图中3-4,那么线图中1-3的路径就是线图中1-2的边权+2-3的边权=(1-2+2-3)+(2-3+3-4)所以可以发现,其实线图中两个点之间的最短路值就是两个点对应的边权值+中间最短路权值*2。两个点对应的边权值对答案的影响很好求,因为是完全图,每个点会与其他n-1个点相连,由于会重复,所以一共有n*(n-1)/2条边,减去自己就是会与n*(n-1)/2-1条边相连,所以就让每条边的权值*(n*(n-1)/2-1)即可,同时对于a-b c-d两条边,这两边之间的最短路是a-c a-d b-c b-d中的一种,如果枚举四个点会很麻烦,这里我们就只用枚举两个点,也就是对于一条边来说其他点到这条边的距离。由于是完全图,所以每条边都会与其他边有交集。同时我们可以发现,如果当前跟a-b最近的点是k,那么所有跟k相连的边到a-b的最近距离都是k到a-b的最近距离,所以可以推断其实最近距离就是k到a-b距离乘以大于k的点的数量即可。在这里需要用到排序操作,时间复杂度达到n^3logn,所以我们可以用一个小技巧,就是先预处理出每一个点到其他所有点的距离(n^2logn),然后算跟a-b距离最近的点的时候直接归并二者的到其他点距离就行。这里要注意已经用过的点就不能再用了,所以就需要用vis标记一下。

 

以下是代码:

n^3logn解法:

#include
using namespace std;
#define inf 0x3f3f3f3f
#define MAXN 505
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=a;i>=b;i--)
#define ll long long
#define MOD 998244353
int map1[MAXN][MAXN];
ll dist[MAXN][MAXN],a[MAXN];
int main(){
    int t,n;
    scanf("%d",&t);
    while(t--){
        scanf("%d",&n);
        rep(i,1,n){
            rep(j,1,n){
                scanf("%d",&dist[i][j]);
                map1[i][j]=dist[i][j];
            }
        }
        rep(k,1,n){
            rep(i,1,n){
                rep(j,1,n){
                   dist[i][j]=min(dist[i][j],(ll)dist[i][k]+dist[k][j]);
                }
            }
        }
        ll ans=0,num=(ll)n*(n-1)/2-1;
        rep(i,1,n){
            rep(j,i+1,n){
                rep(k,1,n){
                    a[k]=min(dist[j][k],dist[i][k]);
                }
                sort(a+1,a+n+1);
                rep(k,1,n){
                    ans=((a[k]*(n-k)%MOD)+ans)%MOD;
                }
                ans=((map1[i][j]*num%MOD)+ans)%MOD;
            }
        }
        ans=(ans+MOD)%MOD;
        printf("%lld\n",ans);
    }
    return 0;
}

n^3解法:

#include
using namespace std;
#define inf 0x3f3f3f3f
#define MAXN 505
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=a;i>=b;i--)
#define ll long long
#define MOD 998244353
int map1[MAXN][MAXN],dist[MAXN][MAXN],pd[MAXN<<1];
pairp[MAXN][MAXN],p2[MAXN<<1];
int main()
{
    int t,n;
    scanf("%d",&t);
    while(t--){
        scanf("%d",&n);
        ll ans=0;
        rep(i,1,n){
            rep(j,1,n){
                scanf("%d",&dist[i][j]);
                map1[i][j]=dist[i][j];
            }
        }
        rep(k,1,n){
            rep(i,1,n){
                rep(j,1,n){
                    dist[i][j]=min(dist[i][j],dist[i][k]+dist[k][j]);
                }
            }
        }
        rep(i,1,n){
            rep(j,1,n){
                p[i][j]=pair(dist[i][j],j);
            }
            sort(p[i]+1,p[i]+n+1);
        }
        int vis[MAXN];
        rep(i,1,n){
            rep(j,i+1,n){
                int num=0;
                memset(vis,0,sizeof(vis));
                merge(p[i]+1,p[i]+n+1,p[j]+1,p[j]+n+1,p2);
                rep(k,0,2*n-1){
                    if(vis[p2[k].second])continue;
                    ans=(ans+(ll)p2[k].first*(n-(++num)))%MOD;
                    vis[p2[k].second]=1;
                }
            }
        }
        ll bian=(ll)n*(n-1)/2-1;
        rep(i,1,n){
            rep(j,i+1,n){
                ans=(ans+map1[i][j]*bian)%MOD;
            }
        }
        ans=(ans+MOD)%MOD;
        printf("%lld\n",ans);
    }
    return 0;
}

 

你可能感兴趣的:(最短路)