「ZJOI2019」打麻将-DP

Description

链接

Solution

考虑如何判断一个牌的集合是胡的。

  • 首先第一种牌型只需记录有多少中牌出现次数超过 2 2 2

  • 对于第二种牌型,可以用DP解决。设 f 0 / 1 , i , j , k f_{0/1,i,j,k} f0/1,i,j,k已经考虑了前 i i i种大小,表示之前是否预留了对子,之前预留了 j j j i − 1 , i i-1,i i1,i,以及 k k k i i i用于凑顺子(大小相邻的麻将牌)。每次转移只需枚举第 i + 1 i+1 i+1种牌有多少张即可。

可以发现这个 D P DP DP有效状态很少,只有 2092 2092 2092种,所以可以把它当成一个自动机,这样就可以识别所有种类的胡牌了。

然后考虑解决本问题,一个经典套路是把每个集合的贡献为每次摸牌后检查有是否胡了,有则结束,没有则加 1 1 1。也就是设 f ( i ) f(i) f(i)表示摸了 i i i张牌还没胡的集合数,答案为

∑ i = 1 4 n − 13 f ( i ) ⋅ i ! ⋅ ( 4 n − 13 − i ) ! ( 4 n − 13 ) ! \frac{\sum_{i=1}^{4n-13}f(i)\cdot i! \cdot (4n-13-i)!}{(4n-13)!} (4n13)!i=14n13f(i)i!(4n13i)!

f i , j , k f_{i,j,k} fi,j,k表示已经考虑了大小为 1.. i 1..i 1..i的麻将牌,已经摸了 j j j张牌,在自动机上的节点为 k k k。每次转移时枚举下一种牌有多少乘上组合数转移即可。

#include 
using namespace std;

inline void chkmax(int &a, int b) {if (a < b) a = b;}

inline int gi()
{
	char c = getchar();
	while(c < '0' || c > '9') c = getchar();
	int sum = 0;
	while('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
	return sum;
}

typedef long long ll;
const int maxn = 405, maxs = 2105, mod = 998244353;

struct node
{
	int f[3][3];

	node() {memset(f, -1, sizeof(f));}

	int* operator [] (const int a) {return f[a];}
	bool operator < (const node &a) const {
		for (int i = 0; i < 3; ++i)
			for (int j = 0; j < 3; ++j)
				if (f[i][j] != a.f[i][j]) return f[i][j] < a.f[i][j];
		return 0;
	}

	bool operator == (const node &a) const {
		for (int i = 0; i < 3; ++i)
			for (int j = 0; j < 3; ++j)
				if (f[i][j] != a.f[i][j]) return 0;
		return 1;
	}
	
};

struct mahjong
{

	node f, g;
	int cnt;

	mahjong() {f[0][0] = cnt = 0;}
	
	bool operator < (const mahjong &a) const {
		return f == a.f ? (g == a.g ? cnt < a.cnt : g < a.g) : f < a.f;
	}

	bool operator == (const mahjong &a) const {
		return f == a.f && g == a.g && cnt == a.cnt;
	}

	mahjong trans(int x)
	{
		mahjong ans;
		ans.cnt = min(7, cnt + (x >= 2));
		for (int i = 0; i < 3; ++i)
			for (int j = 0; j < 3; ++j) {
				if (~f[i][j]) {
					for (int k = 0; i + j + k <= x && k < 3; ++k)
						chkmax(ans.f[j][k], min(4, f[i][j] + i + ((x - i - j - k) >= 3)));
					if (x >= 2)
						for (int k = 0; i + j + k <= x - 2 && k < 3; ++k)
							chkmax(ans.g[j][k], min(4, f[i][j] + i));
				}
				if (~g[i][j]) {
					for (int k = 0; i + j + k <= x && k < 3; ++k)
						chkmax(ans.g[j][k], min(4, g[i][j] + i + ((x - i - j - k) >= 3)));
				}
			}
		return ans;
	}
	
} ;

map<mahjong, int> Id;
int tot, ch[maxs][5], h[maxs];
int n, s[maxn];
int fac[maxn], ifac[maxn], dp[2][maxn][maxs], *f[maxn], *g[maxn];

inline bool check(mahjong &s)
{
	if (s.cnt >= 7) return 1;
	for (int i = 0; i < 3; ++i)
		for (int j = 0; j < 3; ++j) if (s.g[i][j] >= 4) return 1;
	return 0;
}

int dfs(mahjong s)
{
	if (check(s)) return 0;
	int &t = Id[s];
	if (t) return t;
	t = ++tot;
	for (int i = 0; i <= 4; ++i)
		ch[t][i] = dfs(s.trans(i));
	return t;
}

inline int C(int n, int m)
{
	return (ll)fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}

inline int A(int n, int m)
{
	return (ll)fac[n] * ifac[n - m] % mod;
}

int main()
{
	dfs(mahjong());
	
	n = gi();
	for (int w, i = 1; i <= 13; ++i) w = gi(), gi(), ++s[w];

	fac[0] = 1; ifac[0] = ifac[1] = 1;
	for (int i = 1; i <= (n << 2); ++i) fac[i] = (ll)fac[i - 1] * i % mod;
	for (int i = 2; i <= (n << 2); ++i) ifac[i] = (ll)(mod - mod / i) * ifac[mod % i] % mod;
	for (int i = 1; i <= (n << 2); ++i) ifac[i] = (ll)ifac[i - 1] * ifac[i] % mod;

	for (int i = 0; i <= (n << 2); ++i) f[i] = dp[1][i], g[i] = dp[0][i];

	g[0][1] = 1;
	for (int i = 0, sum = 0; i < n; ++i, sum += s[i]) {
		swap(f, g);
		for (int j = 0; j <= (i << 2); ++j)
			for (int k = 1; k <= tot; ++k) g[j][k] = 0;
		for (int j = sum; j <= (i << 2); ++j)
			for (int k = 1; k <= tot; ++k) {
				if (!f[j][k]) continue;
				for (int t = s[i + 1]; t <= 4; ++t)
					if (ch[k][t]) g[j + t][ch[k][t]] = (g[j + t][ch[k][t]] + (ll)C(4 - s[i + 1], t - s[i + 1]) % mod * f[j][k]) % mod;
			}
	}

	int ans = 0;
	for (int i = 1; i <= (n << 2) - 13; ++i) {
		int sum = 0;
		for (int j = 1; j <= tot; ++j) sum = (sum + g[i + 13][j]) % mod;
		ans = (ans + (ll)sum * fac[i] % mod * fac[4 * n - 13 - i]) % mod;
	}
	printf("%lld\n", ((ll)ans * ifac[4 * n - 13] + 1) % mod);
	
	return 0;
}

你可能感兴趣的:(文章类型——题解,source——各省省选,算法——DP)