洛谷链接
树剖一下,直观上来看,是要在树上对一条链维护一段等差数列。如果维护的是区间和,每次在一段区间加上一段等差数列,这个可以直接在线段树上做不依赖任何科技,但这题的查询形式是最小值,直接做很难打标记进行维护。
将等差数列视作一次函数,考虑用李超树在链上维护一个一次函数。
预处理出每个点的深度值 dep[u]
在路径 s , t s,t s,t 上加入一条 斜率为 a,截距为 b
的直线,在 s,lca(s,t)
路径上点x
加入的数字显然是 b + a * (dep[s] - dep[x]) = -a * dep[x] + b + a * dep[s]
,用李超树将这条线段维护到树链上,路径的另外半段同样的分析方法。
由于要维护的是最小值,在李超树上每个节点要维护中点 mid
处最低的线段,这个直接改板子。由于询问形式去区间询问,考虑再维护一个区间最小值 val[rt]
,在加入一条线段之后每个节点的最小值 = min(当前区间最低势线段两端的取值,min(val[ls],val[rs]))
这题还有一个特殊的地方,就是一次函数放到了树上,进行树剖建线段树后,线段树的下标不是 x x x 而是 d f s dfs dfs 序,连续的一段 d f s dfs dfs 序不一定是一条链,而横坐标(深度值)仅在一条链上满足单调性,在整棵树上不满足单调性,维护的线段的区间在维护过程中不能有误。
线段树上的每个节点的区间被该维护的线段的区间完全覆盖,这样在树上不会存在一条跨链的一次函数,在 pushup 不容易出问题。
李超线段树的复杂度是 n log 2 n n\log^2 n nlog2n,树剖还有一个 log \log log,复杂度为 n log 3 n n \log^3 n nlog3n,但树剖和李超树的常数都非常小,可能在 1s 内通过
代码:
#include
using namespace std;
const int maxn = 1e5 + 10;
const int N = 1e5;
typedef long long ll;
ll inf = 123456789123456789ll;
int n,m,q,op,s,t;
ll a,b;
struct Line { //直线结构体
ll k,b;
Line() {}
Line(ll ki,ll bi) {
k = ki, b = bi;
}
ll calc(ll x) { //计算在 x 点的 y值
return k * x + b;
}
};
struct Graph {
int head[maxn],nxt[maxn << 1],to[maxn << 1],w[maxn << 1];
int cnt;
void init() {
memset(head,-1,sizeof head);
cnt = 0;
}
void add(int u,int v,int wi) {
to[cnt] = v;
w[cnt] = wi;
nxt[cnt] = head[u];
head[u] = cnt++;
}
}G;
int dfn[maxn],dis[maxn],son[maxn],idfn[maxn],sz[maxn],top[maxn],f[maxn],cnt;
ll dep[maxn];
void dfs1(int u,int fa) { //预处理
sz[u] = 1, f[u] = fa; son[u] = 0; dis[u] = dis[fa] + 1;
for (int i = G.head[u]; i + 1; i = G.nxt[i]) {
int v = G.to[i], w = G.w[i];
if (v == fa) continue;
dep[v] = dep[u] + w;
dfs1(v,u);
sz[u] += sz[v];
if (!son[u] || sz[v] > sz[son[u]])
son[u] = v;
}
}
void dfs2(int u,int t) { //轻重链剖分
dfn[u] = ++cnt, idfn[cnt] = u;
top[u] = t;
if (!son[u]) return ;
dfs2(son[u],t);
for (int i = G.head[u]; i + 1; i = G.nxt[i]) {
int v = G.to[i];
if (v == f[u] || v == son[u]) continue;
dfs2(v,v);
}
}
struct seg_tree { //维护 x = k 处最低线段
#define lson rt << 1,l,mid
#define rson rt << 1 | 1,mid + 1,r
int tag[maxn << 2];
ll val[maxn << 2]; //val 维护区间的最小值
Line line[maxn << 2];
void build(int rt,int l,int r) {
line[rt].k = 0; line[rt].b = inf;
val[rt] = inf;
if (l == r) return;
int mid = l + r >> 1;
build(lson); build(rson);
}
void update(int rt,int l,int r,int L,int R,Line t) {
if (L <= l && r <= R) {
int mid = l + r >> 1;
if (line[rt].calc(dep[idfn[l]]) > t.calc(dep[idfn[l]]) && line[rt].calc(dep[idfn[r]]) > t.calc(dep[idfn[r]])) {
line[rt] = t;
} else if (line[rt].calc(dep[idfn[l]]) > t.calc(dep[idfn[l]]) || line[rt].calc(dep[idfn[r]]) > t.calc(dep[idfn[r]])) {
if (line[rt].calc(dep[idfn[mid]]) > t.calc(dep[idfn[mid]])) {
Line tmp = t; t = line[rt]; line[rt] = tmp;
}
if (t.k > line[rt].k) {
update(lson,L,R,t);
} else {
update(rson,L,R,t);
}
}
val[rt] = min(val[rt],min(line[rt].calc(dep[idfn[l]]),line[rt].calc(dep[idfn[r]])));
if (l != r) val[rt] = min(val[rt],min(val[rt << 1],val[rt << 1 | 1]));
} else {
int mid = l + r >> 1;
if (L <= mid) update(lson,L,R,t);
if (mid + 1 <= R) update(rson,L,R,t);
if (l != r) val[rt] = min(val[rt],min(val[rt << 1],val[rt << 1 | 1]));
}
}
ll query(int L,int R,int rt,int l,int r) { //查询区间 L,R 最小值
if (L <= l && r <= R) return val[rt];
ll ans = inf;
ans = min(ans,min(line[rt].calc(dep[idfn[max(L,l)]]),line[rt].calc(dep[idfn[min(R,r)]])));
int mid = l + r >> 1;
if (L <= mid) ans = min(ans,query(L,R,lson));
if (mid + 1 <= R) ans = min(ans,query(L,R,rson));
return ans;
}
}seg;
int getlca(int x,int y) {
while (top[x] != top[y]) {
if (dis[top[x]] < dis[top[y]]) swap(x,y);
x = f[top[x]];
}
if (dis[x] > dis[y]) swap(x,y);
return x;
}
void update(int u,int v,ll a,ll b) {
while (top[u] != top[v]) {
if (dis[top[u]] < dis[top[v]]) swap(u,v);
seg.update(1,1,n,dfn[top[u]],dfn[u],Line(a,b));
u = f[top[u]];
}
if (dis[u] > dis[v]) swap(u,v);
seg.update(1,1,n,dfn[u],dfn[v],Line(a,b));
}
ll qry(int u,int v) {
ll ans = inf;
while (top[u] != top[v]) {
if (dis[top[u]] < dis[top[v]]) swap(u,v);
ans = min(ans,seg.query(dfn[top[u]],dfn[u],1,1,n));
u = f[top[u]];
}
if (dis[u] > dis[v]) swap(u,v);
ans = min(ans,seg.query(dfn[u],dfn[v],1,1,n));
return ans;
}
int main() {
G.init();
scanf("%d%d",&n,&m);
for (int i = 1; i < n; i++) {
int u,v,w; scanf("%d%d%d",&u,&v,&w);
G.add(u,v,w);
G.add(v,u,w);
}
dfs1(1,0); dfs2(1,1);
seg.build(1,1,n);
while (m--) {
scanf("%d%d%d",&op,&s,&t);
if (op == 1) {
scanf("%lld%lld",&a,&b);
int lca = getlca(s,t);
update(s,lca,-a,b + a * dep[s]);
update(t,lca,a,b + a * dep[s] - 2 * a * dep[lca]);
} else {
printf("%lld\n",qry(s,t));
}
}
return 0;
}