题目链接:
点击打开链接
题目大意:
给出一棵树,给出树上的一些链,每个链有一个权,问在链之间不相交的情况下,能够得到最大的权值之和。
题目分析:
首先我们要做一个预处理,利用离线的Lca求出每条链的两个端点的最近公共祖先,同时利用时间戳标记我们到达某个点的时间,和扫描完整棵子树,回到的这个点的时间。预处理做完之后我们进行动态规划。(在做动态的规划的过程中说明预处理的意义)
我们定义两个数组:sum[MAX]和d[MAX]
sum[MAX]表示当前点不被子树中的任何一条链经过的最优解,且当前选中的链只在子树中
d[MAX]表示当前点最优解,在当前选中的链只在子树中的情况下
那么从根开始,那么有两种情况,选择根或者根本用不到根,如果不用根,因为他的每个儿子所在的子树的d[i]中的链都只在当前子树中,所以sum[u] = sigma(d[k]),k是u的儿子
如果选择当前的点,那么当前点所在的链一定会选择一个,那么利用预处理出的公共祖先,我们可以将i点作为公共祖先的链加入到一个集合I中,那么我们枚举这些链,然后对于每条链,我们要去掉之前选择的矛盾的链,那么首先我们加上sun[u],得到不去掉其他链的结果,那么对于当前链上的一个节点,它的d[i]一定是选择i为根的子树当中的链,所以就减去d[i],然后加上sum[i],也就是将当前的情况修正为不选则i的最优解,最后再加上要添加的链的权值,所以我们就得到
dp[u] = max ( sum[u] , sigma ( sum[k](k在链上 )-dp[k](k在链上) )
对于这个求和操作我们如果直接做的话,很浪费时间
那么我们应该如何做呢?
需要做一个预处理,就用到了之前的时间戳了。
首先对于点u,l[u]到r[u]范围的点都在它的子树当中,那么对于点u,我们如果在l[u]位置加上dp[u] , 在r[u]位置加上-dp[u],
那么每次我们查询到的到l[u]的前缀和就是从公共祖先到u的这条链上的dp[i]的和。
为什么呢?
因为首先如果点v不在这条链上,且与u同在lca(u)的子树当中,那么如果点v与u在同一条树链上,但是dep[v]>dep[u],那么因为l[u]是第一次到达u的时间戳,所以不会加上dp[v]的值,那么如果v与u不在同一条链上,那么如果那条链的遍历顺序早于u,那么在扫描到u之前已经回溯到v,所以dp[v]-dp[v]相当于没加,如果遍历顺序大于u,那么再前缀和中v的dp[v]还没有被加,那么得到结论:前缀和就是当前lca(u)到u的dp[i]之和,sum同理
那么这道题就很清楚了
代码如下:
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include
#include
#include
#include
#include
#define MAX 200007
using namespace std;
int n,m,t;
typedef long long LL;
LL d[MAX];
LL sum[MAX];
LL c1[MAX<<1];
LL c2[MAX<<1];
int lowbit ( int x )
{
return x&-x;
}
void add1 ( int x , LL v )
{
while ( x <= n )
{
c1[x] += v;
x += lowbit ( x );
}
}
void add2 ( int x , LL v )
{
while ( x <= n )
{
c2[x] += v;
x += lowbit ( x );
}
}
LL sum1 ( int x )
{
LL res = 0;
while ( x )
{
res += c1[x];
x -= lowbit ( x );
}
return res;
}
LL sum2 ( int x )
{
LL res = 0;
while ( x )
{
res += c2[x];
x -= lowbit ( x );
}
return res;
}
typedef pair PII;
vector e[MAX];
vector chain[MAX];
vector a[MAX];
vector w[MAX];
vector val[MAX];
int fa[MAX];
int times;
bool used[MAX];
int l[MAX];
int r[MAX];
int _find ( int x )
{
return fa[x] == x ? x: fa[x] = _find ( fa[x]);
}
void LCA ( int u )
{
fa[u] = u;
l[u] = ++times;
used[u] = true;
for ( int i = 0 ; i < e[u].size() ; i++ )
{
int v = e[u][i];
if ( used[v] ) continue;
LCA ( v );
fa[v] = u;
}
for ( int i = 0 ; i < chain[u].size() ; i++ )
{
int v = chain[u][i];
if ( !used[v] ) continue;
int x = _find ( v );
a[x].push_back ( make_pair ( u , v ));
w[x].push_back ( val[u][i] );
}
r[u] = ++times;
}
void dfs ( int u , int p )
{
sum[u] = 0;
d[u] = 0;
for ( int i = 0 ; i < e[u].size() ; i++ )
{
int v = e[u][i];
if ( v == p ) continue;
dfs ( v , u );
sum[u] += d[v];
}
for ( int i = 0 ; i < a[u].size() ; i++ )
{
int x = a[u][i].first;
int y = a[u][i].second;
LL temp = sum1(l[x]) + sum1(l[y]) + sum[u]
-sum2(l[x]) - sum2(l[y]);
d[u] = max ( temp + w[u][i] , d[u] );
}
d[u] = max ( d[u] , sum[u] );
add1 ( l[u] , sum[u] );
add1 ( r[u] , -sum[u] );
add2 ( l[u] , d[u] );
add2 ( r[u] , -d[u] );
}
void init ( )
{
times = 0;
memset ( c1 , 0 , sizeof ( c1 ) );
memset ( c2 , 0 , sizeof ( c2 ));
memset ( used , 0 , sizeof ( used ));
for ( int i = 0 ; i < MAX ; i++ )
{
e[i].clear();
val[i].clear();
a[i].clear();
w[i].clear();
chain[i].clear();
}
}
int main ( )
{
int u,v,x;
scanf ( "%d" , &t );
while ( t-- )
{
init();
scanf ( "%d%d" , &n , &m );
int nn = n-1;
while ( nn-- )
{
scanf ( "%d%d" , &u , &v );
e[u].push_back ( v );
e[v].push_back ( u );
}
n = n*2;
while ( m-- )
{
scanf ( "%d%d%d" , &u , &v , &x );
chain[u].push_back ( v );
chain[v].push_back ( u );
val[u].push_back ( x );
val[v].push_back ( x );
}
//cout <<"YES" << endl;
LCA ( 1 );
dfs ( 1, -1 );
printf ( "%I64d\n" , d[1] );
}
}