题目如上描述,方法是LCA+树剖。
附AC代码(虽然我没AC):
#include
#include
#include
#include
#include
using namespace std;
const int maxn = 4e5+5;
int n,m,siz,cnt,tot,x,y;
int fa[maxn],num[maxn],d[maxn],g[maxn][20],w[maxn],ans[maxn],root[maxn],size[maxn];
int first[maxn],nxt[maxn*2],to[maxn*2];
struct node
{
int s,t,l;
}e[maxn];
struct Tree
{
int lc,rc,s;
}T[maxn*30];
template <class T> inline void read(T &xx)
{
xx = 0;
T flag = 1;
char ch = (char)getchar();
while(ch<'0' || ch>'9')
{
if(ch == '-') flag = -1;
ch = (char)getchar();
}
while(ch>='0' && ch<='9')
{
xx = (xx<<1) + (xx<<3) + ch - '0';
ch = (char)getchar();
}
xx *= flag;
}
void insert(int x, int y)
{
nxt[++siz] = first[x];
first[x] = siz;
to[siz] = y;
}
void dfs1(int x)
{
size[x] = 1, num[x] = ++cnt;
for(int k = 0; g[x][k]; k++)
g[x][k+1] = g[g[x][k]][k];
for(int i = first[x]; i; i=nxt[i])
if(!fa[to[i]] && to[i] != fa[x])
{
fa[to[i]] = g[to[i]][0] = x;
d[to[i]] = d[x] + 1;
dfs1(to[i]);
size[x] += size[to[i]];
}
}
void add(int &v, int vv, int l, int r, int x)
{
T[v = ++tot] = T[vv];
if(l == r)
{
T[v].s++;
return;
}
int mid = l+r >> 1;
if(x <= mid) add(T[v].lc, T[vv].lc, l, mid, x);
else add(T[v].rc, T[vv].rc, mid+1, r, x);
T[v].s++;
}
int query(int v, int vv, int l, int r, int x, int y)
{
if(x <= l && r <= y) return T[v].s - T[vv].s;
int mid = l+r >> 1, ret = 0;
if(x <= mid) ret += query(T[v].lc, T[vv].lc, l, mid, x, y);
if(mid < y) ret += query(T[v].rc, T[vv].rc, mid+1, r, x, y);
return ret;
}
int lca(int x, int y)
{
if(d[x] < d[y]) swap(x,y);
for(int i = 18; ~i; i--)
if (d[g[x][i]] >= d[y]) x = g[x][i];
if(x == y) return x;
for(int i = 18; ~i; i--)
if(g[x][i] ^ g[y][i]) x = g[x][i], y = g[y][i];
return g[x][0];
}
bool cmp(const node &a, const node &b)
{
return d[a.s] < d[b.s] || d[a.s] == d[b.s] && num[a.s] < num[b.s];
}
bool cmp2(const node &a, const node &b)
{
int d1 = d[a.s]-a.l, d2 = d[b.s]-b.l;
return d1 < d2 || d1 == d2 && num[a.s] < num[b.s];
}
int find1(int dep, int st)
{
int l = 0, r = m+1;
while(l < r-1)
{
int mid = l+r >> 1;
if(d[e[mid].s] < dep || d[e[mid].s] == dep && num[e[mid].s] < st) l = mid;
else r = mid;
}
return l+1;
}
int find2(int dep, int en)
{
int l = 0, r = m+1;
while(l < r-1)
{
int mid = l+r >> 1;
if(d[e[mid].s] > dep || d[e[mid].s] == dep && num[e[mid].s] > en)
r = mid;
else l = mid;
}
return r-1;
}
int findl(int dep, int st)
{
int l = 0, r = m+1;
while(l < r-1)
{
int mid = l+r >> 1;
if(d[e[mid].s]-e[mid].l < dep || d[e[mid].s]-e[mid].l == dep && num[e[mid].s] < st)
l = mid;
else r = mid;
}
return l + 1;
}
int findr(int dep, int en)
{
int l = 0, r = m+1;
while(l < r-1)
{
int mid = l+r >> 1;
if(d[e[mid].s]-e[mid].l > dep || d[e[mid].s]-e[mid].l == dep && num[e[mid].s] > en)
r = mid;
else l = mid;
}
return r-1;
}
void init()
{
read(n); read(m);
for(int i = 1; i < n; i++)
{
read(x); read(y);
insert(x, y); insert(y, x);
}
for(int i = 1; i <= n; i++) read(w[i]);
dfs1(d[1] = 1);
for(int i = 1; i <= m; i++)
{
read(e[i].s); read(e[i].t);
int z = lca(e[i].s, e[i].t);
e[i].l = d[e[i].s] - d[z] + d[e[i].t] - d[z];
if(w[z] == d[e[i].s]-d[z]) ans[z]++;
}
}
int main()
{
freopen("running.in","r",stdin);
freopen("running.out","w",stdout);
init();
sort(e+1, e+m+1, cmp);
for(int i = 1; i <= m; i++)
add(root[i], root[i - 1], 1, n, num[e[i].t]);
for(int i = 1; i <= n; i++)
{
int l = find1(d[i]+w[i], num[i]), r = find2(d[i]+w[i], num[i]+size[i]-1);
if(l <= r)
{
if(num[i] > 1)
ans[i] += query(root[r], root[l-1], 1, n, 1, num[i]-1);
if(num[i]+size[i] <= n)
ans[i] += query(root[r], root[l - 1], 1, n , num[i]+size[i], n);
}
}
for(int i = 1; i <= n; i++)
w[i] = n-w[i];
for(int i = 1; i <= m; i++)
swap(e[i].s, e[i].t);
sort(e+1, e+m+1, cmp2);
for(int i = 1; i <= tot; i++)
T[i].lc = T[i].rc = T[i].s = 0;
for(int i = 1; i <= m; i++)
add(root[i], root[i-1], 1 , n, num[e[i].t]);
for(int i = 1; i <= n; i++)
{
int l = findl(d[i]+w[i]-n, num[i]), r = findr(d[i]+w[i]-n, num[i]+size[i]-1);
if(l <= r)
{
if(num[i] > 1)
ans[i] += query(root[r], root[l-1], 1, n, 1, num[i]-1);
if(num[i] + size[i] <= n)
ans[i] += query(root[r], root[l-1], 1, n, num[i]+size[i], n);
}
}
for (int i = 1; i <= n; i++)
cout << ans[i] << ' ';
return 0;
}