题目描述:戳这里
题解:
其实树剖的重点就在于轻重链,这篇文章写的很好
然而我线段树写得全是问题,改了半天2333
代码如下:
#include
#include
#include
using namespace std;
const int maxn=100005;
int n,m,root,tt,tot,lnk[maxn],son[2*maxn],nxt[2*maxn],a[maxn];
int fa[maxn],top[maxn],siz[maxn],wson[maxn],id[maxn],dep[maxn],rk[maxn];
struct dyt{
int l,r,sum,tag;
}tree[4*maxn];
void add(int x,int y){son[++tot]=y,nxt[tot]=lnk[x],lnk[x]=tot;}
void dfs(int x,int las){
fa[x]=las,siz[x]=1,dep[x]=dep[las]+1;
for (int j=lnk[x];j;j=nxt[j])
if (son[j]!=las) {
dfs(son[j],x); siz[x]+=siz[son[j]];
if (siz[son[j]]>siz[wson[x]]) wson[x]=son[j];
}
}
void dfs1(int x,int las){
top[x]=las,id[x]=++tot,rk[tot]=x;
if (wson[x]>0) dfs1(wson[x],las);
for (int j=lnk[x];j;j=nxt[j])
if (son[j]!=fa[x]&&son[j]!=wson[x]) dfs1(son[j],son[j]);
}
void buildtree(int x,int l,int r){
tree[x].l=l,tree[x].r=r;
if (l==r) {tree[x].sum=a[rk[l]]; return;}
int mid=(l+r)>>1;
buildtree(2*x,l,mid); buildtree(2*x+1,mid+1,r);
tree[x].sum=(tree[2*x].sum+tree[2*x+1].sum)%tt;
}
void pushdown(int x){
int tag=tree[x].tag; tree[x].tag=0;
if (tree[x].l==tree[x].r) return;
tree[2*x].sum=(tree[2*x].sum+tag*(tree[2*x].r-tree[2*x].l+1)%tt)%tt,tree[2*x].tag=(tree[2*x].tag+tag)%tt;
tree[2*x+1].sum=(tree[2*x+1].sum+tag*(tree[2*x+1].r-tree[2*x+1].l+1)%tt)%tt,tree[2*x+1].tag=(tree[2*x+1].tag+tag)%tt;
}
void insert(int x,int l,int r,int z){
pushdown(x);
if (tree[x].l==l&&tree[x].r==r)
{tree[x].sum=(tree[x].sum+z*(tree[x].r-tree[x].l+1)%tt)%tt; tree[x].tag=(tree[x].tag+z)%tt; return;}
int mid=(tree[x].l+tree[x].r)>>1;
if (r<=mid) insert(2*x,l,r,z);
else if (l>mid) insert(2*x+1,l,r,z);
else {insert(2*x,l,mid,z); insert(2*x+1,mid+1,r,z);}
tree[x].sum=(tree[2*x].sum+tree[2*x+1].sum)%tt;
}
int query(int x,int l,int r){
pushdown(x);
if (tree[x].l==l&&tree[x].r==r) return tree[x].sum;
int mid=(tree[x].l+tree[x].r)>>1;
if (r<=mid) return query(2*x,l,r);
else if (l>mid) return query(2*x+1,l,r);
else return (query(2*x,l,mid)+query(2*x+1,mid+1,r))%tt;
}
void put(int x,int y,int z){
int fax=top[x],fay=top[y];
while (fax!=fay) {
if (dep[fax]>=dep[fay]) {insert(1,id[fax],id[x],z); x=fa[fax],fax=top[x];}
else {insert(1,id[fay],id[y],z); y=fa[fay],fay=top[y];}
}
if(id[x]<=id[y]) insert(1,id[x],id[y],z); else insert(1,id[y],id[x],z);
}
int get(int x,int y){
int ret=0,fax=top[x],fay=top[y];
while (fax!=fay) {
if (dep[fax]>=dep[fay]) {ret=(ret+query(1,id[fax],id[x]))%tt; x=fa[fax],fax=top[x];}
else {ret=(ret+query(1,id[fay],id[y]))%tt; y=fa[fay],fay=top[y];}
}
if(id[x]<=id[y]) ret=(ret+query(1,id[x],id[y]))%tt; else ret=(ret+query(1,id[y],id[x]))%tt;
return ret;
}
int main(){
scanf("%d %d %d %d\n",&n,&m,&root,&tt);
for (int i=1;i<=n;i++) scanf("%d",&a[i]),a[i]%=tt;
for (int i=1;iint x,y; scanf("%d %d",&x,&y);
add(x,y); add(y,x);
}
tot=0; dfs(root,0); dfs1(root,root); buildtree(1,1,n);
for (int i=1;i<=m;i++) {
int p,x,y,z;
scanf("%d %d",&p,&x);
if (p!=4) scanf("%d",&y); if (p==1) scanf("%d",&z);
if (p==1) put(x,y,z%tt);
if (p==2) printf("%d\n",get(x,y));
if (p==3) insert(1,id[x],id[x]+siz[x]-1,y%tt);
if (p==4) printf("%d\n",query(1,id[x],id[x]+siz[x]-1));
}
return 0;
}