链接:http://www.lydsy.com/JudgeOnline/problem.php?id=3653
题意:中文题。
分析:很明显这题可以分为两种情况:(1)b是a的祖先,那么有min(de[a]-1,y)种选择,c就随便选一个a子树中的点就行了。(2)b是a的子孙,c是b子树中的点,且b距离a小于等于k。第一种情况没压力,第二种情况才是关键。子树问题,我们优先想到dfs序,然后问题变成了在子树a的区间中深度在deep[a]+1~deep[a]+k中的所有点的size之和。用可持久化线段树即可。
代码:
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<math.h>
#include<cstdio>
#include<vector>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
const int N=300100;
const int MAX=151;
const int mod=100000000;
const int MOD1=1000000007;
const int MOD2=1000000009;
const double EPS=0.00000001;
typedef long long ll;
const ll MOD=998244353;
const ll INF=10000000010;
typedef double db;
typedef unsigned long long ull;
int tot,u[N],v[2*N],pre[2*N];
void add(int x,int y) {
v[tot]=y;pre[tot]=u[x];u[x]=tot++;
}
int k,a[N],in[N],out[N],de[N],siz[N];
void dfs(int x,int y) {
k++;in[x]=k;de[x]=de[y]+1;
a[k]=x;siz[x]=0;
for (int i=u[x];i!=-1;i=pre[i])
if (v[i]!=y) { dfs(v[i],x);siz[x]+=siz[v[i]]+1; }
out[x]=k;
}
ll sum[20*N];
int ls[20*N],rs[20*N],root[20*N];
void updata(int l,int r,int x,int &y,int z,int w) {
y=++k;sum[y]=sum[x]+w;
if (l==r) return ;
ls[y]=ls[x];rs[y]=rs[x];
int mid=(l+r)>>1;
if (z<=mid) updata(l,mid,ls[x],ls[y],z,w);
else updata(mid+1,r,rs[x],rs[y],z,w);
}
ll getsum(int l,int r,int x,int y,int z,int w) {
if (l==z&&r==w) return sum[y]-sum[x];
int mid=(l+r)>>1;
if (w<=mid) return getsum(l,mid,ls[x],ls[y],z,w);
else if (z>mid) return getsum(mid+1,r,rs[x],rs[y],z,w);
else return getsum(l,mid,ls[x],ls[y],z,mid)+getsum(mid+1,r,rs[x],rs[y],mid+1,w);
}
int main()
{
int i,n,m,x,y;
ll ans;
scanf("%d%d", &n, &m);
tot=0;memset(u,-1,sizeof(u));
for (i=1;i<n;i++) {
scanf("%d%d", &x, &y);
add(x,y);add(y,x);
}
de[1]=k=0;dfs(1,1);k=0;
for (i=1;i<=n;i++) updata(1,n,root[i-1],root[i],de[a[i]],siz[a[i]]);
while (m--) {
scanf("%d%d", &x, &y);
ans=(ll)siz[x]*min(de[x]-1,y);
if (de[x]!=n) ans+=getsum(1,n,root[in[x]-1],root[out[x]],de[x]+1,min(n,de[x]+y));
printf("%lld\n", ans);
}
return 0;
}