POJ 1987 Distance Statistics 牛题 树的分治

POJ 1987 Distance Statistics 牛题 树的分治

这题很牛逼,是楼教主的《男人七题》的其中一道。
求:一棵树内最短距离小于K的点对数量
后来看了解题报告,原来树也是可以分治的。

分:
选取一条边,将一棵树分成两半,这两半的节点数要尽量相等。
首先,统计个每个节点的下面有多少个节点
然后,就知道每个边切断之后的情况了。选择最优的即可。

治:
分成两半之后,统计经过该条边的点对线段中,长度小于K的数目。

Amber 大牛论文的精辟描述如下:
 

Divide and Conquer.

Each iteration, we should choose an edge (u, v) and divide the tree into two parts disjoined by the edge. Due to avoid from degenerating, that partition edge should be chosen to divide two parts as equally as possible. Then we should merge two parts and count the valid pairs between them. It can be implemented by two sorted list that denotes the distances between u and the posterities of u and the distances between u and the posterities of v respectively. And like merge sort, use two scan line l, r in two list and maintain the property d(u, l) + d(v, r) <= k.


可见这位大牛的英文水平实在牛逼,英文说得比中文说得还清楚,赞一个。

按照这个思路,很费劲地写出了代码。还好,在1987上面还是勉强上榜啦!250ms那个就是我啦,哈哈。
但是在楼教主的题目1741 上面还是 TLE了。

后来找了一份能过1741的代码,在 http://hi.baidu.com/shingray/blog/item/221362b079afc55d082302f0.html
一个大牛的博客上~
发现它的思路不是选择一条边来把树分成两份。
而是选择一个点来把树分成数份,然后计算经过该点的线段数目。
这样速度就快了,大牛的代码在1741上面只跑了170多ms。
将这份代码放到1987上面,也能跑到260ms。
所以这种方法还是很牛逼的!


我的垃圾代码(POJ 1987):
#include  < stdio.h >
#include 
< stdlib.h >

#define  MAX_VETXS 65536*2
#define  MAX_EDGES (MAX_VETXS - 1)

#if  0
#define  dbp printf
#else
#define  dbp()
#endif

struct  edge_node  {
    
int w, i;
    
struct edge_node *next, *prev;
}
;

struct  edge_node edges[MAX_EDGES], map[MAX_VETXS];
int  edges_cnt;
int  N, K, ans;

int  cmp( const   void   * a,  const   void   * b)
{
    
return *(int *)a - *(int *)b;
}


inline 
int  max( int  a,  int  b)
{
    
return a > b ? a : b;
}


#define  list_foreach(_head, _t)    \
    
for  (_t  =  (_head) -> next; _t  !=  _head; _t  =  (_t) -> next)

inline 
void  list_init( struct  edge_node  * t)
{
    t
->next = t->prev = t;
}


inline 
void  list_add( struct  edge_node  * head,  struct  edge_node  * t)
{
    head
->prev->next = t;
    t
->prev = head->prev;
    head
->prev = t;
    t
->next = head;
}


inline 
void  list_del( struct  edge_node  * t)
{
    t
->prev->next = t->next;
    t
->next->prev = t->prev;
}


inline 
void  list_rev( struct  edge_node  * t)
{
    t
->next->prev = t;
    t
->prev->next = t;
}


inline 
void  edge_add( int  a,  int  b,  int  w)
{
    
struct edge_node *= &edges[edges_cnt++];

    t
->= b;
    t
->= w;
    list_add(
&map[a], t);
}


struct  part_info  {
    
int u, v, e, cnt_v;
}
;

inline 
void  divide( int  i,  int   * arr,  int   * len,  int  cnt,  struct  part_info  * pi)
{
    
static struct {
        
int i, e, depth, cnt, stat, root;
    }
 stk[MAX_VETXS], *sp, *top;
    
static int vis[MAX_VETXS], tm, best, val;
    
int *orig = arr;
    
struct edge_node *e;
    
    best 
= cnt;
    tm
++;
    top 
= stk + 1;
    top
->= i;
    top
->depth = top->cnt = top->stat = top->root = 0;
    vis[i] 
= tm;
    
while (top > stk) {
        sp 
= top;
        
if (sp->stat) {
            stk[sp
->root].cnt += sp->cnt;
            
if (arr && sp->depth <= K)
                
*arr++ = sp->depth;
            val 
= max(sp->cnt, cnt - sp->cnt);
            
if (val < best) {
                best 
= val;
                pi
->= stk[sp->root].i;
                pi
->= sp->i;
                pi
->= sp->e;
                pi
->cnt_v = sp->cnt;
            }

            top
--;
            
continue;
        }

        sp
->stat++;
        list_foreach(
&map[sp->i], e) {
            
if (vis[e->i] == tm)
                
continue;
            vis[e
->i] = tm;
            top
++;
            top
->= e->i;
            top
->= e - edges;
            top
->depth = sp->depth + e->w;
            top
->cnt = 1;
            top
->stat = 0;
            top
->root = sp - stk;
        }

    }


    
if (len)
        
*len = arr - orig;
}


void  conquer( struct  part_info  * pi,  int  cnt)
{
    
struct part_info pl, pr;
    
static int arr_l[MAX_VETXS], arr_r[MAX_VETXS], len_l, len_r, l, r;

    
if (cnt <= 1)
        
return ;

    list_del(
&edges[pi->e]);
    list_del(
&edges[pi->^ 1]);
    
    divide(pi
->u, arr_l, &len_l, cnt - pi->cnt_v, &pl);
    divide(pi
->v, arr_r, &len_r, pi->cnt_v, &pr);
    
    qsort(arr_l, len_l, 
sizeof(arr_l[0]), cmp);
    qsort(arr_r, len_r, 
sizeof(arr_r[0]), cmp);

    r 
= len_r - 1;
    
for (l = 0; l < len_l; l++{
        
while (r >= 0 && arr_l[l] + arr_r[r] + edges[pi->e].w > K)
            r
--;
        ans 
+= r + 1;
    }

    
    conquer(
&pl, cnt - pi->cnt_v);
    conquer(
&pr, pi->cnt_v);
    
    list_rev(
&edges[pi->e]);
    list_rev(
&edges[pi->^ 1]);
}


inline 
void  solve_v2()
{
    
struct part_info pi;

    divide(
1, NULL, NULL, N, &pi);
    conquer(
&pi, N);
}


int  main()
{
    
int i, a, b, w, m;
    
char str[16];

    freopen(
"e:\\test\\in.txt""r", stdin);

    scanf(
"%d%d"&N, &m);
    edges_cnt 
= 0;
    
for (i = 1; i <= N; i++)
        list_init(
&map[i]);
    
for (i = 0; i < m; i++{
        scanf(
"%d%d%d%s"&a, &b, &w, str);
        edge_add(a, b, w);
        edge_add(b, a, w);
    }

    scanf(
"%d"&K);
    ans 
= 0;
    solve_v2();
    printf(
"%d\n", ans);

    
return 0;
}



大牛的代码(POJ 1987):
#include  < algorithm >
#include 
< cstdio >
#include 
< cstring >
#include 
< limits >
#include 
< queue >
#include 
< vector >
using   namespace  std;

const   int  MAX_N  =   65536 * 2 ;
bool  flag[MAX_N];
int  k, n, ret, v[MAX_N];
queue
< pair < int int >   >  q;

struct  edge {int v, w; edge *next; }   * e[MAX_N], data[MAX_N * 2 - 2 ],  * it;
void  insert( int  u,  int  v,  int  w)
{
   
*it = (edge){v, w, e[u]}; e[u] = it++;
   
*it = (edge){u, w, e[v]}; e[v] = it++;
}


int  count( int   * first,  int   * last)
{
   
int ret = 0;
   sort(first, last
--);
   
while (first < last)
       
if (*first+*last <= k) ret += last-first++;
       
else --last;
   
return ret;
}


int  best_size, center;
int  centerOfGravity( int  root,  int  pred)
{
   
int max_sub = 0, size = 1;
   
for (edge *it = e[root]; it; it = it->next)
       
if (it->!= pred && flag[it->v])
       
{
           
int t = centerOfGravity(it->v, root);
           size 
+= t;
           
if (t > max_sub) max_sub = t;
       }

   
if (q.front().second-q.front().first-max_sub > max_sub)
       max_sub 
= q.front().second-q.front().first-max_sub;
   
if (max_sub < best_size)
       best_size 
= max_sub, center = root;
   
return size;
}


int  dists[MAX_N], len;
void  find( int  root,  int  pred,  int  dist)
{
   v[len] 
= root;
   dists[len
++= dist;
   
int last = len;
   
for (edge *it = e[root]; it; it = it->next)
       
if (it->!= pred && flag[it->v])
       
{
           find(it
->v, root, dist+it->w);
           
if (pred == -1)
           
{
               q.push(make_pair(last, len));
               ret 
-= count(dists+last, dists+len);
               last 
= len;
           }

       }

}


int  main()
{
    
int m;
    
char str[16];
   scanf(
"%d%d"&n, &m);
   
{
       it 
= data;
       memset(e, 
0sizeof(e[0])*n);
       
for (int i = n; --i; )
       
{
           
int u, v, w;
           scanf(
"%d%d%d%s"&u, &v, &w, str);
           
--u; --v;
           insert(u, v, w);
       }

       scanf(
"%d"&k);

       ret 
= 0;
       
for (int i = 0; i < n; ++i)
           v[i] 
= i;
       
for (q.push(make_pair(0, n)); !q.empty(); q.pop())
       
{
           
if (q.front().first == q.front().second-1continue;
           
for (int i = q.front().first; i < q.front().second; ++i)
               flag[v[i]] 
= true;

           best_size 
= numeric_limits<int>::max();
           centerOfGravity(v[q.front().first], 
-1);

           len 
= q.front().first;
           find(center, 
-10);
           ret 
+= count(dists+q.front().first, dists+q.front().second);

           
for (int i = q.front().first; i < q.front().second; ++i)
               flag[v[i]] 
= false;
       }

       printf(
"%d\n", ret);
   }

}

你可能感兴趣的:(POJ 1987 Distance Statistics 牛题 树的分治)