楼教主男人八题 POJ 1741(树分治(我自然是看题解搞懂的))

题意就是求树上距离小于等于K的点对有多少个
n2的算法肯定不行,因为1W个点
这就需要分治。可以看09年漆子超的论文 http://wenku.baidu.com/view/e087065f804d2b160b4ec0b5.html###
本题用到的是关于点的分治。
一个重要的问题是,为了防止退化,所以每次都要找到树的重心然后分治下去, 所谓重心,就是删掉此结点后,剩下的结点最多的树结点个数最小
每次分治,我们首先算出重心,为了计算重心,需要进行两次dfs,第一次把以每个结点为根的子树大小求出来,第二次是从这些结点中找重心
找到重心后,需要统计所有结点到重心的距离,看其中有多少对小于等于K,这里采用的方法就是把所有的距离存在一个数组里,进行快速排序,这是nlogn的,然后用一个经典的相向搜索O(n)时间内解决。 但是这些求出来满足小于等于K的里面只有那些路径经过重心的点对才是有效的,也就是说在同一颗子树上的肯定不算数的,所以对每颗子树,把子树内部的满足条件的点对减去。

最后的复杂度是n logn logn    其中每次快排是nlogn 而递归的深度为logn

我的理解是:比如现在有一棵树,从点1开始遍历整棵树,算出每个节点为根所在的子树中点的数量

然后再遍历树,找出重心,重心的定义如上,也可以看论文(有证明),然后算出从重心到每个节点的距离(不计算已经当过根的节点),快排之后用相向搜索找两段dis相加小于等于k的个数(具体见代码)

不过这样找的会有重复,所以每次我们找的符合题目的对数都是经过重心root的对数,所以要减去不经过重心的个数

原理还是可以理解的,代码比较复杂,确实挺难写,仍需努力,早日成为真男人QAQ

#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <string>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>

using namespace std;
#define   MAX       10005
#define   MAXN      2000005
#define   lson      l,m,rt<<1
#define   rson      m+1,r,rt<<1|1
#define   lrt       rt<<1
#define   rrt       rt<<1|1
#define   mid       int m=(r+l)>>1
#define   LL        long long
#define   ull       unsigned long long
#define   mem0(x)   memset(x,0,sizeof(x))
#define   mem1(x)   memset(x,-1,sizeof(x))
#define   meminf(x) memset(x,INF,sizeof(x))
#define   lowbit(x) (x&-x)

const LL     mod   = 1000000;
const int    prime = 999983;
const int    INF   = 0x3f3f3f3f;
const int    INFF  = 1e9;
const double pi    = 3.141592653589793;
const double inf   = 1e18;
const double eps   = 1e-10;

struct Edge{
    int v,cost,next;
}edge[MAX*2];
int head[MAX];
int maxn[MAX];
int siz[MAX];
int vis[MAX];
int dis[MAX];
int tot;
int root;
int mi;
int n,k;
int ans;
int num;

void add_edge(int a,int b,int c){
    edge[tot]=(Edge){b,c,head[a]};
    head[a]=tot++;
}

void dfssize(int u,int fa){
    siz[u]=1;
    maxn[u]=0;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(v!=fa&&!vis[v]){
            dfssize(v,u);
            siz[u]+=siz[v];
            maxn[u]=max(maxn[u],siz[v]);
        }
    }
}

void dfsroot(int r,int u,int fa){
    if(siz[r]-siz[u]>maxn[u]) maxn[u]=siz[r]-siz[u];
    if(maxn[u]<mi){
        mi=maxn[u];
        root=u;
    }
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(v!=fa&&!vis[v]) dfsroot(r,v,u);
    }
}

void dfsdis(int u,int fa,int d){
    dis[num++]=d;
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(v!=fa&&!vis[v]) dfsdis(v,u,d+edge[i].cost);
    }
}

int calc(int u,int d){
    int ret=0;
    num=0;
    dfsdis(u,-1,d);
    sort(dis,dis+num);
    int i=0,j=num-1;
    while(i<j){
        while(dis[i]+dis[j]>k&&i<j) j--;
        ret+=j-i;
        i++;
    }
    return ret;
}

void dfs(int u){
    mi=n;
    dfssize(u,-1);
    dfsroot(u,u,-1);
    ans+=calc(root,0);
    vis[root]=1;
    for(int i=head[root];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(!vis[v]){
            ans-=calc(v,edge[i].cost);//不经过重心的对数,从子节点v开始计算,距离是cost
            dfs(v);
        }
    }
}

int main(){
    while(scanf("%d%d",&n,&k)){
        if(!n&&!k) break;
        mem1(head);
        mem0(vis);
        ans=0;
        tot=0;
        for(int i=1;i<n;i++){
            int a,b,c;
            scanf("%d%d%d",&a,&b,&c);
            add_edge(a,b,c);
            add_edge(b,a,c);
        }
        dfs(1);
        printf("%d\n",ans);
    }
    return 0;
}


你可能感兴趣的:(ACM)