树的分治学习

 

poj 1741 http://poj.org/problem?id=1741

题意:求树上距离小于k的点对的对数;

分析:每条路经要么过根,要么不过根,对于不过根的路径我们递归同理可以求出,

对于过根的路径,我们dfs一遍记录下所有其他节点到跟的距离,然后sort()一下可以o(n)求出

其小于k的pair,  但我们要减去其中来自同一子树的pair;

利用重心分治,最多logn次,所有时间o(n * logn * logn);

重心的定义:去掉该点后,其子树节点个数的最大值最小;

step1:对于当前树,找到其重心,统计pair;

step2:把重心当根,分别递归其子树;

findsz(),findC()为求重心;

 1 #include<cstdio>

 2 #include<cstring>

 3 #include<algorithm>

 4 #include<cmath>

 5 #include<vector>

 6 #include<cstdlib>

 7 #define MP make_pair

 8 using namespace std;

 9 typedef pair<int,int> pii;

10 const int N = 10000+10;

11 vector<pii> g[N];

12 int n,KK;

13 

14 int cn[N];

15 int vis[N];

16 int findsz(int u,int fa) {

17     int ret = 1;

18     int sz = g[u].size();

19     for (int i = 0; i < sz; i++) {

20         int c = g[u][i].first;

21         if (c == fa || vis[c]) continue;

22         ret += findsz(c,u);

23     }

24     return ret;

25 }

26 void findC(int u,int fa,int &k,int &mark,int nn) {

27     int mx = 0;

28     int sz = g[u].size();

29     cn[u] = 1;

30     for (int i = 0; i < sz; i++) {

31         int c = g[u][i].first;

32         if (c == fa || vis[c]) continue;

33         findC(c,u,k,mark,nn);

34         cn[u] += cn[c];

35         if (cn[c] > mx) mx = cn[c];

36     }

37     if (nn - cn[u] > mx) mx = nn - cn[u];

38     if (mark == -1 || mx < mark) mark = mx,k = u;

39 }

40 int num[N];

41 int cnt;

42 void findnum(int u,int fa,int dep) {

43     num[cnt++] = dep;

44     int sz = g[u].size();

45     for (int i = 0; i < sz; i++) {

46         int c = g[u][i].first;

47         if (vis[c] || c == fa) continue;

48         findnum(c,u,dep+g[u][i].second);

49     }

50 }

51 int calc(int k,int w) {

52     cnt = 0;

53     findnum(k,0,w);

54     sort(num,num+cnt);

55     int r = cnt-1;

56     int ret = 0;

57     for (int i = 0; i < r; i++) {

58         while (num[i] + num[r] > KK && i < r) r--;

59         ret += r - i;

60     }

61     return ret;

62 }

63 int ans;

64 void dfs(int u,int w) {

65     int nn = findsz(u,0);

66     int k = 0, mark = -1;

67     findC(u,0,k,mark,nn);

68     int sz = g[k].size();

69     vis[k] = 1;

70     ans += calc(k,0);

71 

72     for (int i = 0; i < sz; i++) {

73         int c = g[k][i].first;

74         if (vis[c]) continue;

75         ans -= calc(c,g[k][i].second);

76         dfs(c,g[k][i].second);

77     }

78 }

79 void solve(){

80     memset(vis,0,sizeof(vis));

81     ans = 0;

82     dfs(1,0);

83     printf("%d\n",ans);

84 }

85 int main(){

86     while (~scanf("%d%d",&n,&KK),n+KK) {

87         for (int i = 0; i <= n; i++) g[i].clear();

88         for (int i = 0; i < n-1; i++) {

89             int u,v,w; scanf("%d%d%d",&u,&v,&w);

90             g[u].push_back(MP(v,w));

91             g[v].push_back(MP(u,w));

92         }

93         solve();

94     }

95 

96     return 0;

97 }
View Code

 

 

你可能感兴趣的:(学习)