【SPOJ】1825 Free tour II 点分治

传送门:【SPOJ】1825 Free tour II


题目分析:敲了两遍。。。

本题是论文题,具体见漆子超论文《分治算法在树的路径问题中的应用》。


在以root为根的第 i 棵子树上,我们用G[ i ,j ]表示root的第 i 棵子树的路径上严格有 j 个黑点的路径的最长长度。用F[ i ,j ]表示在root为根的第 i 棵子树的路径上不超过 j 个黑点的路径的最长长度。因为所有子树里包含黑点数最多的路径的包含黑点数X可以O(N)求出,我们按照每棵子树的X从小到大的顺序遍历,这样就能将G和F数组降低一维,以G[ i ]表示当前遍历的子树路径上严格有 i 个黑点的路径的最长长度,以F[ i ]表示在该子树之前所遍历的所有子树的路径上不超过 i 个黑点的路径的最长长度。

G[ i ]可以通过一次dfs求出,而F[ i ]可以在该子树遍历完以后用G[ i ]和F[ i ]的比较以及F[ i ]和F[ i - 1]的比较来更新。

而遍历顺序可以通过求出每个子树的X,然后把子树的遍历顺序按照X从小到大排序(用个结构体能很方便的解决)。

这样每层是大约O(NlogN)的复杂度,而最多只有logN层(参见点分治的复杂度),所以总复杂度大约为O(Nlog^2N)。


第一次AC了感觉掌握的还是不好,于是又敲了一遍,这才敢写文章。但是我感觉掌握的还是很不好,慢慢来,总有一天我会做的更好!



代码如下:


#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;

typedef long long LL ;

#define travel( e , H , u ) for ( Edge* e = H[u] ; e ; e = e -> next )
#define rep( i , a , b ) for ( int i = ( a ) ; i <  ( b ) ; ++ i )
#define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i )
#define FOR( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i )
#define clr( a , x ) memset ( a , x , sizeof a )
#define cpy( a , x ) memcpy ( a , x , sizeof a )

const int MAXN = 200005 ;
const int MAXE = 400005 ;
const int INF = 0x3f3f3f3f ;

struct Edge {
	int v , c ;
	Edge* next ;
} E[MAXE] , *H[MAXN] , *edge ;

struct Node {
	int v , c ;
	int num ;
	Node () {}
	Node ( int v , int num , int c ) : v ( v ) , num ( num ) , c ( c ) {}
	bool operator < ( const Node& a ) const {
		return num < a.num ;
	}
} T[MAXN] ;

int siz[MAXN] ;
int num[MAXN] ;
int G[MAXN] ;
int F[MAXN] ;
bool vis[MAXN] ;
bool color[MAXN] ;
int root ;
int node_num ;
int n , m , K ;
int ans ;

void clear () {
	ans = 0 ;
	edge = E ;
	num[0] = n ;
	clr ( H , 0 ) ;
	clr ( vis , 0 ) ;
	clr ( color , 0 ) ;
}

void addedge ( int u , int v , int c ) {
	edge -> v = v ;
	edge -> c = c ;
	edge -> next = H[u] ;
	H[u] = edge ++ ;
}

void get_size ( int u , int fa = 0 ) {
	siz[u] = 1 ;
	travel ( e , H , u ) {
		int v = e -> v ;
		if ( !vis[v] && v != fa ) {
			get_size ( v , u ) ;
			siz[u] += siz[v] ;
		}
	}
}

void get_root ( int u , int fa = 0 ) {
	num[u] = 0 ;
	travel ( e , H , u ) {
		int v = e -> v ;
		if ( !vis[v] && v != fa ) {
			get_root ( v , u ) ;
			num[u] = max ( num[u] , siz[v] ) ;
		}
	}
	num[u] = max ( num[u] , node_num - siz[u] ) ;
	if ( num[u] < num[root] ) root = u ;
}

void get_num ( int u , int fa = 0 ) {
	num[u] = color[u] ;
	travel ( e , H , u ) {
		int v = e -> v ;
		if ( !vis[v] && v != fa ) {
			get_num ( v , u ) ;
			num[u] = max ( num[u] , color[u] + num[v] ) ;
		}
	}
}

void get_G ( int u , int fa , int dep , int val ) {
	G[dep] = max ( G[dep] , val ) ;
	travel ( e , H , u ) {
		int v = e -> v ;
		if ( !vis[v] && v != fa ) {
			get_G ( v , u , dep + color[v] , val + e -> c ) ;
		}
	}
}
			
void dfs ( int u ) {
	get_size ( u ) ;//得到树的大小
	node_num = siz[u] ;
	root = 0 ;
	get_root ( u ) ;//求树的重心
	int rt = root , cnt = 0 ;
	vis[rt] = 1 ;//标记,将rt的所有子树分开
	travel ( e , H , rt ) if ( !vis[e -> v] ) dfs ( e -> v ) ;//递归求解
	travel ( e , H , rt ) {
		int v = e -> v ;
		if ( !vis[v] ) {
			get_num ( v ) ;//得到子树内节点数最多的路径的节点数
			T[cnt ++] = Node ( v , num[v] , e -> c ) ;
		}
	}
	sort ( T , T + cnt ) ;//修改访问顺序
	int limit = K - color[rt] ;
	FOR ( i , 0 , T[cnt - 1].num ) F[i] = -INF ;
	rep ( i , 0 , cnt ) {
		FOR ( j , 0 , T[i].num ) G[j] = -INF ;
		get_G ( T[i].v , rt , color[T[i].v] , T[i].c ) ;//dfs得到G[ ]
		if ( i ) {
			FOR ( j , 0 , T[i].num ) {
				if ( j > limit ) break ;
				int tmp = min ( T[i - 1].num , limit - j ) ;//与下面的更新有关,取之前已经更新过的
				if ( F[tmp] == -INF ) break ;
				ans = max ( ans , F[tmp] + G[j] ) ;
			}
		}
		FOR ( j , 0 , T[i].num ) {//更新到T[i].num,减小一点时间复杂度是一点
			if ( j > limit ) break ;
			F[j] = max ( F[j] , G[j] ) ;
			if ( j ) F[j] = max ( F[j] , F[j - 1] ) ;
			if ( j <= limit ) ans = max ( ans , F[j] ) ;
		}
	}
	vis[rt] = 0 ;//取消标记,将子树合并到大树上去
}

void solve () {
	int x , y , c ;
	clear () ;
	while ( m -- ) {
		scanf ( "%d" , &x ) ;
		color[x] = 1 ;
	}
	rep ( i , 1 , n ) {
		scanf ( "%d%d%d" , &x , &y , &c ) ;
		addedge ( x , y , c ) ;
		addedge ( y , x , c ) ;
	}
	dfs ( 1 ) ;
	printf ( "%d\n" , ans ) ;
}

int main () {
	while ( ~scanf ( "%d%d%d" , &n , &K , &m ) ) solve () ;
	return 0 ;
}


你可能感兴趣的:(spoj)