今天做了一道点分治的题目,所以就去网上学了一下。
相信大家都听说过“分治”吧,分治就是“分而治之”一般是把n分成2份,然后再对每一份进行相同的操作,最后合并起来。
而点分治,一般情况下是在一棵树上面进行分治,和普通的分治大同小异。
先看一道例题
【题意】
给定一个有N个点(编号1,2,…,N)的树,每条边都有一个权值(不超过1000)。
树上两个节点x与y之间的路径长度就是路径上各条边的权值之和。
求长度不超过K的路径有多少条。
poj 1741 – tree
【输入格式】
输入包含多组测试样例。
每组测试样例的第一行包含两个正整数N和K。
接下来N-1行,每行包含三个正整数u,v,l,表示节点u与v之间存在一条边,且边的权值为l。
当输入样例N=0,K=0时,表示输入终止,且该样例无需处理。
【输出格式】
每个测试样例输出一个结果。
每个结果占一行。
【数据范围】
N≤10000,K≤10^7
【输入样例】
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
【输出样例】
8
先是考虑暴力的解法,直接枚举每一个点,求lca,然后判断是否<=k,可以得10分
仔细想想,对于一个点来说,长度<=k的路径无非就只有两种:
1.经过该点
2.不经过该点
所以我们采取点分治的做法,先选1作为根节点,求出经过1且<=k的路径条数,然后删掉1,再对1的每个儿子进行重复操作
那么我们怎么求出路径的条数呢
设d[i]为i到根节点root的距离,其中d[root]=0
搜索求出d,对d从小到大排序
然后我们令指针l = 1, r = tp (tp为d中,点的个数)
如果 d[l] + d[r] <= k ,结果就加上r-l+1,然后l++
否则r- -
但是这样会出现一个问题,比如son是root的一个儿子,在son的子树中,有很多个加起来满足<=k的,而这些都不经过u,怎么办???
容斥原理!!!
对于每个儿子,令它的d不变,统计一次<=k的个数,从ans里面减去这个数就好了。
这样可以取得70分
为什么只有70呢
举个例子吧,如果树退化成链的话,这个做法就和暴力没有区别,时间复杂度 O ( n 2 l o g n ) O(n^2logn) O(n2logn)
怎么办呢,因为这是一棵无根树,每一个点成为根节点都不会影响结果,所以我们每次都从当前的树种求出它的重心作为root,这样的话,无论树怎么样,都只有 l o g n logn logn层,时间复杂度 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)
参考代码
#include
#include
#include
using namespace std;
const int N = 10006;
int n, m, k, all, ans, root;
//m是为了好看,all表示当前的树中节点总数,root表示根节点
int tot, Head[N], ver[N<<1], Leng[N<<1], Next[N<<1];
//储存边
int max_part[N], size[N], d[N];
//max_part[i]表示删去i以后剩余的最大部分,size表示子树大小,d表示到根节点的距离
int tp, sta[N];
//储存每个点到根节点的距离
bool vis[N];
//判断这个点是否做过做过根节点(是否被删除)
void Add(int u, int v, int w) {
ver[++tot] = v;
Leng[tot] = w;
Next[tot] = Head[u];
Head[u] = tot;
}
void Dfs1(int u, int fa) { //求重心
max_part[u] = 0, size[u] = 1;
for (int i = Head[u]; i; i = Next[i]) {
int v = ver[i];
if (v == fa || vis[v]) continue;
Dfs1(v, u);
size[u] += size[v];
max_part[u] = max(max_part[u], size[v]);
}
max_part[u] = max(max_part[u], all - max_part[u]);
if (max_part[u] < max_part[root]) root = u;
}
void Dfs2(int u, int fa) { //记录深度
sta[++tp] = d[u];
for (int i = Head[u]; i; i = Next[i]) {
int v = ver[i];
if (v == fa || vis[v]) continue;
d[v] = d[u] + Leng[i];
Dfs2(v, u);
}
}
int Calc(int u, int now) { //统计总数
d[u] = now, tp = 0;
Dfs2(u, 0);
sort(sta + 1, sta + tp + 1);
int sum = 0;
for (int l = 1, r = tp; l < r; )
if (sta[l] + sta[r] <= k)
sum += r - l, l++;
else r--;
return sum;
}
void Solve(int u) { //以重心u为根节点,进行点分治
vis[u] = 1;
ans += Calc(u, 0);
for (int i = Head[u]; i; i = Next[i]) {
int v = ver[i];
if (vis[v]) continue;
ans -= Calc(v, Leng[i]);
all = size[v];
root = 0;
Dfs1(v, u);
Solve(root);
}
}
int main() {
while(~scanf("%d%d", &n, &k) && n) {
tot = 0, memset(Head, 0, sizeof(Head));
memset(vis, 0, sizeof(vis));
for (int i = 1, x, y, z; i < n; i++) {
scanf("%d%d%d", &x, &y, &z);
Add(x, y, z);
Add(y, x, z);
}
all = n, max_part[0] = n, root = 0;
ans = 0;
Dfs1(1, 0);
Solve(root);
printf("%d\n", ans);
}
return 0;
}