bzoj3257: 树的难题

题目链接

bzoj3257: 树的难题

题解

dp一下

代码

#include 
#include 
#include 
inline int read() { 
    int x = 0,f = 1; 
    char c = getchar(); 
    while(c < '0' || c > '9'){if(c == '-')f = -1; c = getchar(); } 
    while(c <= '9' && c >= '0') x = x * 10 + c - '0',c = getchar(); 
    return x *f ; 
}
#define int long long 
const int maxn = 300007; 
int n; 
struct node { 
    int v,next,w; 
} edge[maxn << 1]; 
int head[maxn],num = 0; 
inline void add_edge(int u,int v,int w )  { 
    edge[++ num].v = v;edge[num].next = head[u];head[u] = num; edge[num].w = w; 
} 
int a[maxn]; 
#define INF  100000000000000000ll

//没有黑点or只有一个白点
int dp[maxn][2][3];//x的子数中颜色k的个数 
int tmp[2][3];//notice 如果要断开的话需要辅助数组 
void dfs(int x,int fa = 0) { 
    dp[x][a[x] == 0][a[x] == 1] = 0; 
    for(int e = head[x];e;e = edge[e].next) { 
        int v = edge[e].v; 
        if(v == fa) continue; 
        dfs(v,x); 
        for(int i = 0;i <= 1;++i) for(int j = 0;j <= 2;++ j) tmp[i][j] = INF; 
        
        for(int i = 0;i <= 1;++ i) for(int j = 0;j <= 2;++ j) { 
                if(dp[x][i][j] == INF) continue; 
                for(int k = 0;k <= 1;++ k)  for(int l = 0;l <= 2;++ l) { 
                    if(dp[v][k][l] == INF) continue; 
                    int t1 = k + i >= 1 ? 1  : k + i;  
                    int t2 = j + l >= 2 ? 2 : j + l; 
                    tmp[t1][t2] = std::min(tmp[t1][t2],dp[x][i][j] + dp[v][k][l]); 
                    if(!k || l <= 1) tmp[i][j] = std::min(tmp[i][j],dp[x][i][j] + dp[v][k][l] + edge[e].w); 
                }  
        } 
        std::memcpy(dp[x],tmp,sizeof tmp);  
    } 
} 
main() { 
    int T = read(); 
    while(T --) { 
        memset(dp,0x3f,sizeof dp); 
        memset(head,0,sizeof head); 
        num = 0; 
        n = read(); 
        for(int i = 1;i <= n;++ i) a[i] = read(); 
        for(int u,v,w,i = 1;i < n;++ i) { 
            u = read(),v = read();w = read(); 
            add_edge(u,v,w); 
            add_edge(v,u,w); 
        } 
        dfs(1); 
        int ans = INF; 
        for(int i = 0;i < 2;++ i) 
            for(int j = 0;j < 3;++ j) 
                if(!i || j < 2) ans = std::min(ans,dp[1][i][j]); 
        printf("%lld\n",ans);
    } 
    return 0; 
}
 

转载于:https://www.cnblogs.com/sssy/p/9649189.html

你可能感兴趣的:(bzoj3257: 树的难题)