题目传送门
树链剖分模板题,积累一下模板
#include
#include
#include
#include
#include
#define N 100005
using namespace std;
int n, m, Root, MOD, cur, head_p[N], Tim;
int num[N], fa[N], top[N], dep[N], son[N], size[N], start[N], end[N], fw[N];
struct Tadj{int next, obj;} Edg[N << 1];
struct Tnode{int sum, add;} tree[N << 2];
void Init(){
cur = -1;
Tim = 0;
memset(head_p, -1, sizeof(head_p));
memset(son, -1, sizeof(son));
}
void Insert(int a, int b){
cur ++;
Edg[cur].next = head_p[a];
Edg[cur].obj = b;
head_p[a] = cur;
}
void dfs1(int root, int dad){
dep[root] = dep[dad] + 1;
fa[root] = dad;
size[root] = 1;
for(int i = head_p[root]; ~ i; i = Edg[i].next){
int v = Edg[i].obj;
if(v == dad) continue;
dfs1(v, root);
size[root] += size[v];
if(son[root] == -1 || size[v] > size[son[root]]) son[root] = v;
}
}
void dfs2(int root, int tp){
Tim ++;
start[root] = Tim;
fw[Tim] = root;
top[root] = tp;
if(~ son[root]) dfs2(son[root], tp);
for(int i = head_p[root]; ~ i; i = Edg[i].next){
int v = Edg[i].obj;
if(v == fa[root] || v == son[root]) continue;
dfs2(v, v);
}
end[root] = Tim;
}
void build(int root, int L, int R){
if(L == R){
tree[root].sum = num[fw[L]] % MOD;
return;
}
int mid = (L + R) >> 1, Lson = root << 1, Rson = root << 1 | 1;
build(Lson, L, mid);
build(Rson, mid+1, R);
tree[root].sum = (tree[Lson].sum + tree[Rson].sum) % MOD;
}
void down(int root, int L, int R){
if(tree[root].add == 0) return;
int mid = (L + R) >> 1, Lson = root << 1, Rson = root << 1 | 1;
tree[Lson].add = (tree[Lson].add + tree[root].add) % MOD;
tree[Rson].add = (tree[Rson].add + tree[root].add) % MOD;
tree[Lson].sum = (tree[Lson].sum + (tree[root].add*(mid-L+1)%MOD)%MOD) % MOD;
tree[Rson].sum = (tree[Rson].sum + (tree[root].add*(R-mid)%MOD)%MOD) % MOD;
tree[root].add = 0;
}
void update(int root, int L, int R, int x, int y, int val){
if(x > R || y < L) return;
if(x <= L && y >= R){
tree[root].sum = (tree[root].sum + ((R - L + 1)%MOD * val)%MOD) % MOD;
tree[root].add = (tree[root].add + val) % MOD;
return;
}
int mid = (L + R) >> 1, Lson = root << 1, Rson = root << 1 | 1;
down(root, L, R);
update(Lson, L, mid, x, y, val);
update(Rson, mid+1, R, x, y, val);
tree[root].sum = (tree[Lson].sum + tree[Rson].sum) % MOD;
}
int query(int root, int L, int R, int x, int y){
if(x > R || y < L) return 0;
if(x <= L && y >= R) return tree[root].sum;
int mid = (L + R) >> 1, Lson = root << 1, Rson = root << 1 | 1;
down(root, L, R);
int tmp1 = query(Lson, L, mid, x, y);
int tmp2 = query(Rson, mid+1, R, x, y);
return (tmp1 + tmp2) % MOD;
}
void work1(int x, int y, int val){
while(top[x] != top[y]){
if(dep[top[x]] > dep[top[y]]) swap(x, y);
update(1, 1, n, start[top[y]], start[y], val);
y = fa[top[y]];
}
if(dep[x] > dep[y]) swap(x, y);
update(1, 1, n, start[x], start[y], val);
}
int work2(int x, int y){
int ans = 0;
while(top[x] != top[y]){
if(dep[top[x]] > dep[top[y]]) swap(x, y);
ans = (ans + query(1, 1, n, start[top[y]], start[y])) % MOD;
y = fa[top[y]];
}
if(dep[x] > dep[y]) swap(x, y);
ans = (ans + query(1, 1, n, start[x], start[y])) % MOD;
return ans;
}
int main(){
scanf("%d%d%d%d", &n, &m, &Root, &MOD);
Init();
for(int i = 1; i <= n; i++)
scanf("%d", &num[i]);
int a, b;
for(int i = 1; i < n; i++){
scanf("%d%d", &a, &b);
Insert(a, b);
Insert(b, a);
}
dfs1(Root, 0);
dfs2(Root, Root);
build(1, 1, n);
int op, k;
for(int i = 1; i <= m; i++){
scanf("%d", &op);
if(op == 1){
scanf("%d%d%d", &a, &b, &k);
work1(a, b, k % MOD);
}
else if(op == 2){
scanf("%d%d", &a, &b);
printf("%d\n", work2(a, b));
}
else if(op == 3){
scanf("%d%d", &a, &b);
update(1, 1, n, start[a], end[a], b % MOD);
}
else{
scanf("%d", &a);
printf("%d\n", query(1, 1, n, start[a], end[a]));
}
}
return 0;
}
Smile.