树形dp问题分类【习题集】

树形 d p dp dp问题主要有如下两种:
第一种:简单的单向dp即可求出答案,比如从根部向下递归做dp或者从叶子往上做dp
第二种:双向dp即换根dp,需要做两次dfs一次向下一次向上
下面直接上例题

F. Tree with Maximum Cost

time limit per test:2 seconds
memory limit per test:256 megabytes
input standard:input
output standard:output

You are given a tree consisting exactly of n vertices. Tree is a connected undirected graph with n−1 edges. Each vertex v of this tree has a value av assigned to it.

Let dist(x,y) be the distance between the vertices x and y. The distance between the vertices is the number of edges on the simple path between them.

Let’s define the cost of the tree as the following value: firstly, let’s fix some vertex of the tree. Let it be v. Then the cost of the tree is ∑ i = 1 n d i s t ( i , v ) ∗ a i \sum_{i=1}^{n}{dist(i,v)*a_i} i=1ndist(i,v)ai

Your task is to calculate the maximum possible cost of the tree if you can choose v arbitrarily.

Input
The first line contains one integer n, the number of vertices in the tree ( 1 ≤ n ≤ 2 ⋅ 1 0 5 ) (1≤n≤2⋅10^5) (1n2105).

The second line of the input contains n integers a 1 , a 2 , … , a n a_1,a_2,…,a_n a1,a2,,an ( 1 ≤ a i ≤ 2 ⋅ 1 0 5 ) (1≤a_i≤2⋅10^5) (1ai2105), where ai is the value of the vertex i i i.

Each of the next n−1 lines describes an edge of the tree. Edge i is denoted by two integers ui and vi, the labels of vertices it connects ( 1 ≤ u i , v i ≤ n , u i   ! = v i 1≤u_i,v_i≤n, u_i\ !=v_i 1ui,vin,ui !=vi).

It is guaranteed that the given edges form a tree.

Output

Print one integer — the maximum possible cost of the tree if you can choose any vertex as v.

Examples

input

8
9 4 1 7 10 1 6 5
1 2
2 3
1 4
1 5
5 6
5 7
5 8

output

121

input

1
1337

output

0

Note

Picture corresponding to the first example:

You can choose the vertex 3 as a root, then the answer will be 2⋅9+1⋅4+0⋅1+3⋅7+3⋅10+4⋅1+4⋅6+4⋅5=18+4+0+21+30+4+24+20=121.

In the second example tree consists only of one vertex so the answer is always 0.

  • 题意

    • 给你一棵树,每个节点都有一个权值,让你求
      m a x { ∑ i = 1 n d i s ( c u r , i ) ∗ v a l u e [ i ] } max\{\sum_{i=1}^{n}dis(cur,i)*value[i]\} max{i=1ndis(cur,i)value[i]}
  • 题解

    • 显然暴力 O ( n 2 ) O(n^2) O(n2)是一定不行的,所以考虑从根节点往下dp,由于根节点的选取对答案无影响,这里选取1作为根节点,首先说明几个数组的含义:
      s [ c u r ] = ∑ i d i s ( i , c u r ) ∗ a [ i ] s[cur]=\sum_{}^{i}dis(i,cur)*a[i] s[cur]=idis(i,cur)a[i] s u m [ c u r ] = ∑ i a [ i ] sum[cur]=\sum_{}^{i}a[i] sum[cur]=ia[i] d p [ c u r ] = ∑ i = 1 n d i s ( i , c u r ) ∗ a [ i ] dp[cur]=\sum_{i=1}^{n}{dis(i,cur)*a[i]} dp[cur]=i=1ndis(i,cur)a[i]
    • 其中 i i i是以 c u r cur cur为根的子树中的所有节点
    • 给出一个例图:
      树形dp问题分类【习题集】_第1张图片
      • 节点0只是自己添加的作为1的父亲节点,比如当前从1到6转移,那么如果 ∑ i = 1 n d i s ( i , 1 ) ∗ a [ i ] \sum_{i=1}^{n}dis(i,1)*a[i] i=1ndis(i,1)a[i]已经计算,假设我们知道了 s [ 6 ] s[6] s[6](定义见上文),那么显然 d p [ 6 ] = ( d p [ 1 ] − s [ 6 ] − s u m [ 6 ] + s u m [ 1 ] − s u m [ 6 ] ) + ( s [ 6 ] ) dp[6]=(dp[1]-s[6]-sum[6]+sum[1]-sum[6])+(s[6]) dp[6]=(dp[1]s[6]sum[6]+sum[1]sum[6])+(s[6])
      • 也就是分别计算以 c u r cur cur为根的子树的贡献和除去这颗子树的部分的贡献,可以发现 s [ c u r ] s[cur] s[cur]消掉了
    • 下面给出转移方程:
      s u m [ c u r ] = v a l u e [ c u r ] + ∑ s u m [ s o n i ] sum[cur]=value[cur]+\sum_{}^{}sum[son_i] sum[cur]=value[cur]+sum[soni] d p [ c u r ] = d p [ f a [ c u r ] ] + s u m [ 1 ] − 2 ∗ s u m [ c u r ] dp[cur]=dp[fa[cur]]+sum[1]-2*sum[cur] dp[cur]=dp[fa[cur]]+sum[1]2sum[cur]
  • 附代码:

    #include
    
    using namespace std;
    typedef long long ll;
    const int maxn = 200005;
    int n, u, v, a[maxn], fath[maxn];
    vector<int> vec[maxn];
    ll sum[maxn], s[maxn];
    ll dp[maxn];
    
    ll solve(int cur, int fa)
    {
        dp[cur] = dp[fath[cur]] + sum[1] - 2 * sum[cur];
        ll res = dp[cur];
        for(int i = 0; i < vec[cur].size(); i++)
        {
            if(vec[cur][i] != fa)
            {
                res = max(res, solve(vec[cur][i], cur));
            }
        }
        return res;
    }
    
    void dfs(int cur, int fa)
    {
        sum[cur] = a[cur];
        fath[cur] = fa;
        for(int i = 0; i < vec[cur].size(); i++)
        {
            if(vec[cur][i] != fa)
            {
                dfs(vec[cur][i], cur);
                sum[cur] += sum[vec[cur][i]];
                s[cur] += s[vec[cur][i]] + sum[vec[cur][i]];
            }
        }
    }
    
    int main()
    {
        scanf("%d", &n);
        for(int i = 1; i <= n; i++) scanf("%d", &a[i]);
        for(int i = 1; i < n; i++)
        {
            scanf("%d %d", &u, &v);
            vec[u].push_back(v);
            vec[v].push_back(u);
        }
        dfs(1, 0);
        dp[0] = s[1] + sum[1];
        printf("%lld\n", solve(1, 0));
    }
    

鸽子施工队

Description

鸽克多正在为他的国家规划一张设计图,在设计图上有n个城市(按1,2,3,··,,n标号),城市间共有n - 1条道路,使得任意两个城市间都有路径可以互达。但不幸的是,每条道路的施工方可能会放鸽子,即在实际施工结束后,每条道路有0.5的概率无法投入使用,导致可能存在一些城市没有路径可以互达定义。每座城市的交通指数为从该城市出发,分别到其他可达的城市的最短路径长度之和(路径长度为经过的道路数量)。对于给出的一张设计图,鸽克多想知道按该设计图施工后所有城市的交通指数之和的期望。

Input

有多组输入数据,第一行一个整数T表示数据组数,对于每组数据,有n行,每组数据第一行一个整数n表示城市数,接下来n − 1行,每行两个整数u(1 ⩽ u ⩽ n),v(1 ⩽ v ⩽ n),表示在该设计图上城市u和城市v之间有一条道路。1 ⩽ n ⩽ 1 0 5 10^5 105。保证对于所有数据 ∑ n \sum{n} n ⩽ 2 × 1 0 6 10^6 106

Output

对于每组输入数据,输出一行一个整数,表示按该设计图施工后所有城市的交通指数之和的期望,注意答案可以表示为一个分数 P Q \frac{P}{Q} QP, P 、 Q P、Q PQ互质,且 Q ≠ 0 m o d Q \neq 0 mod Q=0mod ( 1 0 9 + 7 ) (10^9 + 7) (109+7),你只需要输出 P × Q − 1 m o d ( 1 0 9 + 7 ) P × Q^{-1} mod (10^{9} + 7) P×Q1mod(109+7)。( Q − 1 Q^{−1} Q1 Q Q Q对模 1 0 9 + 7 10^9 + 7 109+7的逆元,即 Q × Q − 1 m o d ( 1 0 9 + 7 ) = 1 Q × Q^{-1} mod (10^9 + 7) = 1 Q×Q1mod(109+7)=1

Sample Input 1

4
1
2
1 2
3
1 2
2 3
5
1 2
1 3
2 4
2 5

Sample Output 1

0
1
3
500000013

Hint

输入文件较大,不建议使用cin, cout,推荐使用 scanf, printf

  • 题意:

    • 给你一棵树,让你求这个式子:
      ∑ i = 1 n ∑ j = 1 n d i s ( i , j ) × ( 1 2 ) d i s ( i , j ) \sum_{i=1}^{n}\sum_{j=1}^{n}{dis(i,j) \times(\frac{1}{2})^{dis(i,j)}} i=1nj=1ndis(i,j)×(21)dis(i,j)
  • 题解

    • 给出几个数组的定义:
      s [ c u r ] = ∑ i d i s ( i , c u r ) × ( 1 2 ) d i s ( i , c u r ) s[cur]=\sum_{}^{i}{dis(i,cur) \times(\frac{1}{2})^{dis(i,cur)}} s[cur]=idis(i,cur)×(21)dis(i,cur) s u m [ c u r ] = ∑ i ( 1 2 ) d i s ( i , c u r ) sum[cur]=\sum_{}^{i}{(\frac{1}{2})^{dis(i,cur)}} sum[cur]=i(21)dis(i,cur)

    • 其中 i i i c u r cur cur的子孙节点,即不包括 c u r cur cur
      d p _ s [ c u r ] = ∑ i d i s ( i , c u r ) × ( 1 2 ) d i s ( i , c u r ) dp\_s[cur]=\sum_{}^{i}{dis(i,cur) \times(\frac{1}{2})^{dis(i,cur)}} dp_s[cur]=idis(i,cur)×(21)dis(i,cur) d p _ s u m [ c u r ] = ∑ i ( 1 2 ) d i s ( i , c u r ) dp\_sum[cur]=\sum_{}^{i}{(\frac{1}{2})^{dis(i,cur)}} dp_sum[cur]=i(21)dis(i,cur)

    • 其中 i i i表示除去以 c u r cur cur为根的子树中的节点的所有节点

    • 类似上面的思想,读者可以自己手推一下,这里直接给出递推方程:
      s u m [ c u r ] = ∑ ( 1 2 s u m [ s o n i ] + 1 2 ) sum[cur]=\sum_{}^{}{(\frac{1}{2}sum[son_i]+\frac{1}{2})} sum[cur]=(21sum[soni]+21) s [ c u r ] = s u m [ c u r ] + ∑ 1 2 s [ s o n i ] s[cur]=sum[cur]+\sum_{}^{}{\frac{1}{2}s[son_i]} s[cur]=sum[cur]+21s[soni] d p _ s [ c u r ] = 1 2 ( d p _ s [ f a [ c u r ] ] + d p _ s u m [ f a [ c u r ] ] + 1 ) + 1 2 ( s [ f a [ c u r ] ] dp\_s[cur]=\frac{1}{2}(dp\_s[fa[cur]]+dp\_sum[fa[cur]]+1)+\frac{1}{2}(s[fa[cur]] dp_s[cur]=21(dp_s[fa[cur]]+dp_sum[fa[cur]]+1)+21(s[fa[cur]] − 1 2 ( s [ c u r ] + s u m [ c u r ] + 1 ) ) + 1 2 ( s u m [ f a [ c u r ] ] − 1 2 s u m [ c u r ] − 1 2 ) -\frac{1}{2}(s[cur]+sum[cur]+1))+\frac{1}{2}(sum[fa[cur]]-\frac{1}{2}sum[cur]-\frac{1}{2}) 21(s[cur]+sum[cur]+1))+21(sum[fa[cur]]21sum[cur]21) d p _ s u m [ c u r ] = 1 2 d p _ s u m [ f a [ c u r ] ] + 1 2 + 1 2 ( s u m [ f a [ c u r ] ] − 1 2 s u m [ c u r ] − 1 2 ) dp\_sum[cur]=\frac{1}{2}dp\_sum[fa[cur]]+\frac{1}{2}+\frac{1}{2}(sum[fa[cur]]-\frac{1}{2}sum[cur]-\frac{1}{2}) dp_sum[cur]=21dp_sum[fa[cur]]+21+21(sum[fa[cur]]21sum[cur]21)

    • 显然节点 c u r cur cur的结果为 d p _ s [ c u r ] + s [ c u r ] dp\_s[cur]+s[cur] dp_s[cur]+s[cur],最后将所有的相加即可,复杂度 O ( n ) O(n) O(n)

  • 附代码:

    #include
    
    using namespace std;
    typedef long long ll;
    const int maxn = 100005;
    const ll mod = 1e9 + 7;
    
    int t, n, u, v, fath[maxn];
    vector<int> vec[maxn];
    ll sum[maxn], s[maxn], dp_s[maxn], dp_sum[maxn], inv2;
    
    ll quick_pow(ll a, ll b)
    {
        ll res = 1ll;
        while(b)
        {
            if(b & 1) res = res * a % mod;
            a = a * a % mod;
            b >>= 1;
        }
        return res;
    }
    
    ll solve(int cur, int fa)
    {
        ll up = fa == 0 ? 0ll : ((dp_s[fath[cur]] + dp_sum[fath[cur]] + 1ll) % mod) * inv2 % mod;
        ll down = (((s[fath[cur]] - ((s[cur] + sum[cur] + 1) * inv2 % mod)) % mod + mod) % mod) * inv2 % mod;
        ll rest = (((sum[fath[cur]] - (inv2 * sum[cur] % mod) - inv2) % mod + mod) % mod) * inv2 % mod;
    
        dp_s[cur] = (up + down + rest) % mod;
        dp_sum[cur] = fa == 0 ? 0ll : ((dp_sum[fath[cur]] * inv2 % mod + inv2 + ((sum[fath[cur]] - (sum[cur] * inv2 % mod) - inv2) * inv2 % mod)) % mod + mod) % mod;
    
        ll res = (dp_s[cur] + s[cur]) % mod;
    
        for(int i = 0; i < vec[cur].size(); i++)
        {
            if(vec[cur][i] != fa)
            {
                res = (res + solve(vec[cur][i], cur)) % mod;
            }
        }
        return res;
    }
    
    void dfs(int cur, int fa)
    {
        fath[cur] = fa;
        for(int i = 0; i < vec[cur].size(); i++)
        {
            if(vec[cur][i] != fa)
            {
                dfs(vec[cur][i], cur);
                sum[cur] = (sum[cur] + (sum[vec[cur][i]] * inv2 % mod) + inv2) % mod;
                s[cur] = (s[cur] + ((inv2 * s[vec[cur][i]]) % mod)) % mod;
            }
        }
        s[cur] = (s[cur] + sum[cur]) % mod;
    }
    
    
    void init()
    {
        for(int i = 1; i <= n; i++) vec[i].clear();
        memset(sum, 0, sizeof(sum));
        memset(s, 0, sizeof(s));
        memset(dp_s, 0, sizeof(dp_s));
        memset(dp_sum, 0, sizeof(dp_sum));
    }
    
    int main()
    {
        scanf("%d", &t);
        inv2 = quick_pow(2ll, mod - 2);
        while(t--)
        {
            scanf("%d", &n);
            init();
            for(int i = 1; i < n; i++)
            {
                scanf("%d %d", &u, &v);
                vec[u].emplace_back(v);
                vec[v].emplace_back(u);
            }
            dfs(1, 0);
            s[0] = (s[1] + sum[1] + 1ll) * inv2 % mod;
            sum[0] = (sum[1] * inv2 % mod + inv2) % mod;
            printf("%lld\n", solve(1, 0));
        }
    }
    

D. 0-1-Tree

time limit per test 2 seconds
memory limit per test 256 megabytes
input standard input
output standard output
You are given a tree (an undirected connected acyclic graph) consisting of n vertices and n−1 edges. A number is written on each edge, each number is either 0 (let's call such edges 0-edges) or 1 (those are 1-edges).

Let’s call an ordered pair of vertices (x,y) (x≠y) valid if, while traversing the simple path from x to y, we never go through a 0-edge after going through a 1-edge. Your task is to calculate the number of valid pairs in the tree.

Input

The first line contains one integer n (2 ≤ ≤ n ≤ ≤ 200000) — the number of vertices in the tree.

Then n−1 lines follow, each denoting an edge of the tree. Each edge is represented by three integers x i , y i x_i, y_i xi,yi and c i c_i ci ( 1 ≤ x i , y i ≤ n , 0 ≤ c i ≤ 1 , x i   ! = y i ) (1 ≤ x_i,y_i≤ n, 0 ≤ c_i ≤ 1, x_i\ !=y_i) (1xi,yin,0ci1,xi !=yi) — the vertices connected by this edge and the number written on it, respectively.

It is guaranteed that the given edges form a tree.

Output

Print one integer — the number of valid pairs of vertices.

Example

input

7
2 1 1
3 2 0
4 2 1
5 2 0
6 7 1
7 2 1

output

34

Note

The picture corresponding to the first example:
树形dp问题分类【习题集】_第2张图片

  • 题意

    • 让你求有多少点对 ( x , y ) (x,y) (x,y)满足从x到y的路径上没有经过边权为1然后紧接着经过边权为0的情况
  • 题解

    • 先DFS一遍求以每个点为根的子树中有多少点到根的路径上又多少满足上述条件的点,然后再从根部往下DFS统计以每一个点作为终点(y)的点对数量
  • 附代码:

    #include
    
    using namespace std;
    const int maxn = 200005;
    typedef long long ll;
    
    struct node
    {
        int to, val;
    };
    
    int n, u, v, w;
    vector<node> vec[maxn];
    int dp[maxn][2];
    
    ll solve(int cur, int fa, int v)
    {
        ll res = dp[cur][0] + dp[cur][1], sum0 = 0, sum1 = 0;
        for(int i = 0; i < vec[cur].size(); i++){
            auto sun = vec[cur][i];
            if(sun.to != fa){
                if(sun.val) sum1 += dp[sun.to][0] + 1;
                else{
                    sum0+=dp[sun.to][0]+1;
                    for(int j = 0; j < vec[sun.to].size(); j++){
                        auto son = vec[sun.to][j];
                        if(son.to != cur && son.val) sum0 -= dp[son.to][0] + 1;
                    }
                }
            }
        }
    
        for(int i = 0; i < vec[cur].size(); i++){
            auto sun = vec[cur][i];
            if(sun.to != fa){
                if(sun.val) dp[sun.to][1] = sum0 + sum1 - (dp[sun.to][0] + 1) + dp[cur][1] + 1;
                else{
                    if(!v) dp[sun.to][1] += dp[cur][1];
                 	ll tot=dp[sun.to][0]+1;
                    for(int j = 0; j < vec[sun.to].size(); j++){
                        auto son = vec[sun.to][j];
                        if(son.to != cur && son.val) tot -= dp[son.to][0] + 1;
                    }
    
                    dp[sun.to][1] += sum0 - tot + 1;
                }
                res += solve(sun.to, cur, sun.val);
            }
    
        }
    
    
        return res;
    
    }
    
    void dfs(int cur, int fa)
    {
        for(int i = 0; i < vec[cur].size(); i++){
            if(vec[cur][i].to != fa){
                dfs(vec[cur][i].to, cur);
            }
        }
        for(int i = 0; i < vec[cur].size(); i++){
            auto sun = vec[cur][i];
            if(sun.to != fa){
                if(sun.val) dp[cur][0] += dp[sun.to][0] + 1;
                else{
                    dp[cur][0] += dp[sun.to][0]+1;
                    for(int j = 0; j < vec[sun.to].size(); j++){
                        auto grand = vec[sun.to][j];
                        if(grand.to != cur && grand.val) dp[cur][0] -= dp[grand.to][0] + 1;
                    }
                }
            }
        }
    }
    
    int main()
    {
        scanf("%d", &n);
        for(int i = 1; i < n; i++){
            scanf("%d %d %d", &u, &v, &w);
            vec[u].push_back(node{v, w});
            vec[v].push_back(node{u, w});
        }
        dfs(1, 0);
        printf("%lld\n", solve(1, 0, 0));
    }
    

A. The Fair Nut and the Best Path

time limit per test3 seconds
memory limit per test256 megabytes
inputstandard input
outputstandard output

The Fair Nut is going to travel to the Tree Country, in which there are ? cities. Most of the land of this country is covered by forest. Furthermore, the local road system forms a tree (connected graph without cycles). Nut wants to rent a car in the city ? and go by a simple path to city ?. He hasn’t determined the path, so it’s time to do it. Note that chosen path can consist of only one vertex.

A filling station is located in every city. Because of strange law, Nut can buy only ?? liters of gasoline in the ?-th city. We can assume, that he has infinite money. Each road has a length, and as soon as Nut drives through this road, the amount of gasoline decreases by length. Of course, Nut can’t choose a path, which consists of roads, where he runs out of gasoline. He can buy gasoline in every visited city, even in the first and the last.

He also wants to find the maximum amount of gasoline that he can have at the end of the path. Help him: count it.

Input

The first line contains a single integer n ( 1 ≤ n ≤ 3 ⋅ 1 0 5 ) n (1≤n≤3⋅10^5) n(1n3105) — the number of cities.

The second line contains ? integers w 1 , w 2 , … , w n ( 0 ≤ w i ≤ 109 ) w_1,w_2,…,w_n (0≤w_i≤109) w1,w2,,wn(0wi109) — the maximum amounts of liters of gasoline that Nut can buy in cities.

Each of the next ?−1 lines describes road and contains three integers u , v , v ( 1 ≤ u , v ≤ n , 1 ≤ c ≤ 1 0 9 , u   ! = v ) u, v, v (1≤u,v≤n, 1≤c≤10^9, u\ !=v) u,v,v(1u,vn,1c109,u !=v), where ? and ? — cities that are connected by this road and ? — its length.

It is guaranteed that graph of road connectivity is a tree.

Output

Print one number — the maximum amount of gasoline that he can have at the end of the path.

Examples

input
3
1 3 3
1 2 2
1 3 2
output
3
input
5
6 3 2 5 0
1 2 10
2 3 3
2 4 1
1 5 1
output
7

Note

The optimal way in the first example is 2→1→3.
The optimal way in the second example is 2→4.

  • 题意:

    • 给你一棵树,求经过每条边都有一定的油量花费,每个节点都可以加一定的油,问:找出两个点 u , v u,v u,v使得从 u u u出发到 v v v后剩余的油量最大(u可以等于v),求最大值
  • 解法:

    • 首先需要明确的是从 u u u v v v后剩余的油量与从 v v v u u u后剩余的油量相同
    • d f s dfs dfs一遍求以 c u r cur cur为根的子树中的所有节点到cur的油量剩余最大值,记为 d p [ 0 ] [ c u r ] dp[0][cur] dp[0][cur],然后再 d f s dfs dfs一次计算答案,第二次 d f s dfs dfs计算的是除去以 c u r cur cur为根的子树的节点到 c u r cur cur的油量剩余最大值,记为 d p [ 1 ] [ c u r ] dp[1][cur] dp[1][cur],注意只有 c u r cur cur一个节点的情况只能统计在一个 d f s dfs dfs中,所以答案就是
      a n s = max ⁡ 1 n d p [ 0 ] [ c u r ] + d p [ 1 ] [ c u r ] ans=\max_{1}^{n}{dp[0][cur]+dp[1][cur]} ans=1maxndp[0][cur]+dp[1][cur]
  • 附代码:

    #include
    
    using namespace std;
    typedef long long ll;
    const int maxn=300005;
    
    struct node{
    		int pos,cost;
    		node(int a=0,int b=0){
    			pos=a;cost=b;
    		}
    };
    vector<node> vec[maxn];
    int val[maxn],n,u,v,w;
    ll dp[2][maxn];
    set<pair<ll,int> >s;
    
    void dfs(int cur,int fa)
    {
    	dp[0][cur]=val[cur];
    	for(int i=0;i<vec[cur].size();i++){
    		auto nxt=vec[cur][i];
    		if(nxt.pos!=fa){
    			dfs(nxt.pos,cur);
    			dp[0][cur]=max(dp[0][cur],dp[0][nxt.pos]+1LL*val[cur]-nxt.cost);
    		}
    	}
    }
    
    ll solve(int cur,int fa)
    {
    	ll res=(dp[0][cur]+dp[1][cur]);
    
    	for(int i=0;i<vec[cur].size();i++){
    		auto nxt=vec[cur][i];
    		if(nxt.pos!=fa){
    			s.insert(make_pair(dp[0][nxt.pos]-nxt.cost+val[cur],nxt.pos));
    		}
    	}
    	for(int i=0;i<vec[cur].size();i++){
    		auto nxt=vec[cur][i];
    		if(nxt.pos!=fa){
    			s.erase(make_pair(dp[0][nxt.pos]-nxt.cost+val[cur],nxt.pos));
    			if(!s.empty()){
    				auto maxx=*s.rbegin();
    				dp[1][nxt.pos]=max(dp[1][nxt.pos],max(maxx.first-nxt.cost,1LL*val[cur]-nxt.cost));
    			}else dp[1][nxt.pos]=max(dp[1][nxt.pos],1LL*val[cur]-nxt.cost);
    			s.insert(make_pair(dp[0][nxt.pos]-nxt.cost+val[cur],nxt.pos));
    		}
    	}
    
    	s.clear();
    
    	for(int i=0;i<vec[cur].size();i++){
    		auto nxt=vec[cur][i];
    		if(nxt.pos!=fa){
    			res=max(res,solve(nxt.pos,cur));
    		}
    	}
    	return res;
    }
    
    int main()
    {
    	scanf("%d",&n);
    	for(int i=1;i<=n;i++) scanf("%d",&val[i]);
    	for(int i=1;i<n;i++){
    		scanf("%d %d %d",&u,&v,&w);
    		vec[u].push_back(node(v,w));
    		vec[v].push_back(node(u,w));
    	}
    	dfs(1,0);
    	printf("%lld\n",solve(1,0));
    
    }
    

你可能感兴趣的:(树上dp)