2020 年百度之星·程序设计大赛 - 复赛(Battle for Wosneth2-概率)

Problem Description
你在打游戏的时候碰到了如下问题:

​ 有两个人记作Alice和Bob,生命值分别是n,m,命中率分别为p%,q%。两个人轮流攻击对方,从Alice开始攻击,每次攻击的时候,如果命中,那么能让对方的生命值减低1,直到一方的生命值不超过0为止。

求到最后Alice的生命值大于0的概率,对998244353取模。

对于一个分数a/b,其中gcd(a,b)=1,那么我们认为这个分数对998244353取模的值为一个数c(0≤c<998244353)满足bc≡a(mod998244353)。

Input
第一行一个正整数T(1≤T≤104)表示数据组数。

对于每组数据,第一行四个整数n,m,p,q(1≤n,m≤105,1≤p,q≤100)。

保证∑(n+m)≤5×106。

Output
每组测试数据,输出一个数,表示答案。

Sample Input
3
1 1 50 50
100000 1 99 100
11 45 14 19

Sample Output
665496236
713582462
419834392

Hint

第一组数据,Alice活下来的概率为2/3。

第二组数据,当且仅当Alice前100000轮全部没有命中,Alice会死亡,所以存活概率为1-0.01100000

以一轮为单位,如果两个人都没打中,那么相当于“再来一次”,我们可以去掉这种情况。
于是分别令只有Alice打中,只有Bob打中,两人都打中的概率为 A = p % ( 1 − q % ) D , B = ( 1 − p % ) q % D , C = p % q % D , D = 1 − ( 1 − p % ) ( 1 − q % ) A=\frac{p\%(1-q\%)}{D},B=\frac{(1-p\%)q\%}{D},C=\frac{p\%q\%}{D},D=1-(1-p\%)(1-q\%) A=Dp%(1q%),B=D(1p%)q%,C=Dp%q%,D=1(1p%)(1q%)

最后一次必定为Alice击中,m事先减去1点
枚举C出现的次数 i i i,B出现的次数为 j j j
a n s = ( A + C ) ∑ i = 0 m i n ( n − 1 , m ) ∑ j = 0 n − 1 − i C i A m − i B j ( m + j ) ! i ! j ! ( m − i ) ! ans=(A+C)\sum_{i=0}^{min(n-1,m)}\sum_{j=0}^{n-1-i} C^iA^{m-i}B^j\frac{(m+j)!}{i!j!(m-i)!} ans=(A+C)i=0min(n1,m)j=0n1iCiAmiBji!j!(mi)!(m+j)!

#include
#include
#include
#include
#include
#include
#include
#include
#include
#include

using namespace std;
#define For(i,n) for(int i=1;i<=n;i++)
#define Fork(i,k,n) for(int i=k;i<=n;i++)
#define ForkD(i,k,n) for(int i=n;i>=k;i--)
#define Rep(i,n) for(int i=0;i
#define ForD(i,n) for(int i=n;i;i--)
#define RepD(i,n) for(int i=n;i>=0;i--)
#define Forp(x) for(int p=pre[x];p;p=next[p])
#define Forpiter(x) for(int &p=iter[x];p;p=next[p])  
#define Lson (o<<1)
#define Rson ((o<<1)+1)
#define MEM(a) memset(a,0,sizeof(a));
#define MEMI(a) memset(a,0x3f,sizeof(a));
#define MEMi(a) memset(a,128,sizeof(a));
#define MEMx(a,b) memset(a,b,sizeof(a));
#define INF (0x3f3f3f3f)
#define F (998244353)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define vi vector 
#define pi pair
#define SI(a) ((a).size())
#define Pr(kcase,ans1,ans2) printf("Case #%d:\n%d %d\n",kcase,ans1,ans2);
#define PRi(a,n) For(i,n-1) cout<
#define PRi2D(a,n,m) For(i,n) { \
						For(j,m-1) cout<
#pragma comment(linker, "/STACK:102400000,102400000")
#define ALL(x) (x).begin(),(x).end()
#define gmin(a,b) a=min(a,b);
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
ll mul(ll a,ll b){return (a*b)%F;}
ll add(ll a,ll b){return (a+b)%F;}
ll sub(ll a,ll b){return ((a-b)%F+F)%F;}
void upd(ll &a,ll b){a=(a%F+b%F)%F;}
inline int read()
{
	int x=0,f=1; char ch=getchar();
	while(!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();}
	while(isdigit(ch)) { x=x*10+ch-'0'; ch=getchar();}
	return x*f;
} 

ll pow2(ll a,int b,ll p=998244353)  //a^b mod p 
{  
    if (b==0) return 1%p;  
    if (b==1) return a%p;  
    ll c=pow2(a,b/2,p)%p;  
    c=c*c%p;  
    if (b&1) c=c*a%p;  
    return c%p;  
}  
ll inv(ll a,ll p=998244353) { //gcd(a,p)=1
	return pow2(a,p-2,p);
}
void updiv(ll &A,ll D){A=mul(A,inv(D));}
ll st(ll t,ll k){
	ll p=sub(1,pow2(t,k));
	updiv(p,sub(1,t));
	return p;
}
#define MAXN (512345)
ll inj[MAXN],jie[MAXN];
inline int C(int a,int b) {
	return (ll)jie[a]*inj[b]%F*inj[a-b]%F;
}
ll p=F;
void pre(int n) {
	jie[0]=1;For(i,n) jie[i]=mul(jie[i-1],i);
	inj[0]=inj[1]=1;Fork(i,2,n) inj[i]=(F-(F/i))*inj[F%i]%F;
	For(i,n) inj[i]=mul(inj[i],inj[i-1]);  
} 
ll s[MAXN];
ll cp[MAXN],ap[MAXN],bp[MAXN];
int main()
{
//	freopen("a.in","r",stdin);
//	freopen(".out","w",stdout);
	pre(MAXN-1);
	ap[0]=bp[0]=cp[0]=s[0]=1;
	int T=read();
	For(kcase,T) {
		ll n=read(),m=read(),p=read(),q=read(),A,B,C,D;
		p=mul(p,inv(100)),q=mul(q,inv(100));
		
		A=mul(p,sub(1,q));
		B=mul(sub(1,p),q);
		C=mul(p,q);
		D=add(add(A,B),C);
		updiv(A,D);
		updiv(B,D);
		updiv(C,D);
		m--;
		
		For(i,m) ap[i]=mul(ap[i-1],A);
		For(i,n) bp[i]=mul(bp[i-1],B);
		For(i,max(m,n)) cp[i]=mul(cp[i-1],C);
		Fork(j,0,n) s[j]=mul(bp[j],mul(jie[m+j],inj[j]));
		For(j,n) s[j]=add(s[j-1],s[j]);
		ll su=0;
		Fork(i,0,min(n-1,m)) {
			ll pt=mul(cp[i],ap[m-i]);
			pt=mul(pt,mul(inj[i],inj[m-i]));
			pt=mul(pt,s[n-i-1]);	
			upd(su,pt);		
		}
		su=mul(su,add(A,C));
		
		cout<<su<<endl;
	}
	return 0;
}

你可能感兴趣的:(组合数学)