hdu6829 Borrow

题目链接
思路:如果三者之和不能被3整除,显然无解。
设三者平均数为m,,从大到小为x,y,z 。那么我们先将最大的数拿x-m出来,因为对于次数来说,顺序没有任何意义。
接下来再枚举分到z上的数有多少,设分到z上的为t。那么三者就成了m,y+x-m-t,m+t
三者构成等差数列,等价于0,c,2c。
我们设 f c f_c fc为0 ,c,2c状态下的期望次数,那么显然 f 0 = 0 f_0=0 f0=0
2c要拿c个出来分给其他两个小的,分完之后,得到三个数又等价于一个新的状态,即得转移方程:
f x = x + ∑ i = 0 x ( i x ) 2 x f i f_x=x+\sum_{i=0}^{x} \frac{(^x_i)}{2^x}f_i fx=x+i=0x2x(ix)fi,因为组合数 ( i x ) = ( x − i x ) (^x_i)=(^x_{x-i}) (ix)=(xix)所以后面写成 f i f_i fi是 没关系的,代码中体现出来的更直观一点…

f x = x + ∑ i = 0 x − 1 ( i x ) 2 x f i + f x 2 x f_x=x+\sum_{i=0}^{x-1}\frac{(^x_i)}{2^x}f_i+\frac{f_x}{2^x} fx=x+i=0x12x(ix)fi+2xfx

2 x − 1 2 x f x = x + ∑ i = 0 x − 1 ( i x ) 2 x f i \frac{2^x-1}{2^x}f_x=x+\sum_{i=0}^{x-1}\frac{(^x_i)}{2^x}f_i 2x2x1fx=x+i=0x12x(ix)fi

2 x − 1 2 x f x = x + ∑ i = 0 x − 1 f i 2 x − 1 2 x \frac{2^x-1}{2^x}f_x=x+\sum_{i=0}^{x-1}f_i\frac{2^x-1}{2^x} 2x2x1fx=x+i=0x1fi2x2x1

f x = x 2 x 2 x − 1 + ∑ i = 0 x − 1 ( i x ) f i 2 x − 1 f_x=\frac{x2^x}{2^x-1}+\sum_{i=0}^{x-1}\frac{(^x_i)f_i}{2^x-1} fx=2x1x2x+i=0x12x1(ix)fi

打表发现, f x = 2 x f_x=2x fx=2x
那么就做完了…

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include 
using namespace std;
typedef long long LL;
const int N = 1e6 + 10;
#define fi first
#define se second
#define pb push_back
#define wzh(x) cerr<<#x<<'='<
int x, y, z;
int t;
LL fac[N], inv[N];
const LL mod = 998244353;
LL pm(LL x, LL y) {
	LL z = 1;
	while (y) {
		if (y & 1)z = z * x % mod;
		x = x * x % mod;
		y >>= 1;
	}
	return z;
}
LL s[N], f[N];
LL fa[N], in[N];
void P() {
	fa[0] = 1;
	for (int i = 1; i < N; i++) {
		fa[i] = fa[i - 1] * i % mod;
	}
	in[N - 1] = pm(fa[N - 1], mod - 2);
	for (int i = N - 2; i >= 0; i--)in[i] = in[i + 1] * (i + 1) % mod;
}
LL get(int x, int y) {
	return fa[x] * in[y] % mod * in[x - y] % mod;
}
int main() {
	ios::sync_with_stdio(false);
	fac[0] = 1;
	for (int i = 1; i < N; i++) {
		fac[i] = fac[i - 1] * 2 % mod;
		inv[i] = pm(fac[i] - 1, mod - 2);
	}
	P();
	for (cin >> t; t; t--) {
		cin >> x >> y >> z;
		if ((x + y + z) % 3) {
			cout << -1 << '\n';
		} else {
			int m = (x + y + z) / 3;
			if (y < z)swap(y, z);
			if (x < y)swap(x, y);
			if (y < z)swap(y, z);
			LL ans = 0;
			for (int i = 0; i <= (x - m); i++) {
				int now = min({z + i, m, y + x - m + i});
				vector<int>v;
				v.pb(z + i - now);
				v.pb(m - now);
				v.pb(y + x - m + i - now);
				sort(v.begin(), v.end());
				ans = ans + get(x - m, i) % mod * 2 * v[1] % mod;
				ans %= mod;
			}
			cout << ans*pm(fac[x - m], mod - 2) % mod + (x - m) % mod << '\n';
		}
	}
	return 0;
}

你可能感兴趣的:(概率dp)