难得有一道从头到尾自己做的题了T_T(←这么弱还好意思说)。
首先考虑一个特定的询问(x,z)应该怎么做:将所有1...x的点到根的路径覆盖一遍,然后就是求z到根的每一条边的覆盖次数的和。(画个图就很明确了)。于是可以将询问(l,r,z)拆成(r,z)和(l-1,z),然后按照第一维排序(注意特判0的情况)。
然后将x从1循环到n,每次将x到根的路径覆盖一遍(在原来的基础上),然后求所有第一维=x的(x,z)的值,更新它对应的答案。
快速覆盖路径,以及快速求覆盖次数的和是经典的在树链剖分后,在dfs序中区间修改和区间查询。然后我写了树状数组维护(真的又短又快辣!!)。
时间复杂度O(Nlog^2N)。
AC代码如下:
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define mod 201314 #define ll long long #define N 100005 using namespace std; int n,m,tot,cnt,dfsclk,fst[N],pnt[N],nxt[N],pos[N],anc[N],son[N],fa[N],sz[N],ans[N]; struct node{ int x,y,z,id; }a[N]; struct bit_node{ int c[N]; void ins(int x,int t){ for (; x<=n; x+=x&-x) c[x]=(c[x]+t)%mod; } int getsum(int x){ int t=0; for (; x; x-=x&-x) t=(t+c[x])%mod; return t; } }bit1,bit2; int read(){ int x=0; char ch=getchar(); while (ch<'0' || ch>'9') ch=getchar(); while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); } return x; } void add(int x,int y){ pnt[++tot]=y; nxt[tot]=fst[x]; fst[x]=tot; } bool cmp(node x,node y){ return x.x<y.x; } void dfs(int x){ sz[x]=1; int p; for (p=fst[x]; p; p=nxt[p]){ int y=pnt[p]; dfs(y); sz[x]+=sz[y]; if (sz[y]>sz[son[x]]) son[x]=y; } } void build(int x,int tp){ pos[x]=++dfsclk; anc[x]=tp; int p; if (son[x]) build(son[x],tp); for (p=fst[x]; p; p=nxt[p]) if (son[x]!=pnt[p]) build(pnt[p],pnt[p]); } void ins(int x,int y){ bit1.ins(x,1); bit1.ins(y+1,-1); bit2.ins(x,x-1); bit2.ins(y+1,-y); } int qry(int x){ return ((ll)bit1.getsum(x)*x%mod-bit2.getsum(x))%mod; } int main(){ n=read(); m=read(); int i; for (i=2; i<=n; i++) add(fa[i]=read()+1,i); for (i=1; i<=m; i++){ a[++cnt].x=read(); a[cnt+1].x=read()+1; a[cnt].y=a[cnt+1].y=read()+1; a[cnt].id=i; a[cnt+1].id=i; a[cnt].z=-1; a[++cnt].z=1; } sort(a+1,a+cnt+1,cmp); dfs(1); build(1,1); int j=1,x; while (!a[j].x && j<=cnt) j++; for (i=1; i<=n; i++){ for (x=i; x; x=fa[anc[x]]) ins(pos[anc[x]],pos[x]); for (; a[j].x==i && j<=cnt; j++) for (x=a[j].y; x; x=fa[anc[x]]) ans[a[j].id]=(ans[a[j].id]+(qry(pos[x])-qry(pos[anc[x]]-1))%mod*a[j].z%mod+mod)%mod; } for (i=1; i<=m; i++) printf("%d\n",ans[i]); return 0; }
by lych
2016.2.29