【BZOJ4944】【NOI2017】泳池 概率DP 常系数线性递推 特征多项式 多项式取模

题目大意

  有一个 1001×n 1001 × n 的的网格,每个格子有 q q 的概率是安全的, 1q 1 − q 的概率是危险的。

  定义一个矩形是合法的当且仅当:

  • 这个矩形中每个格子都是安全的
  • 必须紧贴网格的下边界

  问你最大的合法子矩形大小为 k k 的概率是多少。

   n109,k1000 n ≤ 10 9 , k ≤ 1000

  吉老师:这题本来是 k20000 k ≤ 20000

题解

  一道好题。

  我们计算最大子矩形不超过 i i 的答案 si s i ,那么答案就是 sksk1 s k − s k − 1

  显然最后一行连续的安全格子不会超过 k k 个。

  设 gi,j g i , j 表示长度为 j j ,高度为 i i 的海域全部是安全的,剩下的部分未知,最大子矩形 k ≤ k 的概率。

  设 hi,j h i , j 表示长度为 j j ,高度为 i+1 i + 1 的海域中,前 i i 行全部是安全的,剩下的未知且 (i+1,j) ( i + 1 , j ) 是危险的,最大子矩形 k ≤ k 的概率。

  边界:

gk,1gi,0hi,0=qk(1q)=1=1(1)(2)(3) (1) g k , 1 = q k ( 1 − q ) (2) g i , 0 = 1 (3) h i , 0 = 1

  那么我们从 k1 k − 1 1 1 DP,对于 i i j j 列,枚举第 i+1 i + 1 行的下一个危险的格子在哪个地方,然后转移:
gi,jhi,j=k=0jhi,kgi+1,jk=k=0j1hi,kgi+1,jk1qi(1q)(4)(5) (4) g i , j = ∑ k = 0 j h i , k g i + 1 , j − k (5) h i , j = ∑ k = 0 j − 1 h i , k g i + 1 , j − k − 1 q i ( 1 − q )

  因为第 i i 行的宽度不会超过 ki ⌊ k i ⌋ ,所以的暴力的时间复杂度是 ki=1ki2=O(k2) ∑ i = 1 k ⌊ k i ⌋ 2 = O ( k 2 )

  这已经足够了,但我们可以做的更好。

  设

Ai(x)Bi(x)ci=j0gi,jxj=j0hi,jxj=qi(1q)(6)(7)(8) (6) A i ( x ) = ∑ j ≥ 0 g i , j x j (7) B i ( x ) = ∑ j ≥ 0 h i , j x j (8) c i = q i ( 1 − q )

那么
Ai(x)Bi(x)Bi(x)=Bi(x)Ai+1(x)=cixAi+1(x)Bi(x)+1=11cixAi+1(x)(9)(10)(11) (9) A i ( x ) = B i ( x ) A i + 1 ( x ) (10) B i ( x ) = c i x A i + 1 ( x ) B i ( x ) + 1 (11) B i ( x ) = 1 1 − c i x A i + 1 ( x )

  时间复杂度是 ki=1kilogki=O(klog2k) ∑ i = 1 k ⌊ k i ⌋ log ⁡ ⌊ k i ⌋ = O ( k log 2 ⁡ k )

  设 fi f i 为前 i i 列最大子矩形 k ≤ k 的概率,那么

fi=j=1kfij1g1,j(1q) f i = ∑ j = 1 k f i − j − 1 g 1 , j ( 1 − q )

  这就是一个常系数线性递推。
aifi=g1,i1(1q)=j=1kfijaj(12)(13) (12) a i = g 1 , i − 1 ( 1 − q ) (13) f i = ∑ j = 1 k f i − j a j

  时间复杂度:

  • 暴力: O(nk) O ( n k ) 70 70 pts
  • 矩阵快速幂: O(k3logn) O ( k 3 log ⁡ n ) 90 90 pts
  • 特征多项式+暴力: O(k2logn) O ( k 2 log ⁡ n ) 100 100 pts
  • 特征多项式+NTT取模: O(klogklogn) O ( k log ⁡ k log ⁡ n ) 100 100 pts

  这里简单讲一下最后一个做法

  矩阵快速幂是给你一个矩阵 A A ,求 (An)1,1 ( A n ) 1 , 1

  设矩阵的大小为 k k

  根据Cayley-Hamilton定理, |λIA| | λ I − A | 是一个关于 λ λ k k 次多项式,记为 g(λ) g ( λ ) 。对于任意矩阵 A A ,有 g(A)=0 g ( A ) = 0

  对于常系数线性递推的矩阵,设 fi=kj=1fijaj f i = ∑ j = 1 k f i − j a j g(λ)=λkki=1aiλki g ( λ ) = λ k − ∑ i = 1 k a i λ k − i

  所以我们只需要求 Anmodg(A) A n mod g ( A ) 。可以用快速幂(倍增取模)求解。

  然后还要求出 f1fk f 1 … f k ,可以通过其他方法计算(多项式求逆或者题目给你了)。

  最后一次卷积可以得到答案。

  如果要求 fnk+1fn f n − k + 1 … f n ,那就把 f1f2k f 1 … f 2 k 带进去卷积。

  总时间复杂度: O(klog2k+klogklogn) O ( k log 2 ⁡ k + k log ⁡ k log ⁡ n )

代码

  暴力取模

#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair pll;
void sort(int &a,int &b)
{
    if(a>b)
        swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
    char str[100];
    sprintf(str,"%s.in",s);
    freopen(str,"r",stdin);
    sprintf(str,"%s.out",s);
    freopen(str,"w",stdout);
#endif
}
int rd()
{
    int s=0,c;
    while((c=getchar())<'0'||c>'9');
    do
    {
        s=s*10+c-'0';
    }
    while((c=getchar())>='0'&&c<='9');
    return s;
}
int upmin(int &a,int b)
{
    if(breturn 1;
    }
    return 0;
}
int upmax(int &a,int b)
{
    if(b>a)
    {
        a=b;
        return 1;
    }
    return 0;
}
ll p=998244353;
void add(ll &a,ll b)
{
    a=(a+b)%p;
}
ll fp(ll a,ll b)
{
    ll s=1;
    for(;b;b>>=1,a=a*a%p)
        if(b&1)
            s=s*a%p;
    return s;
}
ll inv(ll a)
{
    return fp(a,p-2);
}
ll pw1[1010];
ll pw2[1010];
ll q;
ll q2;
ll g[1010][1010];
ll h[1010][1010];
ll f[2010];
ll a[2010];
ll c[2010];
ll d[2010];
ll final[2010];
void mul(ll *a,ll *b,ll *e,int len)
{
    static ll c[2010];
    int i,j;
    for(i=0;i<=2*len;i++)
        c[i]=0;
    for(i=0;i<=len;i++)
        for(j=0;j<=len;j++)
            add(c[i+j],a[i]*b[j]);
    for(i=2*len;i>=len;i--)
    {
        ll v=c[i]*inv(e[len]);
        if(v)
            for(j=0;j<=len;j++)
                c[i-len+j]=(c[i-len+j]-e[j]*v)%p;
    }
    for(i=0;i<=len;i++)
        a[i]=c[i];
}
ll solve(int n,int k)
{
    if(!k)
        return fp(q2,n);
    memset(g,0,sizeof g);
    memset(h,0,sizeof h);
    g[k][1]=q2*pw1[k]%p;
    g[k][0]=1;
    int i,j,l;
    for(i=k-1;i>=1;i--)
    {
        int m=k/i;
        g[i][0]=1;
        h[i][0]=1;
        for(j=0;j<=m;j++)
        {
            for(l=j+1;l<=m;l++)
                add(h[i][l],h[i][j]*g[i+1][l-j-1]%p*q2%p*pw1[i]%p);
            for(l=j;l<=m;l++)
                if(l)
                    add(g[i][l],h[i][j]*g[i+1][l-j]%p);
        }
    }
    memset(f,0,sizeof f);
    f[0]=1;
    for(i=1;i<=2*(k+1);i++)
        for(j=0;j1]*q2%p*g[1][j]);
    if(n<=2*(k+1))
    {
        ll s=0;
        for(i=0;i<=n&&i<=k;i++)
            add(s,f[n-i]*g[1][i]);
        return s;
    }
    int len=k+1;
    for(i=0;i*g[1][len-i-1]%p;
    a[len]=1;
    memset(c,0,sizeof c);
    c[1]=1;
    memset(d,0,sizeof d);
    d[0]=1;
    int m=n-k-1;
    while(m)
    {
        if(m&1)
            mul(d,c,a,len);
        mul(c,c,a,len);
        m>>=1;
    }
    memset(final,0,sizeof final);
    for(i=1;i<=k+1;i++)
        for(j=0;j<=k;j++)
            add(final[i],d[j]*f[i+j]);
    ll s=0;
    for(i=1;i<=k+1;i++)
        add(s,final[i]*g[1][k+1-i]);
    return s;
}
int main()
{
    open("bzoj4944");
    int n,k,x,y;
    scanf("%d%d%d%d",&n,&k,&x,&y);
    q=x*inv(y)%p;
    q2=(y-x)*inv(y)%p;
    pw1[0]=pw2[0]=1;
    int i;
    for(i=1;i<=k;i++)
    {
        pw1[i]=pw1[i-1]*q%p;
        pw2[i]=pw2[i-1]*q2%p;
    }
    ll ans1=solve(n,k);
    ll ans2=solve(n,k-1);
    ll ans=((ans1-ans2)%p+p)%p;
    printf("%lld\n",ans);
    return 0;
}

  NTT取模

#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair pll;
void sort(int &a,int &b)
{
    if(a>b)
        swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
    char str[100];
    sprintf(str,"%s.in",s);
    freopen(str,"r",stdin);
    sprintf(str,"%s.out",s);
    freopen(str,"w",stdout);
#endif
}
int rd()
{
    int s=0,c;
    while((c=getchar())<'0'||c>'9');
    do
    {
        s=s*10+c-'0';
    }
    while((c=getchar())>='0'&&c<='9');
    return s;
}
int upmin(int &a,int b)
{
    if(breturn 1;
    }
    return 0;
}
int upmax(int &a,int b)
{
    if(b>a)
    {
        a=b;
        return 1;
    }
    return 0;
}
const ll p=998244353;
const int maxn=300000;
ll fp(ll a,ll b)
{
    ll s=1;
    for(;b;b>>=1,a=a*a%p)
        if(b&1)
            s=s*a%p;
    return s;
}
namespace ntt
{
    const ll g=3;
    ll w1[maxn];
    ll w2[maxn];
    int rev[maxn];
    int n;
    void init(int m)
    {
        n=1;
        while(n<m)
            n<<=1;
        int i;
        for(i=2;i<=n;i<<=1)
        {
            w1[i]=fp(g,(p-1)/i);
            w2[i]=fp(w1[i],p-2);
        }
        rev[0]=0;
        for(i=1;i>1]>>1)|((i&1)*(n>>1));
    }
    void ntt(ll *a,int t)
    {
        int i,j,k;
        ll u,v,w,wn;
        for(i=0;iif(rev[i]for(i=2;i<=n;i<<=1)
        {
            wn=(t==1?w1[i]:w2[i]);
            for(j=0;j1;
                for(k=j;k2;k++)
                {
                    u=a[k];
                    v=a[k+i/2]*w%p;
                    a[k]=(u+v)%p;
                    a[k+i/2]=(u-v)%p;
                    w=w*wn%p;
                }
            }
        }
        if(t==-1)
        {
            u=fp(n,p-2);    
            for(i=0;i*u%p;
        }
    }
    ll x[maxn];
    ll y[maxn];
    ll z[maxn];
    void copy_clear(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
        for(i=m;i0;
    }
    void copy(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
    }
    void mul(ll *a,ll *b,ll *c,int m)
    {
        init(m<<1);
        copy_clear(x,a,m);
        copy_clear(y,b,m);
        ntt(x,1);
        ntt(y,1);
        int i;
        for(i=0;ix[i]=x[i]*y[i]%p;
        ntt(x,-1);
        copy(c,x,m);
    }
    void inverse(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            b[0]=fp(a[0],p-2);
            return;
        }
        inverse(a,b,m>>1);
        init(m<<1);
        copy_clear(x,a,m);
        copy_clear(y,b,m>>1);
        ntt(x,1);
        ntt(y,1);
        int i;
        for(i=0;ix[i]=y[i]*(2-x[i]*y[i]%p)%p;
        ntt(x,-1);
        copy(b,x,m);
    }
    ll c[maxn],d[maxn],e[maxn],f[maxn];
    void sqrt(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            if(a[0]==1)
                b[0]=1;
            else if(a[0]==0)
                b[0]=0;
            else
                //我也不会
                ;
            return;
        }
        sqrt(a,b,m>>1);
//      copy_clear(c,b,m>>1);
        int i;
        for(i=m;i<m<<1;i++)
            b[i]=0;
        inverse(b,d,m);
        init(m<<1);
        for(i=m;i<m<<1;i++)
            b[i]=d[i]=0;
        ll inv2=fp(2,p-2);
        copy_clear(x,a,m);
        ntt(x,1);
        ntt(d,1);
        for(i=0;ix[i]=x[i]*d[i]%p;
        ntt(x,-1);
        for(i=0;i<m;i++)
            b[i]=((b[i]+x[i])%p*inv2)%p;
    }
    void derivative(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m-1;i++)
            b[i]=(i+1)*a[i+1]%p;
        b[m-1]=0;
    }
    void differential(ll *a,ll *b,int m)
    {
//      int i;
//      for(i=m-1;i>=1;i--)
//          b[i]=a[i-1]*inv[i]%p;
        b[0]=0;
    }
    void ln(ll *a,ll *b,int m)
    {
        static ll c[maxn],d[maxn];
        derivative(a,c,m);
        inverse(a,d,m);
        init(m<<1);
        int i;
        for(i=m;i0;
        ntt(c,1);
        ntt(d,1);
        for(i=0;i*d[i]%p;
        ntt(c,-1);
        differential(c,b,m);
    }
    void exp(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            b[0]=1;
            return;
        }
        exp(a,b,m>>1);
        int i;
        for(i=m>>1;i<m;i++)
            b[i]=0;
        ln(b,y,m);
        init(m<<1);
        copy_clear(x,a,m);
        x[0]++;
        for(i=0;i<m;i++)
            x[i]=(x[i]-y[i])%p;
        copy_clear(y,b,m);
        ntt(x,1);
        ntt(y,1);
        for(i=0;ix[i]=x[i]*y[i]%p;
        ntt(x,-1);
        copy(b,x,m);
    }
    void module(ll *a,ll *b,ll *c,int n1,int n2)
    {
        int k=1;
        while(k<=n1-n2+1)
            k<<=1;
        int i;
        for(i=0;i<=n1;i++)
            d[i]=a[i];
        for(i=0;i<=n2;i++)
            e[i]=b[i];
        reverse(d,d+n1+1);
        reverse(e,e+n2+1);
        for(i=n1-n2+1;i1;i++)
            d[i]=e[i]=0;
        inverse(e,f,k);
        for(i=n1-n2+1;i1;i++)
            f[i]=0;
        init(k<<1);
        ntt::ntt(d,1);
        ntt::ntt(f,1);
        for(i=0;i*f[i]%p;
        ntt::ntt(e,-1);
        for(i=0;i<=n1-n2;i++)
            c[i]=e[i];
        reverse(c,c+n1-n2+1);
    }
};
void add(ll &a,ll b)
{
    a=(a+b)%p;
}
ll inv(ll a)
{
    return fp(a,p-2);
}
ll pw1[maxn];
ll pw2[maxn];
ll q;
ll q2;
ll f[maxn];
ll a[maxn];
ll c[maxn];
ll d[maxn];
ll final[maxn];
ll g[2][maxn];
ll h[maxn];
ll e[maxn];

void mul(ll *a,ll *b,ll *c,int n)
{
    static ll d[maxn],e[maxn];
    int k=1;
    while(k<=n)
        k<<=1;
    ntt::init(k<<1);
    int i;
    for(i=0;i1;i++)
        d[i]=e[i]=0;
    for(i=0;i<=n;i++)
    {
        d[i]=a[i];
        e[i]=b[i];
    }
    ntt::ntt(d,1);
    ntt::ntt(e,1);
    for(i=0;i1;i++)
        d[i]=d[i]*e[i]%p;
    ntt::ntt(d,-1);
    //d=a*b
    for(i=0;i1;i++)
        e[i]=0;
    int n2=(k<<1)-1;
    while(!d[n2])
        n2--;
    ntt::module(d,c,e,n2,n);
    for(i=0;ifor(i=0;ifor(i=k;i1;i++)
        d[i]=0;
    ntt::init(k<<1);
    ntt::ntt(d,1);
    ntt::ntt(e,1);
    for(i=0;i1;i++)
        d[i]=d[i]*e[i]%p;
    ntt::ntt(d,-1);
    for(i=0;i%p;
}
void powmod(ll *a,ll *b,ll *c,int m,int n)
{
    if(!n)
        return;
    powmod(a,b,c,m,n>>1);
    mul(a,a,c,m);
    if(n&1)
        mul(a,b,c,m);
}
ll solve(int n,int k)
{
    memset(g,0,sizeof g);
    memset(h,0,sizeof h);
    int now=0;
    g[now][1]=q2*pw1[k]%p;
    g[now][0]=1;
    h[0]=1;
    int i,j;
    for(i=k-1;i>=1;i--)
    {
        now^=1;
        int m=k/i;
        ll c=q2*pw1[i]%p;
        int len=1;
        while(len<=m)
            len<<=1;
        for(j=1;j*g[now^1][j-1];
        e[0]=1;
        ntt::inverse(e,h,len);
        for(j=m+1;j1;j++)
            h[j]=0;
        ntt::init(len<<1);
        ntt::ntt(g[now^1],1);
        ntt::ntt(h,1);
        for(j=0;j1;j++)
            g[now][j]=g[now^1][j]*h[j]%p;
        ntt::ntt(g[now],-1);
        for(j=m+1;j1;j++)
            g[now][j]=0;
    }
    memset(a,0,sizeof a);
    for(i=0;i<=k;i++)
        a[i+1]=-g[now][i]*q2%p;
    a[0]=1;
    int len=1;
    while(len<=k+1)
        len<<=1;
    ntt::inverse(a,f,len<<1);
    if(n<=2*(k+1))
    {
        ll s=0;
        for(i=0;i<=n&&i<=k;i++)
            add(s,f[n-i]*g[now][i]);
        return s;
    }
    memset(a,0,sizeof a);
    memset(c,0,sizeof c);
    memset(d,0,sizeof d);
    for(i=0;i<=k;i++)
        a[i]=-g[now][k-i]*q2%p;
    a[k+1]=1;
    if(k)
        c[1]=1;
    else
        c[0]=-a[0];
    d[0]=1;
    int m=n-k;
    powmod(d,c,a,k+1,m);
//  while(m)
//  {
//      if(m&1)
//          mul(d,c,a,k+1);
//      mul(c,c,a,k+1);
//      m>>=1;
////        for(i=0;i<=k;i++)
////            printf("%lld ",(d[i]+p)%p);
////        printf("\n");
//  }
    reverse(d,d+k+1);
    ntt::init(len<<2);
    ntt::ntt(d,1);
    ntt::ntt(f,1);
    for(i=0;i2;i++)
        final[i]=d[i]*f[i]%p;
    ntt::ntt(final,-1);
    ll s=0;
    for(i=0;i<=k;i++)
        add(s,g[now][i]*final[2*k-i]);
    return s;
//  for(i=0;i<=k;i++)
//      g[now][i]=(g[now][i]+p)%p;
//  memset(f,0,sizeof f);
//  f[0]=1;
//  for(i=1;i<=2*(k+1);i++)
//      for(j=0;j1]*q2%p*g[now][j]);
//  if(n<=2*(k+1))
//  {
//      ll s=0;
//      for(i=0;i<=n&&i<=k;i++)
//          add(s,f[n-i]*g[now][i]);
//      return s;
//  }
//  int len=k+1;
//  for(i=0;i*g[now][len-i-1]%p;
//  a[len]=1;
//  memset(c,0,sizeof c);
//  c[1]=1;
//  memset(d,0,sizeof d);
//  d[0]=1;
//  int m=n-k-1;
//  while(m)
//  {
//      if(m&1)
//          mul(d,c,a,len);
//      mul(c,c,a,len);
//      m>>=1;
//  }
//  memset(final,0,sizeof final);
//  for(i=1;i<=k+1;i++)
//      for(j=0;j<=k;j++)
//          add(final[i],d[j]*f[i+j]);
//  ll s=0;
//  for(i=1;i<=k+1;i++)
//      add(s,final[i]*g[now][k+1-i]);
//  return s;
}
int main()
{
    open("bzoj4944");
    int n,k,x,y;
    scanf("%d%d%d%d",&n,&k,&x,&y);
    q=x*inv(y)%p;
    q2=(y-x)*inv(y)%p;
    pw1[0]=pw2[0]=1;
    int i;
    for(i=1;i<=k;i++)
    {
        pw1[i]=pw1[i-1]*q%p;
        pw2[i]=pw2[i-1]*q2%p;
    }
    ll ans1=solve(n,k);
    ll ans2=solve(n,k-1);
    ll ans=((ans1-ans2)%p+p)%p;
    printf("%lld\n",ans);
    return 0;
}

你可能感兴趣的:(数学&数论,DP,DP--概率&期望DP,FFT&NTT,FFT&NTT--多项式求逆,FFT&NTT--多项式取模,常系数线性递推)