SPOJ Two Paths(树形dp,最大不相交路径长度乘积)

题目链接:
SPOJ Two Paths
题意:
给一个 n 个节点和 n1 条边的树,求两条不相交(无公共节点)的路径长度乘积最大值?(路径长度就是路径上边的数量)
数据范围: n105
分析:
这道题和Codeforces 633 F The Chocolate Spree是其实一样的。本来以为会好些点,实际上还是写了好久。。。。主要是细节太多了,有的地方数组的定义也不大一样,不多说了。。。

#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
typedef long long ll;
const int MAX_N = 100010;

int n, total;
int head[MAX_N];
ll up[MAX_N], down[MAX_N], best[MAX_N];
ll predown[MAX_N], ppredown[MAX_N], sufdown[MAX_N], ssufdown[MAX_N];
ll prebest[MAX_N], sufbest[MAX_N];
vector<int> child;

struct Edge {
    int to, next;
}edge[MAX_N * 2];

inline void AddEdge(int from, int to)
{
    edge[total].to = to;
    edge[total].next = head[from];
    head[from] = total++;
}
//down[i]:从i到叶子节点的最长路径,ddown:次长路径
//best[i]:i子树内部的最长链,best可以经过根节点
void dfs(int u, int p)
{
    ll Max = 0, MMax = 0;
    int cnt = 0;
    for (int i = head[u]; i != -1; i = edge[i].next) {
        int v = edge[i].to;
        if (v == p) continue;
        dfs(v, u);
        cnt++;
        if (down[v] > Max) {
            MMax = Max;
            Max = down[v];
        } else if (down[v] > MMax) {
            MMax = down[v];
        }
        best[u] = max(best[u], best[v]);
    }
    if (cnt == 0) return;
    down[u] = Max + 1;
    if (cnt > 1) best[u] = max(best[u], Max + MMax + 2); //细节
    else best[u] = max(best[u], Max + MMax + 1);  //细节
    //printf("down[%d] = %lld best[%d] = %lld\n", u, down[u], u, best[u]);
}

void solve()
{
    ll ans = 0;
    queueint, int> > que;
    que.push(make_pair(1, -1));
    while (!que.empty()) {
        pair<int, int> cur = que.front();
        que.pop();
        int u = cur.first, p = cur.second;
        child.clear();
        child.push_back(0);
        for (int i = head[u]; i != -1; i = edge[i].next) {
            int v = edge[i].to;
            if (v == p) continue;
            child.push_back(v);
        }
        int size = child.size();
        // 前缀down最大和次大,前缀best最大
        // predown和ppredown包括了兄弟节点到父亲的路径
        prebest[0] = predown[0] = ppredown[0] = 0;
        for (int i = 1; i < size; ++i) {
            int v = child[i];
            prebest[i] = max(prebest[i - 1], best[v]);

            predown[i] = predown[i - 1], ppredown[i] = ppredown[i - 1];
            if (down[v] + 1 > predown[i]) {
                ppredown[i] = predown[i];
                predown[i] = down[v] + 1;
            } else if (down[v]+ 1 > ppredown[i]) {
                ppredown[i] = down[v] + 1;
            }
        }

        sufdown[size] = ssufdown[size] = sufbest[size] = 0;
        for (int i = size - 1; i >= 1; --i) {
            int v = child[i];
            sufbest[i] = max(sufbest[i + 1], best[v]);

            sufdown[i] = sufdown[i + 1], ssufdown[i] = ssufdown[i + 1];
            if (down[v] + 1 > sufdown[i]) {
                ssufdown[i] = sufdown[i];
                sufdown[i] = down[v] + 1;
            } else if (down[v] + 1 > ssufdown[i]) {
                ssufdown[i] = down[v] + 1;
            }
        }
        //up[i]包含i到i的父亲的路径
        for (int i = 1; i < size; ++i) {
            int v = child[i];
            ll outside = up[u] + max(predown[i - 1], sufdown[i + 1]);
            outside = max(outside, predown[i - 1] + ppredown[i - 1]);
            outside = max(outside, sufdown[i + 1] + ssufdown[i + 1]);
            outside = max(outside, predown[i - 1] + sufdown[i + 1]);
            outside = max(outside, prebest[i - 1]);
            outside = max(outside, sufbest[i + 1]);
            //printf("v = %d outside = %lld best[v] = %lld\n", v, outside, best[v]);
            ans = max(ans, outside * best[v]);
        }
        // predown/sufdown 算上了兄弟节点到父亲节点的路径
        // up[v]包含了v到父亲u的路径
        for (int i = 1; i < size; ++i) {
            int v = child[i];
            up[v] = 1 + max(up[u], max(predown[i - 1], sufdown[i + 1]));
            que.push(make_pair(v, u));
        }
    }
    printf("%lld\n", ans); 
}

int main()
{
    freopen("G.in", "r", stdin);

    while (~scanf("%d", &n)) {
        memset(head, -1, sizeof(head));
        total = 0;
        for (int i = 1; i < n; ++i) {
            int u, v;
            scanf("%d%d", &u, &v);
            AddEdge(u, v);
            AddEdge(v, u);
        }
        memset(down, 0, sizeof(down));
        memset(best, 0, sizeof(best));
        memset(up, 0, sizeof(up));
        dfs(1, -1);
        solve();
    }
    return 0;
}

你可能感兴趣的:(树形dp)