HDOJ 4670: Cube number on a tree

题目链接:

http://acm.hdu.edu.cn/showproblem.php?pid=4670


题目大意:

树的每个节点有一个点权,所有的点权都可以被给定的30个质数表示出来。

在树上找合法点对。

合法点对指的是,两点间路径上的所有点(含端点)的点权乘积是立方数的点对。

注意:点对中的两个点可以是相同的,这个坑了我好久,切~


算法:

树的点分治。

每次处理的时候,用一个map保存之前遍历过的子树中的节点到根的路径值,另一个保存当前正在遍历的这棵子树里的路径值。

然后很容易就可以求出以这个根为LCA的合法点对。


PS:这么个题还写了4K,有点儿伤哦。。下回果断搞个FOREACH宏

树分治技能初步get,耶!


代码:

#pragma comment(linker,"/STACK:102400000,102400000")
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<sstream>
#include<cstdlib>
#include<cstring>
#include<string>
#include<climits>
#include<cmath>
#include<queue>
#include<vector>
#include<stack>
#include<set>
#include<map>
#define INF 0x3f3f3f3f
#define eps 1e-8
#define mp make_pair
#define pb push_back
#define st first
#define nd second
using namespace std;

const int MAXN = 50010;
const int MAXM = 30;
typedef long long LL;

int siz[MAXN], maxb[MAXN];
bool vis[MAXN];
vector <int> node;
vector <int> mm[MAXN];
LL tnum[MAXN];
LL prim[MAXM];
LL pow3[MAXM];
map <LL, LL> map1, map2;
typedef map <LL, LL> :: iterator mapit;
int n,k;
LL ans;

inline LL hash(LL x)
{
    LL tmp = 0LL;
    for (int i = 0; i < k; i ++)
    {
        int cot = 0;
        while (x % prim[i] == 0)
        {
            x /= prim[i];
            cot ++;
        }
        tmp += (cot % 3) * pow3[i];
    }
    return tmp;
}

inline LL fadd(LL x, LL y)
{
    LL tmp = 0LL;
    for (int i = 0 ; i < k; i ++)
    {
        LL ret1 = (x / pow3[i]) % 3;
        LL ret2 = (y / pow3[i]) % 3;
        tmp += (ret1 + ret2) % 3 * pow3[i];
    }
    return tmp;
}

inline LL frev(LL x)
{
    LL tmp = 0LL;
    for (int i = 0 ; i < k; i ++)
    {
        LL ret = (x / pow3[i]) % 3;
        tmp += ((3 - ret) %3) * pow3[i];
    }
    return tmp;
}

void pre_dfs(int u, int p)
{
    node.pb(u);
    maxb[u] = 0;
    siz[u] = 1;
    for (int i = 0; i < mm[u].size(); i ++)
    {
        int v = mm[u][i];
        if (v == p || vis[v])
        {
            continue;
        }
        pre_dfs(v, u);
        siz[u] += siz[v];
        maxb[u] = max(maxb[u], siz[v]);
    }
}

void dfs(int u, int p, LL tmp)
{
    mapit it;
    (it = map2.find(tmp)) == map2.end() ?
    map2[tmp] = 1 : (it -> nd) ++;
    for (int i = 0; i < mm[u].size(); i ++)
    {
        int v = mm[u][i];
        if (v == p || vis[v])
        {
            continue;
        }
        dfs(v, u, fadd(tmp, tnum[v]));
    }
}

void cal(int root)
{
    map1[0] = 1LL;
    if(! tnum[root])
    {
        ans ++;
    }
    for (int i = 0; i < mm[root].size(); i ++)
    {
        int v = mm[root][i];
        if (vis[v])
        {
            continue;
        }
        dfs(v, root, tnum[v]);
        for (mapit it1 = map2.begin(); it1 != map2.end(); it1 ++)
        {
            mapit it2 = map1.find(frev(fadd(it1 -> st, tnum[root])));
            if (it2 != map1.end())
            {
                ans += (it1 -> nd) * (it2 -> nd);
            }
        }
        for (mapit it1 = map2.begin(); it1 != map2.end(); it1 ++)
        {
            mapit it2;
            (it2 = map1.find(it1 -> st)) == map1.end() ?
            map1[it1 -> st] = it1 -> nd : (it2 -> nd) += (it1 -> nd);
        }
        map2.clear();
    }
    map1.clear();
}

void solve(int u)
{
    node.clear();
    pre_dfs(u, -1);
    int num = node.size();
    int root, tmp = INT_MAX;
    for (int i = 0; i < num; i ++)
    {
        maxb[node[i]] = max(maxb[node[i]], num - maxb[node[i]] - 1);
        if (tmp > maxb[node[i]])
        {
            tmp = maxb[node[i]];
            root = node[i];
        }
    }
    vis[root] = true;
    cal(root);
    for (int i = 0; i < mm[root].size(); i ++)
    {
        int v = mm[root][i];
        if (!vis[v])
        {
            solve(v);
        }
    }
}

int main()
{
    pow3[0] = 1LL;
    for(int i = 1; i < MAXM; i ++)
    {
        pow3[i] = pow3[i - 1] * 3;
    }
    while (scanf("%d %d", &n, &k) == 2)
    {
        ans = 0LL;
        memset(vis, 0, sizeof(vis));
        for (int i = 0; i < n; i ++)
        {
            mm[i].clear();
        }
        for (int i = 0; i < k; i ++)
        {
            scanf("%I64d", &prim[i]);
        }
        for (int i = 0; i < n; i ++)
        {
            scanf("%I64d", &tnum[i]);
            tnum[i] = hash(tnum[i]);
        }
        for (int i = 1; i < n; i ++)
        {
            int u, v;
            scanf("%d %d", &u, &v);
            u --;
            v --;
            mm[u].push_back(v);
            mm[v].push_back(u);
        }
        solve(0);
        printf("%I64d\n", ans);
    }
    return 0;
}


你可能感兴趣的:(HDOJ 4670: Cube number on a tree)