【概率DP+常系数线性齐次递推+NTT】BZOJ4944 NOI2017泳池

【题目大意】
给定一块底边长为 n n ,高度为1001的矩形,矩形的每个格子有 q q 的概率是安全的, 1q 1 − q 的概率是危险的。一个子矩形是合法的当且仅当这个子矩形的下底边贴着大矩形的底边且子矩形内所有格子都是安全的。问最大合法子矩形的面积为 k k 的概率是多少。 n1e9,k1000 n ≤ 1 e 9 , k ≤ 1000 .

【解题思路】
听这题听说了几次了,刚好没有什么安排就来做一下。下面的推导次数界可能会有点小问题,不过不太影响。

首先题目要求的是最大安全矩形面积恰好 k k 的概率,我们可以计算最大子矩形面积不超过 k k 和不超过 k1 k − 1 的答案做差。

gi,j g i , j 表示高度为 i i ,长度为 j j 的海域都是安全的,剩下部分未知(最大子矩形面积 k ≤ k )的概率。
hi,j h i , j 表示高度为 i+1 i + 1 ,长度为 j j 的海域前 i i 行都是安全的,且 (i+1,j) ( i + 1 , j ) 这个位置是危险的,剩下部分未知(最大子矩形面积 k ≤ k )的概率。
我们有边界:

gk,1=qk(11),gi,0=1,hi,0=1 g k , 1 = q k ( 1 − 1 ) , g i , 0 = 1 , h i , 0 = 1

k1 k − 1 到1进行 DP D P ,对于 (i,j) ( i , j ) 这个点,枚举 i+1 i + 1 行的下一个危险格子在哪里,然后进行转移。
gi,j=k=0jhi,kgi+1,jkhi,j=k=0j1hi,kgi+1,jk1qi(1q) g i , j = ∑ k = 0 j h i , k g i + 1 , j − k h i , j = ∑ k = 0 j − 1 h i , k g i + 1 , j − k − 1 q i ( 1 − q )

因为第 i i 行的宽度不会超过 ki ⌊ k i ⌋ ,即暴力 DP D P 的复杂度应该就是
i=1kki2=O(k2) ∑ i = 1 k ⌊ k i ⌋ 2 = O ( k 2 )

这部分预处理已经可以满足要求了。

下面考虑答案的计算。
fi f i 为前 i i 列最大子矩形 k ≤ k 的概率,那么

fi=j=1kfijg1,j1(1q) f i = ∑ j = 1 k f i − j g 1 , j − 1 ( 1 − q )

我们令 ai=g1,i1(1q) a i = g 1 , i − 1 ( 1 − q ) ,那么这个就是一个常系数线性递推的形式
fi=j=1kajfij f i = ∑ j = 1 k a j f i − j

然后矩阵快速幂是 O(k3logk) O ( k 3 l o g k ) 的,我们考虑用特征多项式做到更优秀的复杂度。

我们矩阵乘法的转移矩阵为 A A ,我们只关注 An A n 是什么而不关注中间的项。
我们将 A A 看成变量,构造一个多项式 B B 满足

An=i=0k1biAi A n = ∑ i = 0 k − 1 b i ∗ A i

如果我们将初始矩阵设为 St S t ,我们将上面等式的两边同时乘上 St S t ,因为我们只关注最后矩阵的第0项,而上面这个等式在只取第0项的时候也是成立的,那么我们最终可以得到
Ans=i=0k1biSti A n s = ∑ i = 0 k − 1 b i S t i

也就是说,事实上 StB S t ∗ B 就是最后的答案。

于是现在我们要构造出这个 B B

如果我们将 A A 写成 A=Q(A)g(A)+R(A) A = Q ( A ) g ( A ) + R ( A ) 的形式,我们钦点 g(A) g ( A ) 的次数为 k k ,如果此时 g(A)=0 g ( A ) = 0 ,则事实上 R=B R = B (因为 R R 的次数一定比 g g 低,我们可以将 R R 写成上面的多个幂次的求和形式)。

现在问题就是求 B=An(mod g(A)) B = A n ( m o d   g ( A ) ) ,因为是在模意义下,而我们知道 A A k1 k − 1 次下的答案,所以这个是可以通过多项式快速幂,多项式取模来做到 O(klogklogn) O ( k log ⁡ k log ⁡ n ) 的,对于这题可以直接暴力取模 O(k2logn) O ( k 2 log ⁡ n ) 做到。

现在的问题是构造一个 g g 出来。
根据Cayley-Hamilton定理, |λIA| | λ I − A | 是一个关于 λ λ k k 次多项式( I I 是单位矩阵),记为 g(λ) g ( λ ) ,且对于任意矩阵 A A 都有 g(A)=0 g ( A ) = 0
对于上面的这个式子,我们有一个结论: g(λ)=λkki=1aiλki g ( λ ) = λ k − ∑ i = 1 k a i λ k − i ,其中 k k 是矩阵 A A 的大小, ai a i 就是 A A 的第 i i 项。
然后这个东西我们直接算就是 O(k) O ( k ) 的了。

因此整个计算答案的部分我们已经可以做到 O(klogklogn) O ( k log ⁡ k log ⁡ n ) ,远低于题目要求的范围。
但是我们对 g,h g , h DP D P 复杂度都已经是 O(k2) O ( k 2 ) 的了,我们是不是很亏?所以我们要想办法优化前面的复杂度。

我们试着将两个 DP D P 结果写成生成函数的形式,设

Ai(x)=j0gi,jxjBi(x)=j0hi,jxjci=qi(1q) A i ( x ) = ∑ j ≥ 0 g i , j x j B i ( x ) = ∑ j ≥ 0 h i , j x j c i = q i ( 1 − q )

那么

Ai(x)=Bi(x)Ai+1(x)Bi(x)=cixAi+1(x)Bi(x)+1Bi(x)=11cixAi+1(x) A i ( x ) = B i ( x ) A i + 1 ( x ) B i ( x ) = c i x A i + 1 ( x ) B i ( x ) + 1 B i ( x ) = 1 1 − c i x A i + 1 ( x )

于是对于 k1 k − 1 到1行,我们每一行都可以用多项式求逆来计算当前 DP D P 值,所以复杂度是:

i=1kkilogki=O(klog2k) ∑ i = 1 k ⌊ k i ⌋ log ⁡ ⌊ k i ⌋ = O ( k log 2 ⁡ k )

于是最终我们得到了一个比较优秀的总做法,总复杂度是: O(klog2k+klogklogn) O ( k log 2 ⁡ k + k log ⁡ k log ⁡ n ) ,常数极大,未必跑得过将一个 logk log ⁡ k 换成 k k 的暴力(但实测确实复杂度优秀的算法跑得会快很多),于是我们可以自适应一下,在 k k 比较小的时候暴力卷积,在比较大的时候用 NTT N T T 应该会得到不错的效果。
当然谁会去这么无聊分开写呢。

【参考代码】

#include
using namespace std;

typedef long long LL;
const int N=3e5+10,M=N*4;
const LL mod=998244353,g=3,inv2=(mod+1)>>1;

LL qpow(LL x,LL y) {LL ret=1;for(;y;y>>=1,x=x*x%mod)if(y&1)ret=ret*x%mod;return ret;}
void up(LL &x,LL y) {x+=y;if(x>=mod)x-=mod;if(x<0)x+=mod;}

namespace NTT
{
    int n,L,rev[N];
    LL w1[N],w2[N],d[N],e[N],Q[N],P[N];
    LL f[N],x[N],y[N],z[N];

    void init(int m)
    {
        for(n=1,L=0;n<m;n<<=1,++L);
        for(int i=2;i<=n;i<<=1) 
            w1[i>>1]=qpow(g,(mod-1)/i),w2[i>>1]=qpow(w1[i>>1],mod-2);
        rev[0]=0;
        for(int i=1;i>1]>>1)|((i&1)<<(L-1));
    }

    void ntt(LL *a,int f)
    {
        for(int i=0;iif(i>rev[i]) swap(a[i],a[rev[i]]);
        for(int i=1;i1)
        {
            LL wn=(f==1?w1[i]:w2[i]);
            for(int j=0;j1))
            {
                LL w=1;
                for(int k=0;k*wn%mod)
                {
                    LL x=a[j+k],y=w*a[i+j+k]%mod;
                    a[j+k]=(x+y)%mod;a[i+j+k]=(x-y+mod)%mod;
                }
            }
        }
        if(!~f) for(int i=0,inv=qpow(n,mod-2);i*inv%mod;
    }

    void clear(LL *a,LL *b,int m)
    {
        for(int i=0;i<m;++i) a[i]=b[i];
        for(int i=m;i0;
    }

    void inverse(LL *a,LL *b,int m)
    {
        if(m==1) {b[0]=qpow(a[0],mod-2);return;}
        inverse(a,b,m>>1); init(m<<1);
        clear(x,a,m);clear(y,b,m>>1);
        ntt(x,1);ntt(y,1);
        for(int i=0;ix[i]=y[i]*(2-x[i]*y[i]%mod)%mod,up(x[i],mod);
        ntt(x,-1);
        for(int i=0;i<m;++i) b[i]=x[i];
    }

    void module(LL *a,LL *b,LL *c,int n1,int n2)
    {
        int k=1; while(k<=n1-n2+1) k<<=1; k<<=1;
        for(int i=0;i<=n1;++i) d[i]=a[i];
        for(int i=0;i<=n2;++i) e[i]=b[i];
        reverse(d,d+n1+1); reverse(e,e+n2+1);
        for(int i=n1-n2+1;i0;
        inverse(e,f,k>>1);
        for(int i=n1-n2+1;i0;
        init(k); ntt(d,1); ntt(f,1);
        for(int i=0;i*f[i]%mod;
        ntt(e,-1);
        for(int i=0;i<=n1-n2;++i) c[i]=e[i];
        reverse(c,c+n1-n2+1);
    }

    void mul(LL *a,LL *b,LL *c,int n)
    {
        int k=1; while(k<=n) k<<=1; k<<=1; 
        for(int i=0;i0;
        for(int i=0;i<=n;++i) Q[i]=a[i],P[i]=b[i];
        init(k);ntt(Q,1);ntt(P,1);
        for(int i=0;i*P[i]%mod;
        ntt(Q,-1);
        for(int i=0;i0;
        int n2=k-1; while(!Q[n2]) --n2;
        module(Q,c,P,n2,n);
        for(int i=0;ifor(int i=0;i<(k>>1);++i) Q[i]=c[i];
        for(int i=(k>>1);i0;
        init(k); ntt(Q,1); ntt(P,1);
        for(int i=0;i*P[i]%mod;
        ntt(Q,-1);
        for(int i=0;i*a,LL *b,LL *c,int m,int n)
    {
        if(!n) return;
        powmod(a,b,c,m,n>>1); 
        mul(a,a,c,m); if(n&1) mul(a,b,c,m);
    }
}

namespace SOL
{
    LL n,K,X,Y,q,q2,ans;
    LL fac1[N],fac2[N];
    LL g[2][N],h[N],fin[N];
    LL a[N],b[N],c[N],d[N],e[N],f[N];

    int DP(LL K)
    {
        int now=1,las=0;
        memset(g,0,sizeof(g));memset(h,0,sizeof(h));
        h[0]=1;g[0][0]=1;g[0][1]=q2*fac1[K]%mod;
        for(int i=K-1;i;--i,now^=1,las^=1)
        {
            LL dt=K/i,ct=q2*fac1[i]%mod,m=1;
            while(m<=dt) m<<=1;
            e[0]=1; for(int j=1;j<m;++j) e[j]=-ct*g[las][j-1];
            NTT::inverse(e,h,m); m<<=1;
            for(int j=dt+1;j<m;++j) h[j]=0;
            NTT::init(m); NTT::ntt(g[las],1); NTT::ntt(h,1);
            for(int j=0;j<m;++j) g[now][j]=g[las][j]*h[j]%mod;
            NTT::ntt(g[now],-1);
            for(int j=dt+1;j<m;++j) g[now][j]=0;
        }
        memset(a,0,sizeof(a));
        a[0]=1; for(int i=1;i<=K+1;++i) a[i]=-g[las][i-1]*q2%mod;
        return las;
    }

    LL solve(LL K)
    {
        if(K==0) return qpow(1-q+mod,n);
        NTT::init(K);int las=DP(K);

        LL ret=0,m=1,pw=n-K; while(m<=K+1) m<<=1;
        NTT::inverse(a,f,m<<1);
        if(n<=(K+1)<<1)
        {
            for(int i=0;i<=n && i<=K;++i) up(ret,f[n-i]*g[las][i]%mod);
            return ret;
        }

        memset(a,0,sizeof(a));memset(c,0,sizeof(c));memset(d,0,sizeof(d));
        a[K+1]=1; for(int i=0;i<=K;++i) a[i]=-g[las][K-i]*q2%mod,up(a[i],mod);
        if(K) c[1]=1; else c[0]=-a[0]; d[0]=1;
        NTT::powmod(d,c,a,K+1,pw); reverse(d,d+K+1);
        NTT::init(m<<2); NTT::ntt(d,1); NTT::ntt(f,1);
        for(int i=0;i<m<<2;++i) fin[i]=d[i]*f[i]%mod;
        NTT::ntt(fin,-1);
        for(int i=0,j=K<<1;i<=K;++i) up(ret,g[las][i]*fin[j-i]%mod);
        return ret;
    }   

    void Dream_Lolita()
    {
        scanf("%lld%lld%lld%lld",&n,&K,&X,&Y);
        q=X*qpow(Y,mod-2)%mod;q2=(Y-X)*qpow(Y,mod-2)%mod;
        fac1[0]=fac2[0]=1;
        for(int i=1;i<=K;++i) 
            fac1[i]=fac1[i-1]*q%mod,fac2[i]=fac2[i-1]*q2%mod;
        up(ans,solve(K));up(ans,-solve(K-1));
        printf("%lld\n",ans);
    }
}

int main()
{
#ifndef ONLINE_JUDGE
    freopen("BZOJ4944.in","r",stdin);
    freopen("BZOJ4944.out","w",stdout);
#endif
    SOL::Dream_Lolita();

    return 0;
}

【总结】
论常系数线性齐次递推中,多项式技巧是如何优化矩阵快速幂的。

你可能感兴趣的:(数论-FFT/NTT,DP-概率与期望)