poj1741 Tree

树的点分治 感觉理解的不够深刻
等想好再多写点

#include <iostream> 
#include <queue> 
#include <stack> 
#include <map> 
#include <set> 
#include <bitset> 
#include <cstdio> 
#include <algorithm> 
#include <cstring> 
#include <climits> 
#include <cstdlib> 
#include <cmath> 
#include <time.h> 
using namespace std;
typedef long long ll;
const int MAXN = 10050;
typedef pair<int,int> pi;
const int INF = 0x3f3f3f3f;

int N,K;
ll ans;
vector<pi> mp[MAXN];
int vis[MAXN];
/***************WeightRoot************/
int all, num, center;
int dp[MAXN], nodes[MAXN]; 
void findRoot(int x,int pre) {
    nodes[x] = 1; dp[x] = 0;
    for(int i = 0; i < (int)mp[x].size(); ++i) {
        int y = mp[x][i].first; if(y == pre || vis[y]) continue;
        findRoot(y,x);
        nodes[x] += nodes[y];
        dp[x] = max(dp[x], nodes[y]);
    }
    dp[x] = max(dp[x], all-dp[x]);
    if(dp[x] < num) {
        num = dp[x]; center = x;
    }
}
int getRoot(int root,int sn) {
    num = INF; all = sn; center = root;
    findRoot(root, -1);
    return center;  
}
/***********treecdq***************/
vector<int> Dep; int dep[MAXN];
void getDp(int x,int pre) {
    nodes[x] = 1;
    Dep.push_back(dep[x]);
    for(int i = 0; i < (int)mp[x].size(); ++i) {
        int y = mp[x][i].first; if(y == pre || vis[y]) continue;
        dep[y] = dep[x] + mp[x][i].second;
        getDp(y,x);
        nodes[x] += nodes[y];
    }
}
ll Cal(int x,int tag) {
    Dep.clear();
    if(tag) dep[x] = 0;
    getDp(x,-1);
    ll sum = 0;
    sort(Dep.begin(), Dep.end());
    for(int i = 0, j = Dep.size()-1 ; i < (int)Dep.size(); ++i) {
        while(j > i && Dep[i]+Dep[j] > K) j--;
        if(i < j) sum += (j-i);
    } 
    return sum;
}
void work(int x) {
    vis[x] = 1; 
    ans += Cal(x,1);
    for(int i = 0; i < (int)mp[x].size(); ++i) {
        int y = mp[x][i].first; if(vis[y]) continue;
        ans -= Cal(y,0);
        work(getRoot(y,nodes[y]));
    }
}
int main(){
    while(~scanf("%d %d",&N,&K)) {
        if(N==0 && K==0) break;
        memset(vis,0,sizeof(vis));
        for(int i = 1; i <= N; ++i) mp[i].clear();
        for(int i = 1; i < N; ++i) {
            int a,b,c; scanf("%d %d %d",&a,&b,&c);
            mp[a].push_back({b,c}); mp[b].push_back({a,c});
        }
        ans = 0;
        work(getRoot(1,N));
        printf("%lld\n",ans);
    }
    return 0;
}

你可能感兴趣的:(poj1741 Tree)