树链剖分:P3384 【模板】树链剖分

题目描述:戳这里
题解:
其实树剖的重点就在于轻重链,这篇文章写的很好
然而我线段树写得全是问题,改了半天2333

代码如下:

#include
#include
#include
using namespace std;
const int maxn=100005;
int n,m,root,tt,tot,lnk[maxn],son[2*maxn],nxt[2*maxn],a[maxn];
int fa[maxn],top[maxn],siz[maxn],wson[maxn],id[maxn],dep[maxn],rk[maxn];
struct dyt{
    int l,r,sum,tag;
}tree[4*maxn];
void add(int x,int y){son[++tot]=y,nxt[tot]=lnk[x],lnk[x]=tot;}
void dfs(int x,int las){
    fa[x]=las,siz[x]=1,dep[x]=dep[las]+1;
    for (int j=lnk[x];j;j=nxt[j])
    if (son[j]!=las) {
        dfs(son[j],x); siz[x]+=siz[son[j]];
        if (siz[son[j]]>siz[wson[x]]) wson[x]=son[j];
    }
}
void dfs1(int x,int las){
    top[x]=las,id[x]=++tot,rk[tot]=x;
    if (wson[x]>0) dfs1(wson[x],las);
    for (int j=lnk[x];j;j=nxt[j])
    if (son[j]!=fa[x]&&son[j]!=wson[x]) dfs1(son[j],son[j]);
}
void buildtree(int x,int l,int r){
    tree[x].l=l,tree[x].r=r;
    if (l==r) {tree[x].sum=a[rk[l]]; return;}
    int mid=(l+r)>>1;
    buildtree(2*x,l,mid); buildtree(2*x+1,mid+1,r);
    tree[x].sum=(tree[2*x].sum+tree[2*x+1].sum)%tt;
}
void pushdown(int x){
    int tag=tree[x].tag; tree[x].tag=0;
    if (tree[x].l==tree[x].r) return;
    tree[2*x].sum=(tree[2*x].sum+tag*(tree[2*x].r-tree[2*x].l+1)%tt)%tt,tree[2*x].tag=(tree[2*x].tag+tag)%tt;
    tree[2*x+1].sum=(tree[2*x+1].sum+tag*(tree[2*x+1].r-tree[2*x+1].l+1)%tt)%tt,tree[2*x+1].tag=(tree[2*x+1].tag+tag)%tt;
}
void insert(int x,int l,int r,int z){
    pushdown(x);
    if (tree[x].l==l&&tree[x].r==r)
    {tree[x].sum=(tree[x].sum+z*(tree[x].r-tree[x].l+1)%tt)%tt; tree[x].tag=(tree[x].tag+z)%tt; return;}
    int mid=(tree[x].l+tree[x].r)>>1;
    if (r<=mid) insert(2*x,l,r,z);
    else if (l>mid) insert(2*x+1,l,r,z);
    else {insert(2*x,l,mid,z); insert(2*x+1,mid+1,r,z);}
    tree[x].sum=(tree[2*x].sum+tree[2*x+1].sum)%tt;
}
int query(int x,int l,int r){
    pushdown(x);
    if (tree[x].l==l&&tree[x].r==r) return tree[x].sum;
    int mid=(tree[x].l+tree[x].r)>>1;
    if (r<=mid) return query(2*x,l,r);
    else if (l>mid) return query(2*x+1,l,r);
    else return (query(2*x,l,mid)+query(2*x+1,mid+1,r))%tt; 
}
void put(int x,int y,int z){
    int fax=top[x],fay=top[y];
    while (fax!=fay) {
        if (dep[fax]>=dep[fay]) {insert(1,id[fax],id[x],z); x=fa[fax],fax=top[x];}
        else {insert(1,id[fay],id[y],z); y=fa[fay],fay=top[y];}
    }
    if(id[x]<=id[y]) insert(1,id[x],id[y],z); else insert(1,id[y],id[x],z);
}
int get(int x,int y){
    int ret=0,fax=top[x],fay=top[y];
    while (fax!=fay) {
        if (dep[fax]>=dep[fay]) {ret=(ret+query(1,id[fax],id[x]))%tt; x=fa[fax],fax=top[x];}
        else {ret=(ret+query(1,id[fay],id[y]))%tt; y=fa[fay],fay=top[y];}
    }
    if(id[x]<=id[y]) ret=(ret+query(1,id[x],id[y]))%tt; else ret=(ret+query(1,id[y],id[x]))%tt;
    return ret;
}
int main(){
    scanf("%d %d %d %d\n",&n,&m,&root,&tt);
    for (int i=1;i<=n;i++) scanf("%d",&a[i]),a[i]%=tt;
    for (int i=1;iint x,y; scanf("%d %d",&x,&y);
        add(x,y); add(y,x);
    }
    tot=0; dfs(root,0); dfs1(root,root); buildtree(1,1,n);
    for (int i=1;i<=m;i++) {
        int p,x,y,z;
        scanf("%d %d",&p,&x);
        if (p!=4) scanf("%d",&y); if (p==1) scanf("%d",&z);
        if (p==1) put(x,y,z%tt);
        if (p==2) printf("%d\n",get(x,y));
        if (p==3) insert(1,id[x],id[x]+siz[x]-1,y%tt);
        if (p==4) printf("%d\n",query(1,id[x],id[x]+siz[x]-1));
    }
    return 0;
}

你可能感兴趣的:(题解,洛谷题解,知识整理)