bzoj3653: 谈笑风生

链接:http://www.lydsy.com/JudgeOnline/problem.php?id=3653

题意:中文题。

分析:很明显这题可以分为两种情况:(1)b是a的祖先,那么有min(de[a]-1,y)种选择,c就随便选一个a子树中的点就行了。(2)b是a的子孙,c是b子树中的点,且b距离a小于等于k。第一种情况没压力,第二种情况才是关键。子树问题,我们优先想到dfs序,然后问题变成了在子树a的区间中深度在deep[a]+1~deep[a]+k中的所有点的size之和。用可持久化线段树即可。

代码:

#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<math.h>
#include<cstdio>
#include<vector>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
const int N=300100;
const int MAX=151;
const int mod=100000000;
const int MOD1=1000000007;
const int MOD2=1000000009;
const double EPS=0.00000001;
typedef long long ll;
const ll MOD=998244353;
const ll INF=10000000010;
typedef double db;
typedef unsigned long long ull;
int tot,u[N],v[2*N],pre[2*N];
void add(int x,int y) {
    v[tot]=y;pre[tot]=u[x];u[x]=tot++;
}
int k,a[N],in[N],out[N],de[N],siz[N];
void dfs(int x,int y) {
    k++;in[x]=k;de[x]=de[y]+1;
    a[k]=x;siz[x]=0;
    for (int i=u[x];i!=-1;i=pre[i])
    if (v[i]!=y) { dfs(v[i],x);siz[x]+=siz[v[i]]+1; }
    out[x]=k;
}
ll sum[20*N];
int ls[20*N],rs[20*N],root[20*N];
void updata(int l,int r,int x,int &y,int z,int w) {
    y=++k;sum[y]=sum[x]+w;
    if (l==r) return ;
    ls[y]=ls[x];rs[y]=rs[x];
    int mid=(l+r)>>1;
    if (z<=mid) updata(l,mid,ls[x],ls[y],z,w);
    else updata(mid+1,r,rs[x],rs[y],z,w);
}
ll getsum(int l,int r,int x,int y,int z,int w) {
    if (l==z&&r==w) return sum[y]-sum[x];
    int mid=(l+r)>>1;
    if (w<=mid) return getsum(l,mid,ls[x],ls[y],z,w);
    else if (z>mid) return getsum(mid+1,r,rs[x],rs[y],z,w);
        else return getsum(l,mid,ls[x],ls[y],z,mid)+getsum(mid+1,r,rs[x],rs[y],mid+1,w);
}
int main()
{
    int i,n,m,x,y;
    ll ans;
    scanf("%d%d", &n, &m);
    tot=0;memset(u,-1,sizeof(u));
    for (i=1;i<n;i++) {
        scanf("%d%d", &x, &y);
        add(x,y);add(y,x);
    }
    de[1]=k=0;dfs(1,1);k=0;
    for (i=1;i<=n;i++) updata(1,n,root[i-1],root[i],de[a[i]],siz[a[i]]);
    while (m--) {
        scanf("%d%d", &x, &y);
        ans=(ll)siz[x]*min(de[x]-1,y);
        if (de[x]!=n) ans+=getsum(1,n,root[in[x]-1],root[out[x]],de[x]+1,min(n,de[x]+y));
        printf("%lld\n", ans);
    }
    return 0;
}


你可能感兴趣的:(bzoj3653: 谈笑风生)