hdu6035_Colorful Tree_(树形DP)

#include 
#include 
#include 
#include 
#include 
#include 
#define INF 0x3f3f3f3f
#define rep0(i, n) for (int i = 0; i < n; i++)
#define rep1(i, n) for (int i = 1; i <= n; i++)
#define rep_0(i, n) for (int i = n - 1; i >= 0; i--)
#define rep_1(i, n) for (int i = n; i > 0; i--)
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
#define MIN(x, y) (((x) < (y)) ? (x) : (y))
#define mem(x, y) memset(x, y, sizeof(x))
#define MAXN 200000 + 10
/**
题目大意
树上每个节点有一个用数表示的颜色 定义一条路径的权值为
路上出现的颜色的种类 计算树中所有路径的权值之和
思路
树上所有路径的权值和 = ∑路径颜色种类 = 每种颜色所经过的路径数
直接计算每种颜色所经过的路径数非常非常困难 换了n种方法 死了一堆脑细胞 
也还是出不来
换个角度,计算每种颜色不经过次颜色的路径数就简单多了
 ans = 颜色数 * 路径总数 - 每种颜色不经过的路径数
*/
using namespace std;
typedef long long LL;
int col[MAXN];
vector g[MAXN];

void addEdge(int u, int v)
{
    g[u].push_back(v);
    g[v].push_back(u);

}
LL ans;    //各子树节点数
int n, dp[MAXN];
LL comb(int n)
{
    return (LL)n * (n - 1) / 2;
}
int dfs(int u, int fa)
{
    int myCol = col[u], num = 0, cnt = 0;
    for (int i = 0; i < g[u].size(); i++)
    {
        int v = g[u][i];
        if (v == fa)
            continue;
        int pre = dp[myCol], tmp;
        tmp = dfs(v, u);
        num += tmp;
        cnt += dp[myCol] - pre;
        ans -= comb(tmp - dp[myCol] + pre);


    }
    dp[myCol] += num + 1 - cnt;
    return num + 1;
}
bool book[MAXN];
int cols, colM;
int main()
{
    #ifndef ONLINE_JUDGE
        freopen("in.txt", "r", stdin);
    #endif // ONLINE_JUDGE
    int u, v, kase = 0;
    while (scanf("%d", &n) != EOF)
    {
        mem(book, 0);
        mem(dp, 0);
        ans = 0;
        colM = 0;
        cols = 0;
        for (int i = 1; i <= n; i++)
        {
            scanf("%d", col + i);
            colM = MAX(colM, col[i]);
            if (book[col[i]] == false)
            {
                book[col[i]] = true;
                cols++;
            }
            g[i].clear();
        }

        for (int i = 0; i < n - 1; i++)
        {
            scanf("%d %d", &u, &v);
            addEdge(u, v);
        }
        ans += (LL)cols * comb(n);

        //cout << cols << endl;
        dfs(1, 0);
        for (int i = 1; i <= colM; i++)
            if (book[i])
                ans -= comb(n - dp[i]);
        //cout << ans << endl;
        //dfs1(1, 0);
        printf("Case #%d: %lld\n", ++kase, ans);
    }




    return 0;
}
/*
10
1 2 1 4 2 4 7 3 7 6
1 2
3 1
4 2
5 2
6 1
7 5
8 6
9 3
10 4
8
1 2 4 3 2 1 4 3
1 2
1 3
2 4
2 5
3 6
7 4
8 3
2
1 1
1 2

12
1 1 1 1 1 1 1 1 1 1 1 1
1 2
1 3
1 4
2 5
2 6
3 7
4 8
4 9
6 10
6 11
8 12
4
1 2 2 3
1 2
3 1
4 2
*/

你可能感兴趣的:(DP)