树链剖分 完全模板(子树查改+树链查改)

题目链接
以洛谷题为原题。看代码。
下面是模板代码(横线之外的为模板代码):

/*--------------------

1.mchange(x, y, v)  树链更改(x,y)
2.mask(x, y)        树链查询(x,y)
3.achange(x, v)     子树更改x
4.aask(x)            子树查询x

--------------------*/
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define ll long long
using namespace std;
const int N = 1e5+5;
int he[N], ne[N<<1], ver[N<<1];
int dep[N], son[N], fa[N], id[N], tp[N], dfn[N],siz[N];
int n, m, cnt, tot = 1;
ll su[N], w[N];
/*----------------

int root;
ll mod;

----------------*/
struct Node
{
    int l, r;
    ll sum, add;
}tr[N<<2];
void add(int x, int y){ver[++tot] = y; ne[tot] = he[x]; he[x] = tot;}
int dfs1(int u, int f)
{
    fa[u] = f;
    dep[u] = dep[f]+1;
    siz[u] = 1;
    int mx = -1;
    for (int i = he[u]; i; i = ne[i])
    {
        int y = ver[i];
        if (y == f) continue;
        siz[u] += dfs1(y, u);
        if (siz[y] > mx){mx = siz[y]; son[u] = y;}
    }
    return siz[u];
}
void dfs2(int u, int t)
{
    dfn[u] = ++cnt;
    tp[u] = t;
    w[cnt] = su[u];
    if (!son[u])
        return;
    dfs2(son[u], t);
    for (int i = he[u]; i; i = ne[i])
    {
        int v = ver[i];
        if (v == fa[u] || v == son[u]) continue;
        dfs2(v,v);
    }
}
inline void update(int p, ll v)
{
    /*单点更改----------------------------------

    tr[p].sum = (tr[p].sum + v * (tr[p].r - tr[p].l + 1)%mod)%mod;
    tr[p].add = (tr[p].add + v) % mod;

    ------------------------------------------*/
}
inline void pushup(int p)
{
    /*pushup------------------------------------

    tr[p].sum = (tr[p<<1].sum + tr[p<<1|1].sum) % mod;

    ------------------------------------------*/
}
void spread(int p)
{
    /*pushdown----------------------------------

    if (!tr[p].add) return;
    int l = p<<1, r = p<<1|1;
    update(l, tr[p].add);
    update(r, tr[p].add);
    tr[p].add = 0;

    ------------------------------------------*/
}
int mx =0;
void build(int p, int l, int r)
{
    mx = max(mx, p);
    tr[p].l = l; tr[p].r = r;
    tr[p].add = 0;
    if (l == r){tr[p].sum = w[l]; return;}
    int mid = (l + r) >> 1;
    build(p<<1, l, mid);
    build(p<<1|1, mid+1, r);
    pushup(p);
}
void change(int p, int l, int r, ll v)
{
    if (l <= tr[p].l && tr[p].r <= r){update(p, v); return;}
    spread(p);
    int mid = (tr[p].l + tr[p].r) >>1;
    if (l <= mid) change(p<<1, l, r, v);
    if (r > mid) change(p<<1|1, l, r, v);
    pushup(p);
}
ll ask(int p, int l, int r)
{
    if (l <= tr[p].l && tr[p].r <= r) return tr[p].sum%mod;
    spread(p);
    int mid = (tr[p].l + tr[p].r) >> 1;
    ll val = 0;
    if (l <= mid) val = (val + ask(p<<1, l, r))%mod;
    if (r > mid) val = (val + ask(p<<1|1, l, r))%mod;
    return val%mod;
}
void mchange(int x, int y, ll v)
{
    while(tp[x] != tp[y])
    {
        if (dep[tp[x]] < dep[tp[y]]) swap(x, y);
        change(1, dfn[tp[x]], dfn[x], v);
        x = fa[tp[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    change(1, dfn[x], dfn[y], v);
}
ll mask(int x, int y)
{
    ll sum = 0;
    while(tp[x] != tp[y])
    {
        if (dep[tp[x]] < dep[tp[y]]) swap(x, y);
        /*-----------------------

        sum = (sum + ask(1, dfn[tp[x]], dfn[x]))%mod;

        -----------------------*/
        x = fa[tp[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    /*--------------------------

    sum = (sum + ask(1, dfn[x], dfn[y]))%mod;

    --------------------------*/
    return sum;
}
ll aask(int x)
{
    return ask(1, dfn[x], dfn[x] + siz[x] - 1);
}
void achange(int x, ll v)
{
    change(1, dfn[x], dfn[x] + siz[x] - 1, v);
}
/*--Debug--------------------
void print()
{
    for (int i = 1; i <= n; i++)
        cout << ask(1, dfn[i], dfn[i]) << " ";
    cout << endl;
}
--------------------------*/

/*-----------------------------

int main()
{
    scanf("%d%d%d%lld", &n, &m, &root, &mod);
    for (int i = 1; i <= n; i++)
        scanf("%lld", &su[i]), su[i]%=mod;
    for (int i = 1; i < n; i++)
    {
        int x, y;
        scanf("%d%d", &x, &y);
        add(x, y);
        add(y, x);
    }
    ------------------------*/
    dfs1(root, 0);
    dfs2(root, root);
    build(1, 1, n);
    /*-----------------------
    while(m--)
    {
        int op, x, y;
        ll k;
        scanf("%d", &op);
        if (op == 1)
        {
            scanf("%d%d%lld", &x, &y, &k);
            mchange(x, y, k);
        }
        else if (op == 2)
        {
            scanf("%d%d", &x, &y);
            printf("%lld\n", mask(x, y)%mod);
        }
        else if (op == 3)
        {
            scanf("%d%lld", &x, &k);
            achange(x, k);
        }
        else
        {
            scanf("%d", &x);
            printf("%lld\n", aask(x)%mod);
        }
    }
    return 0;
}
-------------------------*/


你可能感兴趣的:(数据结构)