树(树链剖分)

   
   
   
   
   
   
   
   

解:

傻逼的我考试的时候去写这道题,没调出来,结果爆零了-_-
其实思想是很好懂的,一眼是一个树剖,然后陷入了无尽的推式子中。其实没写出来还有很大一部分原因是因为我数学太菜了,写到一半发现推错,又重新推,又推错……

不过写了好几道树剖,好歹有一道会的了。

首先我们可以很容易想到一种预处理的方法:记录每个点 u u 为根到子树里所有路径的和 (unqrt(u)) ( u n q r t ( u ) ) &平方和 (qrt(u)) ( q r t ( u ) )
我们用 ans(v,u) a n s ( v , u ) 表示 v v u u 子树内所有点的平方和。
然后思考一个子树外点 v v 到这棵子树内所有点的平方和:
ans(v,u)=qrt(u)+2unqrt(u)dis(u,v)+size(u)dis(u,v)2 a n s ( v , u ) = q r t ( u ) + 2 ∗ u n q r t ( u ) ∗ d i s ( u , v ) + s i z e ( u ) ∗ d i s ( u , v ) 2
如果 v v u u 的子树内我们可以以一个较短的时间求答案,那是不是就可以做这道题了?

考虑容斥:我们把 v v 一步一步往上跳求答案。(画个图)
树(树链剖分)_第1张图片
我们现在已经求得 ans(v,v) a n s ( v , v ) ,要求 ans(v,w) a n s ( v , w )
先假设 v v w w 外算一遍,显然在 v v 子树内点的贡献是错误的,我们需要把它更正一下,由于已经求了 ans(v,v) a n s ( v , v ) ,我们只需要在当前算的 ans(v,w) a n s ( v , w ) 中的 v v 子树的贡献扣掉就好了。
具体式子就长成这样:
ans(v,w)=qrt(w)+2unqrt(w)x+size(w)x2qrt(v)2unqrt(v)2xsize(v)(2x)2+ans(v,v) a n s ( v , w ) = q r t ( w ) + 2 ∗ u n q r t ( w ) ∗ x + s i z e ( w ) ∗ x 2 − q r t ( v ) − 2 ∗ u n q r t ( v ) ∗ 2 x − s i z e ( v ) ∗ ( 2 x ) 2 + a n s ( v , v )
然后一步一步往上跳,递归求解就行。
然而这样做太慢了,考虑优化:跳链很容易想到树剖,更何况这里还需要维护链上信息。(似乎还可以做单点修改,不过可能太毒瘤了。)
现在需要考虑的是如何快速跳过一条重链?
假设我们跳过了一个点 s s ,那它对答案的贡献是多少?
根据上面的式子,我们可以发现,点s的贡献在跳到s的时候加了一堆,跳到fa(s)的时候又减了一堆,我们把这两堆加起来(先画个图):
树(树链剖分)_第2张图片
=qrt(s)+2unqrt(s)dis(v,s)+size(s)dis(v,s)2qrt(s)2unqrt(s)(2x+dis(v,s))2size(s)(2x+dis(v,s))2 贡 献 = q r t ( s ) + 2 ∗ u n q r t ( s ) ∗ d i s ( v , s ) + s i z e ( s ) ∗ d i s ( v , s ) 2 − q r t ( s ) − 2 ∗ u n q r t ( s ) ∗ ( 2 x + d i s ( v , s ) ) 2 − s i z e ( s ) ∗ ( 2 x + d i s ( v , s ) ) 2

=4xunqrt(s)size(s)(4x2+4xdis(v,s)) 贡 献 = − 4 x ∗ u n q r t ( s ) − s i z e ( s ) ∗ ( 4 x 2 + 4 x ∗ d i s ( v , s ) )

把所有含有 dis(u,s) d i s ( u , s ) 的式子提出来,然后我们预处理一下 4xunqrt(s)+size(s)4x2 4 x ∗ u n q r t ( s ) + s i z e ( s ) ∗ 4 x 2 size(s)4xdis(u,s) s i z e ( s ) ∗ 4 x ∗ d i s ( u , s )
把这两个东西在树上做一个前缀和,跳链就可以做到一次跳一条重链了。

如何统计答案?
我们维护一下跳链的数值,然后最后跳到顶由于没有了减掉的那一堆,我们把它加回来,最后就是 ans(v,u) a n s ( v , u ) 了。

回答询问时,讨论一下 v v 是否在 u u 中。

哇~5KB代码:

#include
#include
#include
using namespace std;
struct lxy{
    int to,next;
    long long len;
}b[200005];

long long const mod=1000000007;

int n,x,y,z,cnt,q,ui,vi;
int head[100005];
int wson[100005];
int size[100005];
int dep[100005];
long long fro[100005];
bool vis[100005];
int fa[100005];
int tp[100005];
long long qrt[100005];
long long unqrt[100005];
long long rc[100005];
long long xi[100005];
long long cooold[100005];
long long dis[100005];


void add(int op,int ed,int len)
{
    b[++cnt].next=head[op];
    b[cnt].len=len;
    b[cnt].to=ed;
    head[op]=cnt;
}

void dfs2(int u,int las)
{
    tp[u]=las;
    vis[u]=1;
    if(wson[u]!=0) dfs2(wson[u],las);
    for(int i=head[u];i!=-1;i=b[i].next)
      if(vis[b[i].to]==0&&b[i].to!=wson[u])
        dfs2(b[i].to,b[i].to);
    vis[u]=0;
}

void dfs1(int u,int dp)
{
    dep[u]=dp;
    int weigh=0;
    size[u]=1;vis[u]=1;
    for(int i=head[u];i!=-1;i=b[i].next)
      if(vis[b[i].to]==0)
      {
        fro[b[i].to]=fro[u]+b[i].len;
        fa[b[i].to]=u;
        xi[b[i].to]=b[i].len;
        dfs1(b[i].to,dp+1);
        size[u]+=size[b[i].to];
        if(size[b[i].to]>weigh)
          weigh=size[b[i].to],wson[u]=b[i].to;
      }
    vis[u]=0;
}

void dfs3(int u)
{
    vis[u]=1;
    for(int i=head[u];i!=-1;i=b[i].next)
      if(vis[b[i].to]==0)
      {
        dfs3(b[i].to);
        unqrt[u]=(unqrt[u]+unqrt[b[i].to]+(size[b[i].to]*b[i].len)%mod)%mod;
        qrt[u]=(qrt[b[i].to]+qrt[u]+2*unqrt[b[i].to]*b[i].len%mod+size[b[i].to]*b[i].len%mod*b[i].len)%mod;
      }
    vis[u]=0;
}

long long dfs4(int u)
{
    vis[u]=1;
    long long ret=0;
    for(int i=head[u];i!=-1;i=b[i].next)
      if(vis[b[i].to]==0)
      {
        if(b[i].to==wson[u])
        {
            ret=dfs4(wson[u]);
            ret=(ret+b[i].len)%mod;
            rc[u]=(4*unqrt[u]*xi[u]%mod+size[u]*(4*ret*xi[u]%mod+4*xi[u]*xi[u]%mod)%mod+rc[wson[u]])%mod;
            cooold[u]=(4*xi[u]*size[u]+cooold[wson[u]])%mod;
            dis[u]=ret;
        }
        else dfs4(b[i].to);
      }
    if(wson[u]==0)
    {
        rc[u]=(4*xi[u]*xi[u])%mod;
        cooold[u]=(4*xi[u]*size[u])%mod;
    }
    vis[u]=0;
    return ret;
}

int main()
{
    memset(head,-1,sizeof(head));
    scanf("%d",&n);
    for(int i=1;i"%d%d%d",&x,&y,&z);
        add(x,y,z);add(y,x,z);
    }
    dfs1(1,1);dfs2(1,1);dfs3(1);dfs4(1);
    scanf("%d",&q);
    for(int i=1;i<=q;i++)
    {
      scanf("%d%d",&x,&y);
      ui=x,vi=y;
      long long ans=0,len=0,road=0,ret=0;int lca;
      while(tp[x]!=tp[y])
      {
        if(dep[tp[x]]>=dep[tp[y]])
        {
            ans=(ans+rc[tp[x]]-rc[wson[x]]+(len-dis[x])*(cooold[tp[x]]-cooold[wson[x]]))%mod;
            len=(len+fro[x]-fro[fa[tp[x]]])%mod;
            x=tp[x];
            x=fa[x];
        }
        else if(dep[tp[x]]y]])
        {
          road=(road+fro[y]-fro[fa[tp[y]]])%mod;
          y=fa[tp[y]];
        }
      }
      if(dep[y]>=dep[x])
      {
        road=(road+fro[y]-fro[x])%mod;
        ans=(ans+rc[x]-rc[wson[x]]+(len-dis[x])*(cooold[x]-cooold[wson[x]]))%mod;
        lca=x,y=x;
      }
      else{
        ans=(ans+rc[y]-rc[wson[x]]+(len-dis[x])*(cooold[y]-cooold[wson[x]]))%mod;
        len=(len+fro[x]-fro[y])%mod;
        x=y;lca=y;
      }
      if(lca!=vi)
      {
        ret=(qrt[vi]+2*unqrt[vi]*(len+road)%mod+size[vi]*(len+road)%mod*(len+road)%mod)%mod;
        len=(len+fro[x]-fro[fa[x]]);
        x=fa[x];
        while(tp[x]!=0) 
        {
            ans=(ans+rc[tp[x]]-rc[wson[x]]+(len-dis[x])*(cooold[tp[x]]-cooold[wson[x]]))%mod;
            len=(len+fro[x]-fro[fa[tp[x]]])%mod;
            x=fa[tp[x]];
        }
        ans=(qrt[1]+2*unqrt[1]%mod*len+size[1]*len%mod*len-ans)%mod;
        ret=2*ret-ans;
        ret=ret%mod;
        if(ret<0) ret=(ret+mod)%mod;
        printf("%lld\n",ret);
        continue;
      }
      if(lca==vi)
      {
        ret=(qrt[lca]+2*unqrt[lca]*(len+2*xi[lca])%mod+size[lca]*(len+2*xi[lca])%mod*(len+2*xi[lca])-ans)%mod;
        len=(len+fro[x]-fro[fa[x]]);
        x=fa[x];
        while(tp[x]!=0) 
        {
            ans=(ans+rc[tp[x]]-rc[wson[x]]+(len-dis[x])*(cooold[tp[x]]-cooold[wson[x]]))%mod;
            len=(len+fro[x]-fro[fa[tp[x]]])%mod;
            x=fa[tp[x]];
        }
        ans=(qrt[1]+2*unqrt[1]%mod*len+size[1]*len%mod*len-ans)%mod;
        ret=2*ret-ans;
        ret=ret%mod;
        if(ret<0) ret=(ret+mod)%mod;
        printf("%lld\n",ret);
        continue;
      }
    }
}

你可能感兴趣的:(数据结构)