JSOI 2008 最小生成树计数
今天的题目终于良心一点辣
一个套路+模版题。
考虑昨天讲的那几个结论,我们有当我们只保留最小生成树中权值不超过 $ k $ 的边的时候形成的联通块是一定的。
我们可以先拿 kruskal 跑一棵最小生成树,然后我们可以从小到大枚举边权,把所有除开枚举到的边权的边全部加入并且缩点。现在我们就在这个缩点后的点集进行生成树计数就好了。答案就是每种边权算出答案的积。
因为我们知道,连入 $ k $ 边权的边后对于 $ 1 $ 到 $ k - 1 $ 的边加入后的最小生成树的影响是固定的,加入后得到的联通块必然一样。所以加入每个权值实际上是独立的!
但是这题很sb的一点在于,它同种边权数量不超过 10 可以直接暴力。如果是100就得矩阵树定理了。(但是这题模数非质数很烦)
#include "iostream"
#include "algorithm"
#include "cstring"
#include "cstdio"
#include "vector"
using namespace std;
#define P 31011
#define MAXN 106
int n , m;
int power( int a , int x ) {
int cur = a % P , ans = 1;
while( x ) {
if( x & 1 ) ans = ans * cur % P;
cur = cur * cur % P , x >>= 1;
}
return ans;
}
struct ed {
int u , v , w;
} E[1006] , T[MAXN] ; int cn = 0;
bool cmp( ed a , ed b ) { return a.w < b.w; }
int fa[MAXN];
vector w;
int find( int x ) { return x == fa[x] ? x : fa[x] = find( fa[x] ); }
int to[MAXN];
#define pii pair
#define fi first
#define se second
#define mp make_pair
int main() {
cin >> n >> m;
for( int i = 1 ; i <= m ; ++ i )
scanf("%d%d%d",&E[i].u,&E[i].v,&E[i].w);
sort( E + 1 , E + 1 + m , cmp );
for( int i = 1 ; i <= n ; ++ i ) fa[i] = i;
for( int i = 1 ; i <= m ; ++ i ) {
int u = find( E[i].u ) , v = find( E[i].v ) ;
if( u == v ) continue;
fa[u] = v;
T[++ cn] = E[i];
w.push_back( E[i].w );
}
w.erase( unique( w.begin() , w.end() ) , w.end() );
vector eds;
int ans = 1;
for( int a : w ) {
int cur = 0;
int sz = 0;
for( int i = 1 ; i <= n ; ++ i ) fa[i] = i;
for( int i = 1 ; i <= cn ; ++ i ) if( T[i].w != a )
fa[find( T[i].u )] = find( T[i].v );
for( int i = 1 ; i <= n ; ++ i ) if( find( i ) == i ) to[i] = ++ sz;
eds.clear();
for( int i = 1 ; i <= m ; ++ i ) {
int u = find( E[i].u ) , v = find( E[i].v );
if( u == v || E[i].w != a ) continue;
eds.emplace_back( mp( to[u] , to[v] ) );
}
int S = eds.size();
for( int i = 0 ; i < ( 1 << S ) ; ++ i ) {
int ok = 0;
for( int j = 1 ; j <= sz ; ++ j ) fa[j] = j;
for( int j = 0 ; j < S ; ++ j ) if( i & ( 1 << j ) ) {
if( find( eds[j].fi ) != find( eds[j].se ) )
fa[find(eds[j].fi)] = find(eds[j].se);
else { ok = 1; break; }
}
for( int j = 1 ; j <= sz ; ++ j ) if( find( j ) != find( 1 ) ) { ok = 1 ; break; }
if( !ok ) ++ cur;
}
ans = ans * cur % P;
}
cout << ans << endl;
}