Labeling the Tree with Distances(换根DP,多项式哈希,EDU)

#include 
using namespace std;
#define all(a) (a).begin(), (a).end()
#define ll long long
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const int N = 2e5 + 10;

//https://codeforces.com/problemset/problem/1794/E
/*
换根DP + 多项式哈希
EDU

f(x) = 0 * B^0 + c1 * B^1 + c2 * B^2 + ...
ci * B^i :ci 表示到点 x 距离为 i 的数量

通过换根 DP 可以求出每个点作为根时的 hash值
又因为有一个权值可以任意,所以一共有 n 种哈希值
只需判断每个点为根时的哈希值是否是 n 个中的一个即可
*/

int n, a[N];
ll p, mod = (ll)1e18 + 9, pp[N];
vector g[N];
ll f[N], f2[N];

ll mul(ll x, ll y)
{
    return (__int128)x * y % mod;
}

void dfs1(int x, int fa)
{
    f[x] = 1;
    for(auto &to : g[x])
    {
        if(to == fa) continue;
        dfs1(to, x);
        f[x] = (f[x] + mul(f[to], p)) % mod;
    }
}

void dfs2(int x, int fa)
{
    if(fa == -1) f2[x] = f[x]; 
    else f2[x] = (mul((f2[fa] - mul(f[x], p) + mod) % mod, p) + f[x]) % mod;
    
    for(auto &to : g[x]) if(to != fa) dfs2(to, x);
}

void solve()
{
    cin >> n;
    for(int i = 1; i < n; i++) cin >> a[i];

    for(int i = 1, u, v; i < n; i++)
    {
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }

    std::mt19937 rng(std::chrono::steady_clock::now().time_since_epoch().count());
    p = rng() + 123456;

    pp[0] = 1;
    for(int i = 1; i <= n; i++) pp[i] = mul(pp[i - 1], p);
    
    ll sum = 0;
    for(int i = 1; i < n; i++) sum = (sum + pp[a[i]]) % mod;
    
    vector v(n);
    for(int i = 0; i < n; i++) v[i] = (sum + pp[i]) % mod;
    sort(all(v));

    dfs1(1, -1);
    dfs2(1, -1);

    vector ans;
    for(int i = 1; i <= n; i++) if(binary_search(all(v), f2[i])) ans.push_back(i); 

    cout << ans.size() << '\n';
    for(auto &it : ans) cout << it << ' ';
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    int t = 1;
    // cin >> t;
    while (t--)
        solve();
    return 0;
}

你可能感兴趣的:(算法,哈希,换根DP)