N个点,对于每个点i,都有一条连向i+1的有向边,另外有M条其他的有向边(姑且称作小道边),有Q个询问 (u,v) 求 u 到 v 的最短路。
这题真的是思考了好久,也参考了很多大牛的博客,终于把这题给过了。
这题让我理解到了,原来树状数组也可以维护最大值和最小值,真的是长见识了。
以下解析,可能是这题网络上最详细的讲解了。
刚一看感觉是最短路的题,但是再一想,有 n−1 条有向边,你可以随意用但是那些小道边,且只可以用一次,所以就是对于特定的问题,根据那些小道边对于每个点的距离进行更新。
为了防止读者在阅读的过程中出现混乱,先定义如下变量
首先可以预处理出前缀和 sum 数组,这样方便询问 (u,v) 之间的距离。
对于 u<v 和 u>v 的这两种情况,则需要进行分类讨论。
(1) 对于 u<v 的这种情况,即 u−>v 的情况
可以首先可以得到的两点间的距离为 dist(u,v)=sum[v]−sum[u]
那么要求 u−>v 最小的解,就是求一些小道边可以节省最大路程是多少的,因为那些小道只可以用一次,所以对这个结果产生影响的只有小道边 (u′,v′) ,其中 u<=u′<v′<=v
可以对于所有询问,和所有小道边进行离散处理,并一块进行排序。
可以对于 (u,v) 这条边 u 大的排前,如果 u 一样的那么 v 小的排前。
这样就可以保证在查询到 (u,v) 之前,所有的 (u′,v′) 且 u<=u′<v′<=v 都已经插入好,
同时,对于小道边 (u′,v′) , u′<u 或者 u′=u且v′>v 的这样的边还没插入。
然后对于一个小道边,则对1 到 u’ 进行更新,更新的值为
dist(u′,v′)−cost(u′,v′)=(sum[v′]−sum[u′])−cost(v′,u′) ,
代表这条边对于这个区间内可以节省的最大的值,
对于每次询问,则直接在线段树中查找在u的位置的最大值,那么总距离就最小,结果为
dist(u,v)−query(v)=sum[v]−sum[u]−query(v)
(2) 对于 u>v 这种情况,即 v−>u 的情况
类似于上面的分析,对于询问 (u,v) ,
因为只有一个小道可以选择,那么对其产生影响的边只有 u′ 和 v′ ,其中 v′<=v<u<=u′ ,
首先进行同 (1) 的一样的排序,这样保证了在查询到 (u,v) 之前,所有 (u′,v′) v′<=v<u<=u′ 的都已经插入好,
同时, (u′,v′) v<v′<u′<u 那些情况不能插入。
所以对于每次询问 v−>u ,插入每条小道边 v′−>u′ ,之后的边长为
dist(v′,u′)+cost(v′,u′)−dist(v,u)=(sum[u′]−sum[v′])+cost(v′,u′)−(sum[u]−sum[v])
由于 dist(v,u) 是个定值,所以在询问的时候 dist(v,u) 是不起作用的,相当于一个常数,我们需要的只是 dist(v′,u′)+cost(v′,u′) 最小,这样才能使得总路程最小,所以更新的时候只需要维护这个就好了。
然后对于一个小道边,则对1 到 v’ 进行更新,更新的值为 dist(v′,u′)+cost(v′,u′)
对于每次询问,则直接在线段树中查找在v的位置的最小值,那么总距离就最小,
结果为 query(v)−dist(u,v)
至此圆满结束。
注意虽然题目说最终结果不会超过int,但是中间过程可能会爆int,所以要开long long。
my code
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const ll INF = (1LL<<63)-1;
const int N = 100005;
const int M = 400005;
int n, m, q;
struct Edge {
int kind, id;
int u, v;
ll cost;
} arr[M];
int top;
ll ans[M], sum[N];
struct BIT {
ll C[N];
void clear(ll val) { fill(C, C+N, val); }
inline int lowbit(int x) { return x & (-x); }
ll query(int x) {
ll ret = INF;
while(x > 0) {
ret = min(ret, C[x]);
x -= lowbit(x);
}
return ret;
}
void update(int x, ll val) {
while(x <= n) {
C[x] = min(C[x], val);
x += lowbit(x);
}
}
} bit;
bool cmp(Edge a, Edge b) {
if(a.u != b.u)
return a.u > b.u;
else {
if(a.kind != b.kind)
return a.kind < b.kind;
return a.v < b.v;
}
}
void prepare() {
int u, v;
ll cost;
sum[0] = sum[1] = 0;
for(int i = 2; i <= n; i++) {
scanf("%lld", &sum[i]);
sum[i] += sum[i-1];
}
top = 0;
for(int i = 0; i < m; i++) {
scanf("%d%d%lld", &u, &v, &cost);
arr[top].kind = 0;
arr[top].u = u, arr[top].v = v, arr[top].cost = cost;
top++;
}
scanf("%d", &q);
for(int i = 0; i < q; i++) {
scanf("%d%d", &u, &v);
arr[top].kind = 1;
arr[top].u = u, arr[top].v = v, arr[top].id = i;
top++;
}
sort(arr, arr+top, cmp);
}
void solve() {
int u, v, kind, id;
ll cost;
memset(ans, 0, sizeof(ans));
bit.clear(0);
for(int i = 0; i < top; i++) {
u = arr[i].u, v = arr[i].v, kind = arr[i].kind;
if(u < v && kind == 0) {
cost = arr[i].cost;
bit.update(v, cost - (sum[v] - sum[u]));
}else if(u < v && kind == 1) {
id = arr[i].id;
ans[id] = bit.query(v) + (sum[v] - sum[u]);
}
}
bit.clear(INF);
for(int i = 0; i < top; i++) {
u = arr[i].u, v = arr[i].v, kind = arr[i].kind;
if(u > v && kind == 0) {
cost = arr[i].cost;
bit.update(v, (sum[u] - sum[v]) + cost);
}else if(u > v && kind == 1) {
id = arr[i].id;
ans[id] = bit.query(v) - (sum[u] - sum[v]);
}
}
}
void output() {
for(int i = 0; i < q; i++)
printf("%lld\n", ans[i]);
}
int main() {
while(~scanf("%d%d", &n, &m)) {
prepare();
solve();
output();
}
return 0;
}