2015 Multi-University Training Contest 1 - 1009 Annoying problem

 Annoying problem

Problem's Link:  http://acm.hdu.edu.cn/showproblem.php?pid=5296  


 

Mean: 

给你一个有根树和一个节点集合S,初始S为空,有q个操作,每个操作要么从树中选一个结点加入到S中(不删除树中节点),要么从S集合中删除一个结点。你需要从树中选一些边组成集合E,E中的边能够是S中的节点连通。对于每一个操作,输出操作后E中所有边的边权之和。

analyse:

首先是构图,把树看作一个无向图,使用邻接表存图。

处理出从1号结点的dfs序存储起来。

添加点u操作:查找S集合中的点与添加的点dfs序在前面且编号最大的点,以及dfs序在后面且编号最小的点,设这两个点是x,y。

那么增加的花费是:dis[u]-dis[lca[(x,u)] - dis [lca(y,u)] + dis[lca(x,y) ]; 其中dis代表该点到根节点的距离。

 

对于删除点u操作:先把点从集合中删除,然后再计算减少花费,计算公式和增加的计算方法一样。

也是看了题解才撸出来的。

Time complexity: O(N)

 

Source code: 

/*
* this code is made by crazyacking
* Verdict: Accepted
* Submission Date: 2015-07-22-11.22
* Time: 0MS
* Memory: 137KB
*/
#include <queue>
#include <cstdio>
#include <set>
#include <string>
#include <stack>
#include <cmath>
#include <climits>
#include <map>
#include <cstdlib>
#include <iostream>
#include <vector>
#include <algorithm>
#include <cstring>
#define  LL long long
#define  ULL unsigned long long
#define rep(i,n) for(int i=0;i<n;++i)
using namespace std;

const int N = 100010,D=20;
int st[N], ori[N], dfs_clock;
vector<pair<int, int> > G[N];
int shortcut[D][N], dep[N];
int *fa;
int get_lca( int a, int b )
{
      if( dep[a] < dep[b] )
            swap( a, b );
      for( int i = D - 1; ~i; --i )
      {
            if( dep[a] - dep[b] >> i & 1 )
                  a = shortcut[i][a];
      }
      if( a == b ) return a;
      for( int i = D - 1; ~i; --i )
      {
            if( shortcut[i][a] != shortcut[i][b] )
            {
                  a = shortcut[i][a];
                  b = shortcut[i][b];
            }
      }
      return fa[a];
}

int dis[N];
void dfs( int u, int father )
{
      st[u] = ++dfs_clock;
      ori[dfs_clock] = u;
      vector<pair<int, int> > :: iterator it;
      for( it = G[u].begin(); it != G[u].end(); ++it )
      {
            int v = ( *it ).first;
            int w = ( *it ).second;
            if( v == father )continue;
            fa[v] = u;
            dep[v] = dep[u] + 1;
            dis[v] = dis[u] + w;
            dfs( v, u );
      }
}


void prepare( int n )
{
      for( int j = 1; j < D; ++j )
      {
            rep( i, n )
            {
                  int &res = shortcut[j][i];
                  res = shortcut[j - 1][i];
                  if( ~res ) res = shortcut[j - 1][res];
            }
      }
}


set<int> nodes;
int get_dis( int a, int b )
{
      return dis[a] + dis[b] - 2 * dis[get_lca( a, b )];
}


int add( int u )
{
      if( !nodes.empty() )
      {
            set<int>::iterator low, high;
            low = nodes.lower_bound( st[u] );
            high = low;
            if( low == nodes.end() || low == nodes.begin() )
            {
                  low = nodes.begin();
                  high = nodes.end();
                  high--;
            }
            else low--;
            int x = ori[*low];
            int y = ori[*high];
            int res = get_dis( x, u ) + get_dis( y, u ) - get_dis( x, y );
            return res;
      }
      return 0;
}

int main()
{
      ios_base::sync_with_stdio( false );
      cin.tie( 0 );
      int T, n, q, u, v, w;
      scanf( "%d", &T );
      rep( cas, T )
      {
            scanf( "%d %d", &n, &q );
            rep( i, n ) G[i].clear();
            dfs_clock = 0;
            rep( i, n - 1 )
            {
                  scanf( "%d %d %d", &u, &v, &w );
                  u--;
                  v--;
                  G[u].push_back( make_pair( v, w ) );
                  G[v].push_back( make_pair( u, w ) );
            }
            fa = shortcut[0];
            fa[0] = -1;
            dfs( 0, -1 );
            prepare( n );
            nodes.clear();
            printf( "Case #%d:\n", cas + 1 );
            int ans = 0;
            while( q-- )
            {
                  int op;
                  scanf( "%d %d", &op, &u );
                  u--;
                  if( op == 1 )  // add
                  {
                        if( nodes.find( st[u] ) == nodes.end() )
                        {
                              ans += add( u );
                              nodes.insert( st[u] );
                        }
                  }
                  else   // delete
                  {
                        if( nodes.find( st[u] ) != nodes.end() )
                        {
                              nodes.erase( st[u] );
                              ans -= add( u );
                        }
                  }
                  printf( "%d\n", ans >> 1 );
            }
      }
      return 0;
}
/*

*/

  

你可能感兴趣的:(test)