拆点成边来建图 +BEST定理:ABC336G

https://www.luogu.com.cn/problem/AT_abc336_g

考虑一个状态 ( a , b , c , d ) (a,b,c,d) (a,b,c,d) 要出现 k k k 次,如果相当于每次加1个字符,相当于要从 ( a , b , c ) (a,b,c) (a,b,c) 走到 ( b , c , d ) (b,c,d) (b,c,d) k k k 次。因此我们就可以根据这样建图。

问题转化为求一个图的欧拉路径 / 欧拉回路条数。由于起终点相同的边没有本质区别,所以我们算出的答案需要除以排列。

对于欧拉回路,我们还有枚举初始时哪种边,而他们答案是一样的,所以我们乘上总边数即可。

#include
using namespace std;
#ifdef LOCAL
 #define debug(...) fprintf(stdout, ##__VA_ARGS__)
#else
 #define debug(...) void(0)
#endif
#define int long long
inline int read(){int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;
ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+
(x<<3)+(ch^48);ch=getchar();}return x*f;}
#define Z(x) (x)*(x)
#define pb push_back
#define fi first
#define se second
//#define M
#define mo 998244353
#define N 100
void Mod(int &a) { a = (a % mo + mo) % mo; }
int n, m, i, j, k, T, st, ed, sum;
int fac[1000010], a[N][N], cnt, nw, t, prod, chu[N], ru[N], u, v, i1, i2, i3, i4; 

int pw(int a, int b) {
	int ans = 1; 
	while(b) {
		if(b & 1) ans *= a; 
		a *= a; b >>= 1; 
		ans %= mo; a %= mo; 
	}
	return ans; 
}

int solve() {
	int ans = 1, t, flg = 0; 
	for(i = 1; i <= n; ++i) if(a[i][i]) break; 
	flg = i; 
//	if(i != n) ans = -1; 
//	for(j = 1; j <= n; ++j) swap(a[n][j], a[i][j]); 
//	for(i = 1; i <= n; ++i, debug("\n"))
//			for(j = 1; j <= n; ++j) debug("%2lld ", a[i][j]); 
	for(i = 1; i <= n; ++i) {
		if(!chu[i]) continue; 
		if(i == flg) continue; 
//		if(!flg) { flg = 1; continue; }
		for(k = i; k <= n; ++k) 
			if(a[k][i] && k != flg) break; 
		if(k != i) ans = -ans; 
		for(j = 1; j <= n; ++j) swap(a[i][j], a[k][j]); 
		t = pw(a[i][i], mo - 2); 
		for(j = 1; j <= n; ++j) {
			if(j == i) continue; 
			if(j == flg) continue; 
			nw = a[j][i] * t % mo; 
			for(k = 1; k <= n; ++k) if(k != flg) a[j][k] -= a[i][k] * nw % mo, Mod(a[j][k]); 
		}
	}
	for(i = 1; i <= n; ++i) 
		if(chu[i]) {
			if(i == flg) continue; 
//			if(flg) { flg = 0; continue; }
			ans *= a[i][i]; Mod(ans); 
		}
	return ans; 
}

signed main()
{
	#ifdef LOCAL
	  freopen("in.txt", "r", stdin);
	  freopen("out.txt", "w", stdout);
	#endif
//	srand(time(NULL));
//	T=read();
//	while(T--) {
//
//	}
	n = 8; 
	for(i = fac[0] = prod = 1; i <= 1e6; ++i) fac[i] = fac[i - 1] * i % mo; 
	for(i1 = 0; i1 < 2; ++i1) 
		for(i2 = 0; i2 < 2; ++i2)
			for(i3 = 0; i3 < 2; ++i3)
				for(i4 = 0; i4 < 2; ++i4) {
					k = read(); if(k) prod *= fac[k]; //, ++sum; 
					Mod(prod); sum += k; 
					u = i1 * 4 + i2 * 2 + i3 + 1; 
					v = i2 * 4 + i3 * 2 + i4 + 1; 
					a[u][v] -= k; 
					chu[u] += k; ru[v] += k; 
				}
	for(i = 1; i <= 8; ++i) {
		a[i][i] += chu[i]; 
		if(chu[i] == ru[i]) continue; 
		if(chu[i] - ru[i] == 1) {
			if(st) { return printf("0"), 0; }
			st = i; continue; 
		}
		if(ru[i] - chu[i] == 1) {
			if(ed) { return printf("0"), 0; }
			ed = i; continue; 
		}
		return printf("0"), 0;
	}
	if(st && ed) {
//		if(!a[n][n]) return -1; 
		++ru[st]; ++chu[ed]; 
		a[ed][ed]++; a[ed][st]--; 
//		for(i = 1; i <= n; ++i, debug("\n"))
//			for(j = 1; j <= n; ++j) debug("%2lld ", a[i][j]); 


		cnt = solve(); 
		debug(">> %lld\n", cnt); 
		for(i = 1; i <= n; ++i) 
			if(chu[i]) cnt *= fac[chu[i] - 1], Mod(cnt); 
		debug(">>》 %lld\n", cnt); 
		cnt *= pw(prod, mo - 2); Mod(cnt); 
		printf("%lld", cnt); 
		return 0; 
	}
	else {
		for(i = 1, j = 0; i <= 8; ++i) if(chu[i]) ++j; 
		if(j == 1) return printf("1"), 0; 
		cnt = solve(); 
		debug(">> %lld\n", cnt); 
		for(i = 1; i <= n; ++i) 
			if(chu[i]) cnt *= fac[chu[i] - 1], Mod(cnt); 
//		for(i = 1; i)
		debug(">>》 %lld\n", cnt); 
		cnt *= pw(prod, mo - 2); Mod(cnt); 
		cnt *= sum; Mod(cnt); 
		printf("%lld", cnt); 
		return 0; 
	}
	return -1;
}

你可能感兴趣的:(图论,BEST定理)