CF 293 E Close Vertices (树的分治+树状数组)

转载请注明出处,谢谢http://blog.csdn.net/ACM_cxlove?viewmode=contents    by---cxlove

题目:给出一棵树,问有多少条路径权值和不大于w,长度不大于l。

http://codeforces.com/contest/293/problem/E

有男人八题很相似,但是多了一个限制。

同样 还是点分治,考虑二元组(到根的路径权值和,到根的路径长度)。

按第一维度排序之后,可以用two points查询权值小不大于w的,然后 用树状数组维护路径长度。

也就是第一个条件用two points,第二个条件用树状数组维护。

 

#include <iostream>

#include <cstdio>

#include <cstring>

#include <algorithm>

#include <vector>

#define lson step << 1

#define rson step << 1 | 1

#define pb(a) push_back(a)

#define mp(a,b) make_pair(a , b)

#define lowbit(x) (x & (-x))

#pragma comment(linker, "/STACK:1024000000,1024000000")    

using namespace std;

typedef long long LL;

const int N = 100005;

struct Edge {

    int v , w , next;

}e[N << 1];

int n , l , w , tot , start[N];

int del[N] = {0} , size[N];

LL ans = 0LL;

void _add (int u , int v , int w) {

    e[tot].v = v ; e[tot].next = start[u];

    e[tot].w = w;

    start[u] = tot ++;

}

void add (int u , int v , int w) {

    _add (u , v , w);

    _add (v , u , w);

}

void calsize (int u , int pre) {

    size[u] = 1;

    for (int i = start[u] ; i != -1 ; i = e[i].next) {

        int v = e[i].v;

        if (v == pre || del[v]) continue;

        calsize (v , u);

        size[u] += size[v];

    }

}

int totalsize , maxsize , rootidx;

void dfs (int u , int pre) {

    int mx = totalsize - size[u];

    for (int i = start[u] ; i != -1 ; i = e[i].next) {

        int v = e[i].v;

        if (v == pre || del[v]) continue;

        mx = max (mx , size[v]);

        dfs (v , u);

    }

    if (mx < maxsize) maxsize = mx , rootidx = u;

}

int search (int r) {

    calsize (r , -1);

    totalsize = size[r];

    maxsize = 1 << 30;

    dfs (r , -1);

    return rootidx;

}

vector<pair<int,int> > sub[N] , all;

int idx , dist[N] , cnt[N];

void gao (int u , int pre) {

    all.pb(mp(dist[u] , cnt[u]));

    sub[idx].pb(mp(dist[u] , cnt[u]));

    for (int i = start[u] ; i != -1 ; i = e[i].next) {

        int v = e[i].v , w = e[i].w;

        if (v == pre || del[v]) continue;

        dist[v] = dist[u] + w;

        cnt[v] = cnt[u] + 1;

        gao (v , u);

    }

}

int s[N] , up;

void add (int x , int val) {

    for (int i = x ; i <= up ; i += lowbit (i)) {

        s[i] += val;

    }

}

int ask (int x) {

    int ret = 0;

    for (int i = x ; i > 0 ; i -= lowbit (i)) {

        ret += s[i];

    }

    return ret;

}

LL fuck (vector<pair<int , int> > &v) {

    LL ret = 0;

    up = 0;

    for (int i = 0 ; i < v.size() ; i ++)

        up = max (up , v[i].second);

    for (int i = 1 ; i <= up ; i ++)

        s[i] = 0;

    for (int i = 0 ; i < v.size() ; i ++)

        add (v[i].second , 1);

    for (int i = 0 , j = v.size() - 1 ; i < v.size() ; i ++) {

        while (j >= i && v[i].first + v[j].first > w) {

            add (v[j].second , -1);

            j --;

        }

        if (j < i) break;

        ret += ask (min(up , (l - v[i].second)));

        add (v[i].second , -1);

    }

    return ret;

}

void solve (int root) {

    root = search (root);

    del[root] = 1;

    if (totalsize == 1) return ;

    idx = 0 ;all.clear();

    for (int i = start[root] ; i != -1 ; i = e[i].next) {

        int v = e[i].v , w = e[i].w;

        if (del[v]) continue;

        sub[idx].clear();

        dist[v] = w ; cnt[v] = 1;

        gao (v , -1);

        sort (sub[idx].begin() , sub[idx].end());

        idx ++;

    }

    sort (all.begin() , all.end());

    ans += fuck (all);

    for (int i = 0 ; i < idx ; i ++) {

        for (int j = 0 ; j < sub[i].size() ; j ++) {

            if (sub[i][j].first <= w && sub[i][j].second <= l) {

                ans ++;

            }

        }

        ans -= fuck (sub[i]);

    }

    for (int i = start[root] ; i != -1 ; i = e[i].next) {

        int v = e[i].v;

        if (del[v]) continue;

        solve (v);

    }

}

int main () {

    // freopen ("input.txt" , "r" , stdin);

    // freopen ("output.txt" , "w" , stdout);

    tot = 0;memset (start , -1 , sizeof(start));

    scanf ("%d %d %d" , &n , &l , &w);

    for (int i = 1 ; i < n ; i ++) {

        int p , d;

        scanf ("%d %d" , &p , &d);

        add (i + 1 , p , d);

    }

    solve (1);

    printf ("%I64d\n" , ans);

    return 0;

}


 

 

你可能感兴趣的:(close)