给一棵n个点的树,根节点是1,每个点有点权。q个询问,每个询问给出两个数x和k,问满足下列条件的(i,j)二元组个数。
(1)
(2)
(3)是的祖先
题目中,但数据是
多组测例。T未知。
数据范围:
这题想了半天nlogn做法,未果,没想到是big-small。
cnt[i]表示权值是i的节点个数。
1. ,遍历到u节点时
(1)t是权值是a[u]的询问。树状数组查询的区间和,注意树状数组查询时的边界。
(2)树状数组下标是a[u]的数加1
(3)遍历儿子
(4)树状数组下标是a[u]的数减1
这其实是纯暴力,对小于根号的部分的每个询问都直接统计父亲的个数。
2. ,siz[z]表示以z节点为根的子树的权值是a[u]的个数。
预处理出siz[z]后,树状数组中下标是a[z]的数加siz[z]。然后区间询问就好了。
时间复杂度:
这道题很不错,很好的考察了big-small的运用,但是题目数据范围出锅,让我连wa一个多小时。
#include
#define pb push_back
#define fi first
#define se second
#define sz(x) (int)x.size()
#define cl(x) x.clear()
#define all(x) x.begin() , x.end()
#define rep(i , x , n) for(int i = x ; i <= n ; i ++)
#define per(i , n , x) for(int i = n ; i >= x ; i --)
#define mem0(x) memset(x , 0 , sizeof(x))
#define mem_1(x) memset(x , -1 , sizeof(x))
#define mem_inf(x) memset(x , 0x3f , sizeof(x))
#define debug(x) cerr << '*' << x << '\n'
#define ddebug(x , y) cerr << '*' << x << ' ' << y << '\n'
#define ios std::ios::sync_with_stdio(false) , cin.tie(0)
using namespace std ;
typedef long long ll ;
typedef long double ld ;
typedef pair pii ;
typedef pair pll ;
const int mod = 998244353 ;
const int maxn = 1e5 + 10 ;
const int inf = 0x3f3f3f3f ;
const double eps = 1e-6 ;
mt19937 rnd(chrono::high_resolution_clock::now().time_since_epoch().count()) ;
int n , Q ;
int a[maxn] ;
vector g[maxn] ;
int x[maxn] , k[maxn] ;
int cnt[maxn] ;
ll ans[maxn] ;
int lim ;
bool vis[maxn] ;
vector q[maxn] ;
struct BIT
{
ll tree[maxn] ; //开 1 倍空间
int n ;
void init()
{
n = 20000 ;
mem0(tree) ;
}
int lowbit(int k)
{
return k & -k ;
}
void add(int x , ll k) // a[x] += k
{
while(x <= n) //维护的是 [1 , n] 的序列
{
tree[x] += k ;
x += lowbit(x) ;
}
}
ll sum(int x) // sum[l , r] = sum(r) - sum(l - 1)
{
ll ans = 0 ;
while(x != 0)
{
ans += tree[x] ;
x -= lowbit(x) ;
}
return ans ;
}
ll query(int l , int r)
{
l = max(l , 1) ;
r = min(r , 20000) ;
return sum(r) - sum(l - 1) ;
}
} bit ;
void dfs1(int fa , int u)
{
for(auto t : q[a[u]]) ans[t] += bit.query(x[t] - k[t] , x[t] + k[t]) ;
bit.add(a[u] , 1) ;
for(auto v : g[u])
{
if(v == fa) continue ;
dfs1(u , v) ;
}
bit.add(a[u] , -1) ;
}
ll siz[maxn] ;
void dfs2(int fa , int u , int col , int f)
{
siz[u] = (a[u] == col) ;
for(auto v : g[u])
{
if(v == fa) continue ;
dfs2(u , v , col , f) ;
siz[u] += siz[v] ;
}
bit.add(a[u] , f * siz[u]) ;
}
int main()
{
ios ;
int T ;
bit.init() ;
cin >> T ;
while(T --)
{
cin >> n ;
rep(i , 1 , n) cl(g[i]) ;
rep(i , 1 , n - 1)
{
int u , v ;
cin >> u >> v ;
g[u].pb(v) , g[v].pb(u) ;
}
rep(i , 1 , n) cin >> a[i] , cnt[a[i]] ++ , assert(a[i] >= 1 && a[i] <= 20000) ;
lim = sqrt(n) ;
int Q ;
cin >> Q ;
rep(i , 1 , Q) cin >> x[i] >> k[i] , ans[i] = 0 ;
//small
rep(i , 1 , Q) if(cnt[x[i]] < lim) q[x[i]].pb(i) ;
dfs1(1 , 1) ;
//big
rep(i , 1 , Q) cl(q[x[i]]) ;
rep(i , 1 , Q) if(cnt[x[i]] >= lim) q[x[i]].pb(i) ;
rep(i , 1 , Q)
if(!vis[x[i]] && cnt[x[i]] >= lim)
{
vis[x[i]] = 1 ;
dfs2(1 , 1 , x[i] , 1) ;
for(auto t : q[x[i]]) ans[t] = bit.query(x[t] - k[t] , x[t] + k[t]) - cnt[x[i]] ;
dfs2(1 , 1 , x[i] , -1) ;
}
rep(i , 1 , Q) cout << ans[i] << '\n' ;
rep(i , 1 , Q) cl(q[x[i]]) ;
rep(i , 1 , Q) vis[x[i]] = 0 ;
rep(i , 1 , n) cnt[a[i]] = 0 ;
}
return 0 ;
}