题目描述
给出一张 $n$ 个点 $m$ 条边的有向图,边权为非负整数。求满足路径长度小于等于 $1$ 到 $n$ 最短路 $+k$ 的 $1$ 到 $n$ 的路径条数模 $p$ ,如果有无数条则输出 $-1$ 。
输入
第一行包含一个整数 $T$ , 代表数据组数。
接下来 $T$ 组数据,对于每组数据: 第一行包含四个整数 $N,M,K,P$ ,每两个整数之间用一个空格隔开。
接下来 $M$ 行,每行三个整数 $a_i,b_i,c_i$ ,代表编号为 $a_i,b_i$ 的点之间有一条权值为 $c_i$ 的有向边,每两个整数之间用一个空格隔开。
输出
输出文件包含 $T$ 行,每行一个整数表示答案。
样例输入
2
5 7 2 10
1 2 1
2 4 0
4 5 2
2 3 2
3 4 1
3 5 2
1 5 3
2 2 0 10
1 2 0
2 1 0
样例输出
3
-1
题解
最短路+拓扑排序+dp
首先使用堆优化Dijkstra或Spfa(不知道本题是否会卡)求出1到所有点的最短路。
由于对于所有边 $(x,y,z)$ 满足 $dis[x]+z\ge dis[y]$ ,因此超过最短路的部分不会减少。
那么我们设 $f[i][j]$ 表示到达点 $i$ 时经过的路径总长度为 $dis[i]+j\ (j \le k)$ 的方案数。那么这相当于一个新的分层图,只会在同层或向上层转移,不会像下层转移。
这就转化为图上求路径条数。首先初始化 $f[1][0]=0 $ ,跑拓扑排序的同时进行转移。
如果一个点被排到了,那么 $f$ 值即为路径条数。
如果一个点没有被排到,则说明有环连接到它,即路径条数为 $\infty$。
因此把所有 $f[n][0...k]$ 统计一下即可。
时间复杂度 $O(T(m\log n+mk))$
考场上一眼看出题解,然而卡了两个小时的常数才勉强卡进去...
考场原代码(去掉了文件操作):
#include
#include
#include
#include
#include
#define N 100010
#define M 200010
#define R register
#define pos(x , y) (x + (y) * n)
using namespace std;
typedef pair pr;
priority_queue q;
int head[N] , to[M] , len[M] , next[M] , cnt , dis[N] , vis[N];
int ll[N * 51] , rr[N * 51] , tt[M * 51] , que[N * 51] , ind[N * 51] , f[N * 51];
inline void add(int x , int y , int z)
{
to[++cnt] = y , len[cnt] = z , next[cnt] = head[x] , head[x] = cnt;
}
inline char nc()
{
static char buf[100000] , *p1 , *p2;
return p1 == p2 && (p2 = (p1 = buf) + fread(buf , 1 , 100000 , stdin) , p1 == p2) ? EOF : *p1 ++ ;
}
inline int read()
{
int ret = 0; char ch = nc();
while(!isdigit(ch)) ch = nc();
while(isdigit(ch)) ret = ((ret + (ret << 2)) << 1) + (ch ^ '0') , ch = nc();
return ret;
}
int main()
{
int T = read();
while(T -- )
{
memset(head , 0 , sizeof(head));
memset(vis , 0 , sizeof(vis));
memset(ind , 0 , sizeof(ind));
memset(f , 0 , sizeof(f));
cnt = 0;
int n = read() , m = read() , k = read() , z , ans = 0 , flag = 1;
R int p = read() , i , j , x , y , l = 1 , r = 0;
for(i = 1 ; i <= m ; ++i) x = read() , y = read() , z = read() , add(x , y , z);
memset(dis , 0x3f , sizeof(dis));
dis[1] = 0 , q.push(pr(0 , 1));
while(!q.empty())
{
x = q.top().second , q.pop();
if(vis[x]) continue;
vis[x] = 1;
for(i = head[x] ; i ; i = next[i])
if(dis[to[i]] > dis[x] + len[i])
dis[to[i]] = dis[x] + len[i] , q.push(pr(-dis[to[i]] , to[i]));
}
cnt = 0;
for(x = 1 ; x <= n ; ++x)
{
for(j = 0 ; j <= k ; ++j)
{
ll[pos(x , j)] = cnt + 1;
for(i = head[x] ; i ; i = next[i])
if(j + dis[x] + len[i] - dis[to[i]] <= k)
++ind[tt[++cnt] = pos(to[i] , j + dis[x] + len[i] - dis[to[i]])];
rr[pos(x , j)] = cnt;
}
}
f[1] = 1;
for(x = 1 ; x <= pos(n , k) ; ++x)
if(!ind[x])
que[++r] = x;
while(l <= r)
{
x = que[l ++ ];
for(i = ll[x] ; i <= rr[x] ; ++i)
{
y = tt[i];
f[y] += f[x] , ind[y] -- ;
if(f[y] >= p) f[y] -= p;
if(!ind[y]) que[++r] = y;
}
}
for(i = 0 ; i <= k ; ++i)
{
if(ind[pos(n , i)]) flag = 0;
ans = (ans + f[pos(n , i)]) % p;
}
if(flag) printf("%d\n" , ans);
else puts("-1");
}
return 0;
}