看题第一眼反应点分治。。。QAQ但是从来没写过动态点分治不会写。
然后扒到了一个树链剖分的题解,发现还是可做的。
考虑一个朴素的问题,如果没有颜色限制,问所有点到u的距离之和是多少?
两点间距离dist(u,v)=deep(u)+deep(v)-2*deep(lca(u,v))。
因此所有点到u的距离之和=deep(u)*n+Σ(i=1,n)deep(i)-2Σ(i=1,n)deep(lca(u,v))。
因此关键就是求Σ(i=1,n)deep(lca(u,v))。
换句话说,我们需要知道u的每一个祖先t是多少对(u,v)的lca,实际上又因为t是u的祖先,我们只需要知道t是多少个(v)的祖先,且v->t的路径和u->t的路径不重合。
因此我们先暴力枚举每个点i,将i->rt的路径覆盖,送油i->rt的边经过的次数+1。然后Σ(i=1,n)deep(lca(u,v))就是u->rt的路径中,Σlen[e]*times[e],len[e]表示边权,times[e]表示经过的次数。
然后就可以用树链剖分维护了。O((N+M)log^2N)。
考虑颜色限制[l,r],将问题转化为前缀和统计[1,r]-[1,l-1]。
然后我们根据颜色从小到大进行路径覆盖,由于每次只会修改O(log^2N)个点,用主席树维护[1,i]的那颗线段树。然后查询的时候减一减就好了。时空复杂度O(Nlog^2N),似乎会MLE?没关系因为树链剖分好像是做不到log^2N的,只要把数组稍微比Nlog^2N开小一点就好了。
AC代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define ll long long #define N 300005 #define M 20000005 using namespace std; int n,m,mod,trtot,tot,dfsclk,fst[N],pnt[N],len[N],nxt[N]; int pos[N],son[N],fa[N],anc[N],ls[M],rs[M],cvr[M],sz[N],edg[N],rt[N]; ll sum[M],sumd[N],d[N],val[N]; struct node{ int x,y; }a[N]; int read(){ int x=0; char ch=getchar(); while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); } return x; } void add(int x,int y,int z){ pnt[++tot]=y; len[tot]=z; nxt[tot]=fst[x]; fst[x]=tot; } bool cmp(node aa,node bb){ return aa.x<bb.x || aa.x==bb.x && aa.y<bb.y; } void dfs(int x){ sz[x]=1; int p; for (p=fst[x]; p; p=nxt[p]){ int y=pnt[p]; if (fa[x]==y) continue; fa[y]=x; edg[y]=len[p]; d[y]=d[x]+len[p]; dfs(y); sz[x]+=sz[y]; if (sz[y]>sz[son[x]]) son[x]=y; } } void build(int x,int tp){ pos[x]=++dfsclk; val[dfsclk]=edg[x]; anc[x]=tp; if (son[x]) build(son[x],tp); int p; for (p=fst[x]; p; p=nxt[p]){ int y=pnt[p]; if (fa[x]!=y && son[x]!=y) build(y,y); } } void ins(int l,int r,int x,int &y,int u,int v){ y=++trtot; int mid=(l+r)>>1; sum[y]=sum[x]; cvr[y]=cvr[x]; if (l==u && r==v){ ls[y]=ls[x]; rs[y]=rs[x]; cvr[y]++; return; } sum[y]+=val[v]-val[u-1]; if (v<=mid){ rs[y]=rs[x]; ins(l,mid,ls[x],ls[y],u,v); } else if (u>mid){ ls[y]=ls[x]; ins(mid+1,r,rs[x],rs[y],u,v); } else{ ins(l,mid,ls[x],ls[y],u,mid); ins(mid+1,r,rs[x],rs[y],mid+1,v); } } ll qry(int l,int r,int k,int u,int v){ ll t=(ll)(val[v]-val[u-1])*cvr[k]; if (l==u && r==v) return t+sum[k]; int mid=(l+r)>>1; if (v<=mid) return t+qry(l,mid,ls[k],u,v); else if (u>mid) return t+qry(mid+1,r,rs[k],u,v); else return t+qry(l,mid,ls[k],u,mid)+qry(mid+1,r,rs[k],mid+1,v); } int find(int x){ int l=0,r=n+1,mid; while (l+1<r){ mid=(l+r)>>1; if (a[mid].x<=x) l=mid; else r=mid; } return l; } int main(){ n=read(); m=read(); mod=read(); int i; for (i=1; i<=n; i++){ a[i].x=read(); a[i].y=i; } sort(a+1,a+n+1,cmp); for (i=1; i<n; i++){ int x=read(),y=read(),z=read(); add(x,y,z); add(y,x,z); } dfs(1); build(1,1); for (i=1; i<=n; i++){ val[i]+=val[i-1]; sumd[i]=sumd[i-1]+d[a[i].y]; } for (i=1; i<=n; i++){ int x=a[i].y; rt[i]=rt[i-1]; for (; x; x=fa[anc[x]]) ins(1,n,rt[i],rt[i],pos[anc[x]],pos[x]); } ll ans=0; int k,x,y; while (m--){ k=read(); x=(ans+read())%mod; y=(ans+read())%mod; if (x>y) swap(x,y); x=find(x-1); y=find(y); ans=d[k]*(y-x)+sumd[y]-sumd[x]; for (; k; k=fa[anc[k]]) ans-=(qry(1,n,rt[y],pos[anc[k]],pos[k])-qry(1,n,rt[x],pos[anc[k]],pos[k]))<<1; printf("%lld\n",ans); } return 0; }
by lych
2016.2.29