线段树 ---- D. Power Tree(离线dfs序+线段树维护树上多条路径和的技巧)

题目链接


题目大意:

一开始给你只有一个点 1 1 1的树,有 q q q次询问。每次询问有两种操作

  • 1    p    v 1\;p\;v 1pv 就是把最小的没加入的点,加入这个树,它的父亲是 p p p,权值是 v v v
  • 2    u 2\;u 2u 就是询问你u的 S t r e n g t h ( S u ) Strength(S_{u}) Strength(Su)是多少?

S u S_u Su的直接定义是一个集合这个集合包括这个点里面所有的直接儿子 S t r e n g t h ( S s o n ) Strength(S_{son}) Strength(Sson),和当前点的权值组成的集合

S t r e n g t h ( S ) = ∣ S ∣ ⋅ ∑ d ∈ S d Strength(S)=|S|\cdot\sum_{d\in S}d Strength(S)=SdSd


解题思路:

  1. 首先我们知道对于每个点的值 u u u是这么算的就是它所有子树里面的点,一直往上跳直到点 u u u,每次跳到一个点我们就乘以(跳到点的直接儿子个数+1),那么就是这个点对 u u u节点的答案贡献。那么所有的子树的点贡献加起来就是 u u u的答案了

  2. 但是这么搞很多条路径直接到不同的点很难维护,我们可以假设全部询问都是根节点

  3. 我们可以算出每一个节点对根节点的贡献因子mi。
    d ( i ) d(i) d(i)是节点 i i i儿子节点数+1.
    m ( i ) m(i) m(i)就是从 i i i到根节点的简单路径上的 d ( i ) d(i) d(i)的乘积。

  4. 当新加入一个节点 u u u到节点 p p p时:我们观察可以发现,仅仅改变了 p p p p p p的子树的 m ( i ) m(i) m(i).

  5. 显然是 m ( i ) = m ( i ) ∗ ( d ( p ) + 1 ) / d ( p ) . m ( u ) = m ( p ) m(i)=m(i)*(d(p)+1)/d(p). m(u)=m(p) m(i)=m(i)(d(p)+1)/d(p).m(u)=m(p)

  6. 我们需要一个数据结构可以 对一个区间进行乘,更新一个值,以及求一个区间的总和。

  7. 我们用线段树维护dfn序。
    那么一个节点的子树就是一段连续的区间。
    我们求一个节点的贡献时。显然就是 s u m sum sum同时除以他们的公共路径的 m ( i ) m(i) m(i)
    由于乘法是一个浮点数。我们先累乘所有的分子,再累乘所有的分母,最后我们用乘法逆元
    可以求出取模。

  8. 时间复杂度 O ( q × ( l o g n + l o g ( m o d ) ) ) O(q\times(logn+log(mod))) O(q×(logn+log(mod)))


AC code

#include 
#define mid ((l + r) >> 1)
#define Lson rt << 1, l , mid
#define Rson rt << 1|1, mid + 1, r
#define ms(a,al) memset(a,al,sizeof(a))
#define log2(a) log(a)/log(2)
#define lowbit(x) ((-x) & x)
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define INF 0x3f3f3f3f
#define LLF 0x3f3f3f3f3f3f3f3f
#define f first
#define s second
#define endl '\n'
using namespace std;
const int N = 2e6 + 10, mod = 1e9 + 7;
const int maxn = 500010;
const long double eps = 1e-5;
const int EPS = 500 * 500;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> PII;
typedef pair<ll,ll> PLL;
typedef pair<double,double> PDD;
template<typename T> void read(T &x) {
   x = 0;char ch = getchar();ll f = 1;
   while(!isdigit(ch)){if(ch == '-')f*=-1;ch=getchar();}
   while(isdigit(ch)){x = x*10+ch-48;ch=getchar();}x*=f;
}
template<typename T, typename... Args> void read(T &first, Args& ... args) {
   read(first);
   read(args...);   
}

int n, m, now = 1;
int d[maxn];
vector<int> G[maxn];
struct Q {
    int op, u, v;
};
vector<Q> e;
int L[maxn], R[maxn], tot, fa[maxn];
void dfs(int u) {// 求每个子树的区间[L[u],R[u]]
    L[u] = ++ tot;
    for(auto it : G[u]) dfs(it);
    R[u] = tot;
}

inline ll qim(ll a,ll b) {//快速幂
    ll res = 1;
    while(b) {
        if(b&1) res = res * a % mod;
        b >>= 1;
        a = a * a % mod;
    }
    return res;
}

struct Segtree {
    ll tr[maxn << 2];// 维护区间加
    ll tag[maxn << 2];// 维护乘法
    void pushdown(int rt) {
        int mul = tag[rt];
        tag[rt<<1] = (tag[rt<<1]*mul) % mod;
        tag[rt<<1|1] = (tag[rt<<1|1]*mul) % mod;
        tr[rt<<1] = (tr[rt<<1]*mul) % mod;
        tr[rt<<1|1] = (tr[rt<<1|1]*mul) % mod;
        tag[rt] = 1;
    }
    void pushup(int rt) {
        tr[rt] = (tr[rt<<1]+tr[rt<<1|1]) % mod;
    }
    void build(int rt, int l, int r) {
        tag[rt] = 1;
        tr[rt] = 0;
        if(l == r) return;
        build(Lson);
        build(Rson);
    }
    void segmul(int rt, int l, int r, int L, int R, int val) {
        if(L <= l && R >= r) {
            tag[rt] = (tag[rt] * val) % mod;
            tr[rt] = (tr[rt] * val) % mod;
            return;
        }
        pushdown(rt);
        if(L <= mid) segmul(Lson,L,R,val);
        if(R > mid) segmul(Rson,L,R,val);
        pushup(rt);
    }
    void update(int rt, int l, int r, int pos, int val) {
        if(l == r) {
            tr[rt] = (val * tag[rt]) % mod;// 这里新加入一个点相当于把它激活了以前是0,不考虑贡献,要让它乘上路径上的d(i), 也就是tag[rt]
            return;
        }
        pushdown(rt);
        if(pos <= mid) update(Lson,pos,val);
        else update(Rson,pos,val);
        pushup(rt);
    }
    ll ask(int rt, int l, int r, int L, int R) {
        if(L <= l && R >= r) return tr[rt];// 询问区间加
        pushdown(rt);
        ll res = 0;
        if(L <= mid) res = (res + ask(Lson,L,R)) % mod;
        if(R > mid) res = (res + ask(Rson,L,R)) % mod;
        return res;
    }
    ll quary(int rt, int l, int r, int pos) {
        if(l == r) return tag[rt];// 询问路径上面的d(i)的乘积
        pushdown(rt);
        if(pos <= mid) return quary(Lson,pos);
        else return quary(Rson,pos);
    }
}sgt;

int main() {
    IOS;
    read(n,m);
    for(int i = 1; i <= m; ++ i) {
        int op, p, v;
        read(op,p);
        if(op==1) {
            read(v);
            now ++;
            G[p].push_back(now);
            e.push_back({op,p,v});
        } else {
            e.push_back({op,p,-1});
        }
    }
    dfs(1);
    sgt.build(1,1,tot);
    now = 1;
    d[1] = 1;
    sgt.update(1,1,tot,1,n);
    for(int i = 0; i < m; ++ i) {
        int op = e[i].op;
        if(op == 1) {
            int p = e[i].u, v = e[i].v;
            d[++now] = 1;
            fa[now] = p; 
            int sub = qim(d[p],mod-2);
            d[p] ++;
            sgt.segmul(1,1,tot,L[p],R[p],d[p]);// mu[p]*(d[p]+1)/d[p]
            sgt.segmul(1,1,tot,L[p],R[p],sub);
            sgt.update(1,1,tot,L[now],v);// 把单前点在线段树里面激活
        }  else {
            int v = e[i].u;
            ll tmp=1;
            if(v!=1) tmp=qim(sgt.quary(1,1,tot,L[fa[v]]),mod-2);// 除以路径上面的乘积
            cout << sgt.ask(1,1,tot,L[v],R[v]) * tmp % mod << "\n";
        }
    }
}

你可能感兴趣的:(#,各种线段树,数据结构,算法)