LOJ#2433. 「ZJOI2018」线图 题解

隔了好久才发了这一发loj的题解。

但是,ZJOI为什么一直有九条可怜这个名字???

算了,不生气。

直接看代码吧。

/*
*Made by Ying Youyu.
*/
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
vector E[5005];
vector kind[1250];
vector T[15];
int rnd[2005];
int dp[5005][1250];
int crt[20] , ntd[1250];
int C[5005][20];
int inv[21];
const int mod = 998244353;
const int inv6 = 166374059;
int n , kk , ans = 0 , limit , cnt = 0;
int p[20] , tot = 0 , e_tot = 0;
map mp;
map hashh;
map , int> edge;
int seed = 91478513 , a = 16554871 , b = 35659598;
int power(int a,int b)
{
    int temp = a , ans = 1;
    while(b){
        if(b&1) ans = (1LL * ans * temp) % mod;
        temp = (1LL * temp * temp) % mod;
        b >>= 1;
    }
    return ans;
}
inline int rand()
{
    return seed = (1LL * (seed ^ a) * b) % mod;
}
int get_hash(int fa,int u)
{
    long long h = 1;
    vector hh;
    for(int i = 0;i < T[u].size();i++){
        if(T[u][i] != fa){
            int g = get_hash(u , T[u][i]);
            h += rnd[g];
            hh.push_back(g);
        }
    }
    map::iterator it = hashh.find(h);
    if(it != hashh.end()) return it->second;
    hashh.insert(pair{h , ++tot});
    kind[tot] = hh;ntd[tot] = 1;
    sort(kind[tot].begin() , kind[tot].end());
    int qq = 1 , start = 0;ntd[tot] = 1;
    for(;start < kind[tot].size() && kind[tot][start] == 1;start++);
    for(int i = start + 1;i < kind[tot].size();i++){
        if(kind[tot][i] == 1) continue;
        if(kind[tot][i] == kind[tot][i - 1]) qq++;
        else{
            ntd[tot] = (1LL * ntd[tot] * inv[qq]) % mod;
            qq = 1;
        }
    }
    ntd[tot] = (1LL * ntd[tot] * inv[qq]) % mod;
    return tot;
}
int get_p(int fa,int u)
{
    long long h = 1;
    for(int i = 0;i < T[u].size();i++){
        if(T[u][i] != fa){
            int g = get_p(u , T[u][i]);
            if(g == -1) return -1;
            h += rnd[g];
        }
    }
    map::iterator it = hashh.find(h);
    if(it != hashh.end()) return it->second;
    return -1;
}
inline int get_num(int u,int v)
{
    if(u > v) swap(u , v);
    map , int>::iterator it = edge.find(pair{u , v});
    if(it != edge.end()) return it->second;
    e_tot++;
    edge.insert(pair,int>{pair{u , v} , e_tot});
    return e_tot;
}
int get_node(int cnode , int k , vector G[])
{
    edge.clear();e_tot = 0;
    int q = 0;
    for(int i = 1;i <= cnode;i++) q += G[i].size();
    q /= 2;
    if(k == 2){
        int ans = 0;
        for(int i = 1;i <= cnode;i++) ans = (ans + 1LL * (G[i].size() - 1) * G[i].size()) % mod;
        return (1LL * ans * inv[2]);
    }
    if(k == 3){
        int ans = 0;
        for(int i = 1;i <= cnode;i++){
            for(int j = 0;j < G[i].size();j++){
                if(i > G[i][j]) continue;
                if(G[i].size() + G[G[i][j]].size() < 4) continue;
                int d = G[i].size() + G[G[i][j]].size() - 2;
                ans = (ans + 1LL * d * (d - 1) % mod * inv[2]) % mod;
            }
        }
        return ans;
    }
    if(k == 4){
        vector ed[cnode + 1];
        int ans = 0;
        for(int i = 1;i <= cnode;i++){
            for(int j = 0;j < G[i].size();j++){
                if(i > G[i][j]) continue;
                int d0 = G[i].size()+G[G[i][j]].size()-2;
                ed[i].push_back (d0 - 1), ed[G[i][j]].push_back (d0 - 1);
                if (d0 > 1) ans = (ans + 1LL * d0 * (d0 - 1) % mod * (d0 - 2) % mod) % mod;
            }
        }
        for(int i = 1;i <= cnode;i++){
            long long x = 0, y = 0;
            for (int j = 0; j < ed[i].size(); ++ j)    if (ed[i][j]>0)
                x = (x + ed[i][j]) % mod, y = (y+1LL * (ed[i][j] * ed[i][j]) % mod) % mod;
            ans = (ans + (x * x % mod - y + mod) % mod) % mod;
        }
        return 1LL * ans * inv[2] % mod;
    }
    vector G2[q + 1];
    for(int i = 1;i <= cnode;i++){
        if(G[i].size() < 2) continue;
        int pst[G[i].size()];
        for(int j = 0;j < G[i].size();j++) pst[j] = get_num(i , G[i][j]);
        for(int j = 0;j < G[i].size();j++){
            for(int k = j + 1;k < G[i].size();k++){
                G2[pst[j]].push_back(pst[k]);G2[pst[k]].push_back(pst[j]);
            }
        }
    }
    return get_node(e_tot , k - 1 , G2);
}
bool connect[(1<<10) + 1];
int neigh[20];
int brout(vector G[])
{
    for(int i = 0;i < (1<> j-1) & 1) mask |= neigh[j];
        }
        for(int j = 1;j <= limit;j++){
            if((mask >> j - 1) & 1) {connect[i | (1<> j-1) & 1) == 0) continue;
            if(!pp) pp = j;
            for(int k = 0;k < G[j].size();k++){
                if((i>>G[j][k]-1) & 1) {T[j].push_back(G[j][k]);}
            }
        }
        ans = (ans + mp[get_hash(0 , pp)]) % mod;
    }
    return ans;
}
void count()
{
    for(int i = 1;i <= limit;i++) T[i].clear();
    for(int i = 2;i <= limit;i++){
        T[i].push_back(p[i]);T[p[i]].push_back(i);
    }
    int h = get_hash(0 , 1);
    map::iterator it = mp.find(h);
    if(it != mp.end()) return;
    for(int i = 2;i <= limit;i++){
        int g = get_p(0 , i);
        if(g == -1) continue;
        map::iterator it = mp.find(g);
        if(it != mp.end()) {mp.insert(pair{h , it->second});return;}
    }
    vector W[limit + 1];
    for(int i = 1;i <= limit;i++){
        for(int j = 0;j < T[i].size();j++) W[i].push_back(T[i][j]);
    }
    int t = get_node(limit , kk , T);
    int g = brout(W);
    t = (t + mod - g) % mod;
    mp.insert(pair{h , t});
    return;
}
void dfs(int x)
{
    if(x == limit){
        count();return;
    }
    for(int i = 1;i <= x;i++){
        p[x + 1] = i;dfs(x + 1);
    }
    return;
}
void find(int fa,int u)
{
    dp[u][1] = 1;
    for(int i = 0;i < E[u].size();i++){
        if(E[u][i] != fa) find(u , E[u][i]);
    }
    if(E[u].size() == 1 && fa != 0) return;
    int siz = (fa == 0) ? E[u].size() : E[u].size() - 1;
    int cop[siz + 1][(1<>p) & 1){
                        cop[g][k] = (cop[g][k] + 1LL * cop[g - 1][k ^ (1<= kind[i].size())  dp[u][i] = (1LL * cop[siz][(1<{1 , 0});
    for(int i = 2;i <= kk + 1;i++){
        limit = i;
        dfs(1);
    }cnt = tot;
    find(0 , 1);
    int ans = 0;
    for(int i = 1;i <= n;i++){
        for(int j = 1;j <= cnt;j++){
            ans = (ans + 1LL * dp[i][j] * mp[j]) % mod;
        }
    }
    printf("%d\n",ans);
    return 0;
}

 

你可能感兴趣的:(题解)