牛客15541 Counting On A Tree Again

题目链接

一、题意

给一棵n个点的树,根节点是1,每个点有点权a_i。q个询问,每个询问给出两个数x和k,问满足下列条件的(i,j)二元组个数。

(1)a_i=x

(2)\left | a_i-a_j \right | \leqslant k

(3)ji的祖先

题目中a_i \leqslant n,但数据是1 \leqslant a_i \leqslant 2 \cdot 10^4

多组测例。T未知。

数据范围:1 \leqslant n , a_i \leqslant 2 \cdot 10^4 , 1 \leqslant q \leqslant 10^5,1 \leqslant x,k \leqslant 2 \cdot 10^4

二、题解

这题想了半天nlogn做法,未果,没想到是big-small。

cnt[i]表示权值是i的节点个数。

1.  cnt[a[u]]\leqslant \sqrt{n} ,遍历到u节点时

(1)t是权值是a[u]的询问。树状数组查询[x[t] - k[t] , x[t] + k[t]]的区间和,注意树状数组查询时的边界。

(2)树状数组下标是a[u]的数加1

(3)遍历儿子

(4)树状数组下标是a[u]的数减1

这其实是纯暴力,对小于根号的部分的每个询问都直接统计父亲的个数。

2.  cnt[a[u]] > \sqrt{n} ,siz[z]表示以z节点为根的子树的权值是a[u]的个数。

预处理出siz[z]后,树状数组中下标是a[z]的数加siz[z]。然后区间询问就好了。

时间复杂度:\dpi{150} O(q*\sqrt{n}*logn)

三、感受

这道题很不错,很好的考察了big-small的运用,但是题目数据范围出锅,让我连wa一个多小时。

四、代码

#include
#define pb push_back
#define fi first
#define se second
#define sz(x)  (int)x.size()
#define cl(x)  x.clear()
#define all(x)  x.begin() , x.end()
#define rep(i , x , n)  for(int i = x ; i <= n ; i ++)
#define per(i , n , x)  for(int i = n ; i >= x ; i --)
#define mem0(x)  memset(x , 0 , sizeof(x))
#define mem_1(x)  memset(x , -1 , sizeof(x))
#define mem_inf(x)  memset(x , 0x3f , sizeof(x))
#define debug(x)  cerr << '*' << x << '\n'
#define ddebug(x , y)  cerr << '*' << x << ' ' << y << '\n'
#define ios std::ios::sync_with_stdio(false) , cin.tie(0)
using namespace std ;
typedef long long ll ;
typedef long double ld ;
typedef pair pii ;
typedef pair pll ;
const int mod = 998244353 ;
const int maxn = 1e5 + 10 ;
const int inf = 0x3f3f3f3f ;
const double eps = 1e-6 ; 
mt19937  rnd(chrono::high_resolution_clock::now().time_since_epoch().count()) ; 
int n , Q ;
int a[maxn] ;
vector g[maxn] ;
int x[maxn] , k[maxn] ;
int cnt[maxn] ;
ll ans[maxn] ;
int lim ;
bool vis[maxn] ;
vector q[maxn] ;
struct BIT
{
    ll tree[maxn] ; //开 1 倍空间
    int n ;
    void init()
    {
        n = 20000 ;
        mem0(tree) ;
    }
    int lowbit(int k)
    {
        return k & -k ;
    }
    void add(int x , ll k)  // a[x] += k
    {
        while(x <= n)  //维护的是 [1 , n] 的序列
        {
            tree[x] += k ;
            x += lowbit(x) ;
        }
    }
    ll sum(int x)  // sum[l , r] = sum(r) - sum(l - 1)
    {
        ll ans = 0 ;
        while(x != 0)
        {
            ans += tree[x] ;
            x -= lowbit(x) ;
        }
        return ans ;
    }
    ll query(int l , int r)
    {
        l = max(l , 1) ;
        r = min(r , 20000) ;
        return sum(r) - sum(l - 1) ;
    }
} bit ;
void dfs1(int fa , int u)
{
    for(auto t : q[a[u]])  ans[t] += bit.query(x[t] - k[t] , x[t] + k[t]) ;
    bit.add(a[u] , 1) ;
    for(auto v : g[u])
    {
        if(v == fa)  continue ;
        dfs1(u , v) ;
    }
    bit.add(a[u] , -1) ;
}
ll siz[maxn] ;
void dfs2(int fa , int u , int col , int f)
{
    siz[u] = (a[u] == col) ;
    for(auto v : g[u])
    {
        if(v == fa)  continue ;
        dfs2(u , v , col , f) ;
        siz[u] += siz[v] ;
    }
    bit.add(a[u] , f * siz[u]) ;
}
int main()
{
    ios ;
    int T ;
    bit.init() ;
    cin >> T ;
    while(T --)
    {
        cin >> n ;
        rep(i , 1 , n)  cl(g[i]) ;
        rep(i , 1 , n - 1)
        {
            int u , v ;
            cin >> u >> v ;
            g[u].pb(v) , g[v].pb(u) ;
        }
        rep(i , 1 , n)  cin >> a[i] , cnt[a[i]] ++ , assert(a[i] >= 1 && a[i] <= 20000) ;
        lim = sqrt(n) ;
        int Q ;
        cin >> Q ;
        rep(i , 1 , Q)  cin >> x[i] >> k[i] , ans[i] = 0 ;
        //small
        rep(i , 1 , Q)  if(cnt[x[i]] < lim)  q[x[i]].pb(i) ;
        dfs1(1 , 1) ;
        
        //big
        rep(i , 1 , Q)  cl(q[x[i]]) ;
        rep(i , 1 , Q)  if(cnt[x[i]] >= lim)  q[x[i]].pb(i) ;
        rep(i , 1 , Q)
            if(!vis[x[i]] && cnt[x[i]] >= lim)
            {
                vis[x[i]] = 1 ;
                dfs2(1 , 1 , x[i] , 1) ;
                for(auto t : q[x[i]])  ans[t] = bit.query(x[t] - k[t] , x[t] + k[t]) - cnt[x[i]] ;
                dfs2(1 , 1 , x[i] , -1) ;
            }
        rep(i , 1 , Q)  cout << ans[i] << '\n' ;
        rep(i , 1 , Q)  cl(q[x[i]]) ;
        rep(i , 1 , Q)  vis[x[i]] = 0 ;
        rep(i , 1 , n)  cnt[a[i]] = 0 ;
    }
    return 0 ;
}

 

你可能感兴趣的:(#,big-small)