BZOJ
据说当场10+的人数很少,虽然是道好题,但是不是毒瘤过头了啊QAQ
恰好面积为 k k k 并不好处理,不妨求面积小于等于 k k k,则最后答案为 P ( k ) − P ( k − 1 ) P(k)-P(k-1) P(k)−P(k−1)。
为了方便描述,我们约定从最下方的行开始编号为第1行, q q q 表示块安全的概率。
考虑第1行,由于面积都要小于等于 k k k,则第一行不会连续出现 k k k 个以上的安全方块,那么我们可以根据这个来进行转移,设 f n f_n fn 表示前 n n n 列合法的概率。基于求最大矩形面积的方法,即从行最小的不安全块开始,向两边递归求最大的面积,因此我们可以设状态 d p i , j dp_{i,j} dpi,j 表示 i i i 列的矩阵,且行数最小的不安全块的行数为 j + 1 j+1 j+1 的合法概率
f n = ∑ i = 1 k + 1 f n − i ( 1 − q ) ( ∑ j = 1 ⌊ k i ⌋ d p i − 1 , j ) f_n=\sum_{i=1}^{k+1} f_{n-i} (1-q) (\sum_{j=1}^{\lfloor \frac k i \rfloor}dp_{i-1,j}) fn=i=1∑k+1fn−i(1−q)(j=1∑⌊ik⌋dpi−1,j)
后面那串可以看做是常系数递推的常数,用多项式取模的方法可以优化至 O ( k 2 log n ) O(k^2\log n) O(k2logn),那么问题就在于求 d p dp dp
考虑枚举第 j + 1 j+1 j+1 行的最靠右的不安全块的位置,则有
d p i , j = ( 1 − q ) q j ∑ r = 1 i ( ∑ x ≥ j d p r − 1 , x ) ( ∑ x > j d p i − r , x ) dp_{i,j}=(1-q)q^j\sum_{r=1}^i (\sum_{x\geq j}dp_{r-1,x})(\sum_{x>j}dp_{i-r,x}) dpi,j=(1−q)qjr=1∑i(x≥j∑dpr−1,x)(x>j∑dpi−r,x)
显然 i j ≤ k ij\leq k ij≤k,因此总的状态数仅有 O ( k ln k ) O(k\ln k) O(klnk) 个,用后缀和优化一下即可做到 O ( k 2 ln k ) O(k^2\ln k) O(k2lnk)
总的时间复杂度为 O ( k 2 ln k + k 2 log n ) O(k^2\ln k+k^2\log n) O(k2lnk+k2logn)
#include
#include
using namespace std;
typedef long long ll;
const int maxn=1010,mod=998244353;
template <typename Tp> inline int getmin(Tp &x,Tp y){return y<x?x=y,1:0;}
template <typename Tp> inline int getmax(Tp &x,Tp y){return y>x?x=y,1:0;}
template <typename Tp> inline void read(Tp &x)
{
x=0;int f=0;char ch=getchar();
while(ch!='-'&&(ch<'0'||ch>'9')) ch=getchar();
if(ch=='-') f=1,ch=getchar();
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
if(f) x=-x;
}
int n,k,q,p,tx,ty,pw[maxn],dp[maxn][maxn],a[maxn];
int f[maxn],g[maxn],x[maxn],y[maxn],tmp[maxn<<1];
int pls(int x,int y){return x+y>=mod?x+y-mod:x+y;}
int dec(int x,int y){return x<y?x-y+mod:x-y;}
int power(int x,int y)
{
int res=1;
for(;y;y>>=1,x=(ll)x*x%mod)
if(y&1)
res=(ll)res*x%mod;
return res;
}
void mul(int *a,int *b,int *c,int k)
{
memset(tmp,0,sizeof(tmp));
for(int i=0;i<k;i++)
for(int j=0;j<k;j++)
tmp[i+j]=((ll)a[i]*b[j]+tmp[i+j])%mod;
for(int i=k+k;i>=k;i--)
if(tmp[i])
for(int j=k;j;j--) tmp[i-j]=dec(tmp[i-j],(ll)tmp[i]*g[k-j]%mod);
for(int i=0;i<k;i++) c[i]=tmp[i];
}
int solve(int k)
{
if(!k) return power(p,n);
int res=0;
memset(dp,0,sizeof(dp));
for(int i=1;i<=k+1;i++) dp[0][i]=1;
for(int i=1;i<=k;i++)
for(int j=k/i;j;j--)
{
for(int r=1;r<=i;r++)
dp[i][j]=((ll)dp[r-1][j]*dp[i-r][j+1]+dp[i][j])%mod;
dp[i][j]=(ll)dp[i][j]*p%mod*pw[j]%mod;
dp[i][j]=pls(dp[i][j],dp[i][j+1]);
}
for(int i=0;i<=k;i++){a[i+1]=(ll)dp[i][1]*p%mod;f[i]=dp[i][1];}
for(int i=1;i<=k;i++)
for(int j=1;j<=i&&j<=k+1;j++) f[i]=((ll)f[i-j]*a[j]+f[i])%mod;
memset(x,0,sizeof(x));memset(y,0,sizeof(y));
g[++k]=x[0]=y[1]=1;
for(int i=1;i<=k;i++) g[k-i]=dec(0,a[i]);
for(int i=n;i;i>>=1,mul(y,y,y,k))
if(i&1)
mul(x,y,x,k);
for(int i=0;i<k;i++) res=((ll)x[i]*f[i]+res)%mod;
return res;
}
int main()
{
read(n);read(k);read(tx);read(ty);
q=(ll)tx*power(ty,mod-2)%mod;p=dec(1,q);pw[0]=1;//q safe p dangerous
for(int i=1;i<=k;i++) pw[i]=(ll)pw[i-1]*q%mod;
printf("%d\n",dec(solve(k),solve(k-1)));
return 0;
}