bzoj 3626: [LNOI2014]LCA(树链剖分+离线+差分)

3626: [LNOI2014]LCA

Time Limit: 10 Sec   Memory Limit: 128 MB
Submit: 1512   Solved: 563
[ Submit][ Status][ Discuss]

Description

给出一个n个节点的有根树(编号为0到n-1,根节点为0)。一个点的深度定义为这个节点到根的距离+1。
设dep[i]表示点i的深度,LCA(i,j)表示i与j的最近公共祖先。
有q次询问,每次询问给出l r z,求sigma_{l<=i<=r}dep[LCA(i,z)]。
(即,求在[l,r]区间内的每个节点i与z的最近公共祖先的深度之和)

Input

第一行2个整数n q。
接下来n-1行,分别表示点1到点n-1的父节点编号。
接下来q行,每行3个整数l r z。

Output

输出q行,每行表示一个询问的答案。每个答案对201314取模输出

Sample Input

5 2
0
0
1
1
1 4 3
1 4 2

Sample Output

8
5

HINT

共5组数据,n与q的规模分别为10000,20000,30000,40000,50000。


Source

数据已加强 by saffah

[ Submit][ Status][ Discuss] 

题解:树链剖分+离线+差分

先想一想暴力的做法,我们要求sigma_{l<=i<=r}dep[LCA(i,z)],那么容易想到的就是将z到根路径上的点都打上标记,那么对于l,r中的点只需要找到第一标记,第一个标记出现的位置就是lca所在的位置,因为我们把路径上的点都标记过了,所以如果我们想要求deep[lca(i,z)]的话,其实就把i到根路径上的标记统计一下就是答案,因为一个点上面有多少个标记(包括这个点本身),其实他的deep值就是多少。那么我们在此基础上在转变一下思路,我们这次不标记z到根的路径,而改成标记[l,r]到根的路径,将路径上点的点权+1,最后统计z 到根路径上的标记数即为答案。那么我们肯定不能来回更改,怎么才能保证1-n的点只更改一次呢?  差分!把每个询问拆成两个,端点分别为l-1,r,然后将所有询问的端点从小到大排序,然后依次更改1-n的答案,这样就能保证在计算x时之前的[1,x]区间中的点到根的路径已经更新过了,那么计算当前的答案即可。对于每一个询问答案就是query(r)-query(l-1),不解释。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 500003
#define mod 201314
using namespace std;
int n,m,num;
int deep[N],point[N],size[N],next[N],v[N],tot,sz,top;
int cur[N],fa[N],pos[N],belong[N],st[N],son[N],use[N];
int tr[N*4],delta[N*4],ans1[N],ans2[N];
struct data{
	int id,x,f,t;
}q[N];
int cmp(data a,data b)
{
	return a.x<b.x||a.x==b.x&&a.f<b.f;
}
void add(int x,int y)
{
	tot++; next[tot]=point[x]; point[x]=tot; v[tot]=y;
	tot++; next[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void dfs()
{
	top=0;
	for (int i=1;i<=n;i++)  cur[i]=point[i];
	st[++top]=1; size[1]=1; deep[1]=1;
	while (top)
	{
		int x=st[top];
	    if (v[cur[x]]==fa[x]) cur[x]=next[cur[x]];
	    if (!cur[x]){
	    	top--;
	    	if (fa[x]){
	    		size[fa[x]]+=size[x];
	    		if (size[x]>size[son[fa[x]]]) son[fa[x]]=x;
	    	}
	    	continue;
	    }
	    int t=v[cur[x]];
	    deep[t]=deep[x]+1; size[t]=1; fa[t]=x;
	    st[++top]=t; cur[x]=next[cur[x]];
	}
}
void dfs2()
{
	top=0;
	for (int i=1;i<=n;i++) cur[i]=point[i];
	st[++top]=1; belong[1]=1; pos[1]=++sz;
	while (top)
	{
		int x=st[top];
		if (!use[x])
		{
			use[x]=1;
			if (son[x]){
				pos[son[x]]=++sz; belong[son[x]]=belong[x];
				st[++top]=son[x];
			}
			continue;
		}
		while (cur[x]&&(v[cur[x]]==fa[x]||v[cur[x]]==son[x])) cur[x]=next[cur[x]];
		if (!cur[x]){
			--top; continue;
		}
		int t=v[cur[x]];
		belong[t]=t; pos[t]=++sz; st[++top]=t;
		cur[x]=next[cur[x]];
	}
}
void update(int now)
{
	tr[now]=tr[now<<1]+tr[now<<1|1];
}
void pushdown(int x,int l,int r)
{
	if (!delta[x]) return;
	int mid=(l+r)/2;
	delta[x<<1]+=delta[x]; delta[x<<1|1]+=delta[x];
	tr[x<<1]+=(mid-l+1)*delta[x];
	tr[x<<1|1]+=(r-mid)*delta[x];
	delta[x]=0;
}
void query(int now,int l,int r,int ll,int rr)
{
	if (l>=ll&&r<=rr){
		tr[now]+=(r-l+1);
		delta[now]++;
		return;
	}
	pushdown(now,l,r);
	int mid=(l+r)/2;
	if (ll<=mid) query(now<<1,l,mid,ll,rr);
	if (rr>mid) query(now<<1|1,mid+1,r,ll,rr);
	update(now);
}

void solve1(int x,int y)
{
	while (belong[x]!=belong[y]){
		if (deep[belong[x]]<deep[belong[y]]) swap(x,y);
		query(1,1,n,pos[belong[x]],pos[x]);
		x=fa[belong[x]];
	}
	if (deep[x]>deep[y]) swap(x,y);
	query(1,1,n,pos[x],pos[y]);
}
int qjsum(int now,int l,int r,int ll,int rr)
{
	if (l>=ll&&r<=rr) return tr[now];
	pushdown(now,l,r);
	int mid=(l+r)/2;  int ans=0;
	if (ll<=mid) ans+=qjsum(now<<1,l,mid,ll,rr);
	if (rr>mid) ans+=qjsum(now<<1|1,mid+1,r,ll,rr);
	return ans;
}
int solve(int x,int y)
{
	int ans=0;
	while (belong[x]!=belong[y]){
		if (deep[belong[x]]<deep[belong[y]])  swap(x,y);
		ans+=qjsum(1,1,n,pos[belong[x]],pos[x])%mod;
		x=fa[belong[x]];
	}
	if (deep[x]>deep[y]) swap(x,y);
	ans+=qjsum(1,1,n,pos[x],pos[y])%mod;
	return ans;
}
int main()
{
	scanf("%d%d",&n,&m);
	for (int i=2;i<=n;i++)
	{
		int x; scanf("%d",&x); x++;
		add(i,x);
	}
    dfs();  dfs2();
    for (int i=1;i<=m;i++)
     {
     	int x,y,z; scanf("%d%d%d",&x,&y,&z); x++; y++; z++;
     	++num; q[num].x=x-1; q[num].id=i; q[num].f=0;  q[num].t=z;
		++num; q[num].x=y; q[num].id=i; q[num].f=1;    q[num].t=z;
     }
    sort(q+1,q+num+1,cmp);
    int j=1;
    for (int i=1;i<=n;i++)
     {
     	solve1(1,i);                   
     	while (q[j].x==0)  j++;
     	while (q[j].x==i){
     	int x=q[j].id;
     	if (q[j].f)  ans1[x]=solve(1,q[j].t)%mod;
     	else ans2[x]=solve(1,q[j].t)%mod;
     	j++;
        }
     }
    for (int i=1;i<=m;i++)
     printf("%d\n",((ans1[i]-ans2[i])%mod+mod)%mod);
}


你可能感兴趣的:(bzoj 3626: [LNOI2014]LCA(树链剖分+离线+差分))