【poj1741】Tree 点分治

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

Source

LouTiancheng@POJ

问你树上有多少点对满足距离小于等于k。

树分治裸题。答案分为经过根节点的路径和不经过根节点的路径。每次我们只计算经过根节点的路径条数。

可以通过统计所有点到根节点的距离计算。存到d数组里,问题成了计算d[i] + d[j] <= k的 < i,j >个数。

排序( O(nlogn) )后可以在O(n)的时间内算出,具体看代码(getans函数)。得到答案ans1。

容易想到,ans1有不合法的路径,那就是这条路径起点终点在同一个子树内。这时我们应该减去这些方案数。解决方法就是,每次计算某个儿子为根的子树时,计算完毕后减去这棵子树的答案即可。

分治nlogn,加上sort,总复杂度是O(nlogn)

我代码在poj上老是TLE,刷了快半页了…挖个坑,以后再说。

代码:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;

const int SZ = 10010;

int head[SZ],nxt[SZ << 1],n,k,tot = 0,dist[SZ];

struct edge{
    int t,d;
}l[SZ << 1];

void build(int f,int t,int d)
{
    l[++ tot].t = t;
    l[tot].d = d;
    nxt[tot] = head[f];
    head[f] = tot;
}

int ans = 0,maxn,s,t,root;

bool rt[SZ];

int find(int u,int fa)
{
    int sz = 1;
    int now = 0;
    for(int i = head[u];i;i = nxt[i])
    {
        int v = l[i].t;
        if(!rt[v] && v != fa)
        {
            int son = find(v,u);
            sz += son;
            now = max(now,son);
        }
    }
    now = max(now,n - sz);
    if(now < maxn) maxn = now,root = u;
    return sz;
} 

void dfsdist(int u,int fa,int d)
{
    dist[++ t] = d;
    for(int i = head[u];i;i = nxt[i])
    {
        int v = l[i].t;
        if(!rt[v] && v != fa)
            dfsdist(v,u,d + l[i].d);
    }
}

int getans(int s,int t)
{
    sort(dist + s,dist + t + 1);
    int ans = 0;
    int r = t;
    for(int i = s;i <= t;i ++)
    {
        while(dist[i] + dist[r] > k && r > i) r --;
        ans += r - i;
        if(r == i) break; 
    }
    return ans;
}

void dfs(int x,int fa)
{
    maxn = n;
    find(x,fa);
    int u = root;
    s = 1,t = 0;
    rt[u] = 1;
    for(int i = head[u];i;i = nxt[i])
    {
        int v = l[i].t;
        if(!rt[v])
        {
            s = t + 1;
            dfsdist(v,u,l[i].d);
            ans -= getans(s,t);
        }
    }
    dist[++ t] = 0;
    ans += getans(1,t);
    for(int i = head[u];i;i = nxt[i])
        if(!rt[l[i].t]) dfs(l[i].t,u);
}

void init()
{
    memset(head,0,sizeof(head));
    memset(rt,0,sizeof(rt));
    ans = tot = 0;
}

void scanf(int &n)
{
    n = 0;
    char a = getchar();
    bool flag = 0;
    while(a < '0' || a > '9') { if(a == '-') flag = 1; a = getchar(); } 
    while(a >= '0' && a <= '9') n = (n << 3) + (n << 1) + a - '0',a = getchar();
    if(flag) n = -n;
}

int main()
{
    freopen("in.txt","r",stdin); 
    freopen("out.txt","w",stdout);      
    while(233)
    {
        init();
        scanf(n); scanf(k);
        if(!n && !k) break;
        for(int i = 1,a,b,c;i < n;i ++)
        {
            scanf(a); scanf(b); scanf(c);
            build(a,b,c);
            build(b,a,c);
        }
        dfs(1,0);
        printf("%d\n",ans);
    }
    return 0;
}

———-以上是12.28号的事情———-

———-以下是12.29号的事情———-

填坑了…树分治导致我树的重心打错,导致昨天T了三个题……

树的重心需要用到当前树的总点数,我直接用的n,所以重心找错了…

所以说要在找重心之前先计算一下当前树的大小,然后就可以AC了。

代码:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;

const int SZ = 10010;

int head[SZ],nxt[SZ << 1],n,k,tot = 0,dist[SZ];

struct edge{
    int t,d;
}l[SZ << 1];

void build(int f,int t,int d)
{
    l[++ tot].t = t;
    l[tot].d = d;
    nxt[tot] = head[f];
    head[f] = tot;
}

int ans = 0,maxn,s,t,root;

bool rt[SZ];

int find(int u,int fa,int n)
{
    int sz = 1;
    int now = 0;
    for(int i = head[u];i;i = nxt[i])
    {
        int v = l[i].t;
        if(!rt[v] && v != fa)
        {
            int son = find(v,u,n);
            sz += son;
            now = max(now,son);
        }
    }
    now = max(now,n - sz);
    if(now < maxn) maxn = now,root = u;
    return sz;
} 

void dfsdist(int u,int fa,int d)
{
    dist[++ t] = d;
    for(int i = head[u];i;i = nxt[i])
    {
        int v = l[i].t;
        if(!rt[v] && v != fa)
            dfsdist(v,u,d + l[i].d);
    }
}

int getans(int s,int t)
{
    sort(dist + s,dist + t + 1);
    int ans = 0;
    int r = t;
    for(int i = s;i <= t;i ++)
    {
        while(dist[i] + dist[r] > k && r > i) r --;
        ans += r - i;
        if(r == i) break; 
    }
    return ans;
}

int dfssz(int u,int fa)
{
    int sz = 1;
    for(int i = head[u];i;i = nxt[i])
    {
        int v = l[i].t;
        if(!rt[v] && v != fa)
            sz += dfssz(v,u);
    }
    return sz;
} 

void dfs(int x,int fa)
{
    int sz = dfssz(x,fa);
    maxn = n;
    find(x,fa,sz);
    int u = root;
    s = 1,t = 0;
    rt[u] = 1;
    for(int i = head[u];i;i = nxt[i])
    {
        int v = l[i].t;
        if(!rt[v])
        {
            s = t + 1;
            dfsdist(v,u,l[i].d);
            ans -= getans(s,t);
        }
    }
    dist[++ t] = 0;
    ans += getans(1,t);
    for(int i = head[u];i;i = nxt[i])
        if(!rt[l[i].t]) dfs(l[i].t,u);
}

void init()
{
    memset(head,0,sizeof(head));
    memset(rt,0,sizeof(rt));
    ans = tot = 0;
}

void scanf(int &n)
{
    n = 0;
    char a = getchar();
    bool flag = 0;
    while(a < '0' || a > '9') { if(a == '-') flag = 1; a = getchar(); } 
    while(a >= '0' && a <= '9') n = (n << 3) + (n << 1) + a - '0',a = getchar();
    if(flag) n = -n;
}

int main()
{   
// freopen("in.txt","r",stdin); 
// freopen("out.txt","w",stdout); 
    while(233)
    {
        init();
        scanf(n); scanf(k);
        if(!n && !k) break;
        for(int i = 1,a,b,c;i < n;i ++)
        {
            scanf(a); scanf(b); scanf(c);
            build(a,b,c);
            build(b,a,c);
        }
        dfs(1,0);
        printf("%d\n",ans);
    }
    return 0;
}

你可能感兴趣的:(poj)