3653: 谈笑风生
Time Limit: 20 Sec
Memory Limit: 512 MB
Submit: 720
Solved: 277
[ Submit][ Status][ Discuss]
Description
设T 为一棵有根树,我们做如下的定义:
• 设a和b为T 中的两个不同节点。如果a是b的祖先,那么称“a比b不知道
高明到哪里去了”。
• 设a 和 b 为 T 中的两个不同节点。如果 a 与 b 在树上的距离不超过某个给定
常数x,那么称“a 与b 谈笑风生”。
给定一棵n个节点的有根树T,节点的编号为1 到 n,根节点为1号节点。你需
要回答q 个询问,询问给定两个整数p和k,问有多少个有序三元组(a;b;c)满足:
1. a、b和 c为 T 中三个不同的点,且 a为p 号节点;
2. a和b 都比 c不知道高明到哪里去了;
3. a和b 谈笑风生。这里谈笑风生中的常数为给定的 k。
Input
输入文件的第一行含有两个正整数n和q,分别代表有根树的点数与询问的个数。接下来n - 1行,每行描述一条树上的边。每行含有两个整数u和v,代表在节点u和v之间有一条边。
接下来q行,每行描述一个操作。第i行含有两个整数,分别表示第i个询问的p和k。
Output
Sample Input
5 3
1 2
1 3
2 4
4 5
2 2
4 1
2 3
Sample Output
3
1
3
HINT
1<=P<=N
1<=K<=N
N<=300000
Q<=300000
Source
[ Submit][ Status][ Discuss]
题解:主席树+dfs序
(a,b,c)中a是给定的,因为a,b都比c不知道高明到哪里去了,所以c应该在a,b的子树中。
那么对于a,b的相对位置可以分为两种。
(1)b在a到根的路径上,那么从a向上走,距离在k之内的都可以是b点,如果b在a上面,那么c可以是a子树中除a的任意一点,min(deep[a]-1,k)*(size[a]-1)
(2)b在a的子树中(不与a重合),且deep[b]-deep[a]<=k ,所有我们要找寻的合法的b点就是子树中deep<=deep[a]+k的点,那么对应的合法的c点就在b的子树中。我们可以按照dfs序建立主席树,外层是dfs序,内层是权值线段树,权值线段树的权值为deep值,然后加入的数是这个点的size-1.统计答案的时候只需要计算dfs序中对应的a的子树区间中所有deep在[deep[a]+1,deep[a]+k]的点权和。
#include
#include
#include
#include
#include
#define N 600002
#define LL long long
using namespace std;
int n,m,k,deep[N],root[N];
int point[N],next[N],v[N],l[N],r[N],size[N],tot,sz,cnt,q[N];
struct data
{
int l,r;
LL sum;
}tr[N*30];
void add(int x,int y)
{
tot++; next[tot]=point[x]; point[x]=tot; v[tot]=y;
tot++; next[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void dfs(int x,int fa)
{
deep[x]=deep[fa]+1;
l[x]=++cnt; q[cnt]=x; size[x]=1;
for (int i=point[x];i;i=next[i])
if (v[i]!=fa) {
dfs(v[i],x);
size[x]+=size[v[i]];
}
r[x]=cnt;
}
void update(int x)
{
int l=tr[x].l; int r=tr[x].r;
tr[x].sum=tr[l].sum+tr[r].sum;
}
void insert(int &i,int l,int r,int pos,int v)
{
tr[++sz]=tr[i]; i=sz;
tr[i].sum+=(LL)v;
if (l==r) return;
int mid=(l+r)/2;
if (pos<=mid) insert(tr[i].l,l,mid,pos,v);
else insert(tr[i].r,mid+1,r,pos,v);
}
LL query(int i,int j,int l,int r,int ll,int rr)
{
if (ll<=l&&r<=rr) return tr[j].sum-tr[i].sum;
int mid=(l+r)/2;
LL ans=0;
if (ll<=mid) ans+=query(tr[i].l,tr[j].l,l,mid,ll,rr);
if (rr>mid) ans+=query(tr[i].r,tr[j].r,mid+1,r,ll,rr);
return ans;
}
int main()
{
freopen("a.in","r",stdin);
scanf("%d%d",&n,&m);
for (int i=1;i