[JZOJ5251]决战

题目大意

给定一个 n×3 的矩形,你要在一些格子上放东西,一个格子最多只能放一个。而且一个格子上放了东西会对四周有影响。
输入会给定一个 3×3 01 矩阵,表示当一个 3×3 的子矩阵中心放了东西时,哪些地方不能放东西。
譬如矩阵

010111010

表示一个东西上下左右都不能放东西。
请求出恰好放了 m 个东西的方案,答案对 998244353 取模。

1n2500

题目分析

首先我们可以写出一个很简单的状压 dp fi,s,j 表示做到第 i 行,上一行的状态为 s ,已经放了 j 个东西的方案。但是这样很慢,怎么办呢?
可以发现,如果我们把 dp 的最后一维看成一个多项式, xi 的系数就是放了 i 个东西的方案数,那么转移的过程其实就是多项式乘法,而且最后我们相当于求答案多项式 xm 的系数。
考虑使用插值的方法求出这个多项式,这里采用傅里叶变换。我们做次数界次 dp ,每次把一个主次数界次单位根的次幂代入多项式然后计算。最后我们会得到次数界个值,其实就是答案多项式做了 DFT 的结果,最后做一次 IDFT 就好了。
可是中间的 dp 依然很慢,因为要做次数界(最多有 3n )次。考虑将 dp 的转移用矩阵乘法快速幂来优化就好了。
时间复杂度 O((23)33nlogn+3nlog3n)

代码实现

#include 
#include 
#include 

using namespace std;

const int P=998244353;
const int N=2505;
const int S=1<<3;
const int L=8192;
const int G=3;

int bitcnt[8]={0,1,1,2,1,2,2,3};
int omega[L+5],t[L+5],trs[L+5];
int mat[3][3],tmp[3][3];
bool legal[1<<6];
int f[N][S];
int a[L+5];
int n,m,len,wn;

struct matrix
{
    int num[S][S];
    int r,c;

    matrix operator*(matrix const mat)const
    {
        matrix ret;ret.r=r,ret.c=mat.c,memset(ret.num,0,sizeof ret.num);
        for (int i=0;ifor (int j=0;jfor (int k=0;k1ll*num[i][k]*mat.num[k][j]%P)%=P;
        return ret;
    }
}zero,one,F;

matrix operator^(matrix x,int y)
{
    matrix ret=zero;
    for (;y;y>>=1,x=x*x) if (y&1) ret=ret*x;
    return ret;
}

int quick_power(int x,int y)
{
    int ret=1;
    for (;y;y>>=1,x=1ll*x*x%P) if (y&1) ret=1ll*ret*x%P;
    return ret;
}

void DFT(int *a,int sig)
{
    for (int i=0;ifor (int l=2;l<=len;l<<=1)
        for (int h=l>>1,p=len/l,i=0;ifor (int w=omega[sig>0?i*p:len-i*p],j=i;jint u=t[j],v=1ll*t[j+h]*w%P;
                t[j]=(u+v)%P,t[j+h]=(u-v+P)%P;
            }
    for (int i=0;ivoid NTT_pre()
{
    for (len=1;len<=n*3;len<<=1);
    wn=quick_power(G,(P-1)/len),omega[0]=1;
    for (int i=1;i<=len;++i) omega[i]=1ll*omega[i-1]*wn%P;
    for (int i=0;iint ret=0;
        for (int x=i,j=1;j>=1,j<<=1) ret=(ret<<1)|(x&1);
        trs[i]=ret;
    }
}

void pre()
{
    for (int sta=0;sta<1<<6;++sta)
    {
        for (int s=sta,i=0;i<2;++i)
            for (int j=0;j<3;++j,s>>=1)
                tmp[i][j]=s&1;
        bool flag=1;
        for (int i=0;flag&&i<3;++i)
            for (int j=0;flag&&j<3;++j)
                if (tmp[i][j])
                    for (int x=0;flag&&x<3;++x)
                        for (int y=0;flag&&y<3;++y)
                            if (!(x==1&&y==1)&&mat[x][y])
                            {
                                int u=x-1+i,v=y-1+j;
                                if (u>=0&&u<3&&v>=0&&v<3&&tmp[u][v]) flag=0;
                            }
        legal[sta]=flag;
    }
    zero.r=zero.c=S;
    for (int i=0;i1;
}

int dp(int w)
{
    memset(f,0,sizeof f);
    int pw[4];pw[0]=1;
    for (int i=1;i<=3;++i) pw[i]=1ll*pw[i-1]*w%P;
    F.r=1,F.c=S,memset(F.num,0,sizeof F.num);
    for (int s=0;s0][s]+=quick_power(w,bitcnt[s]))%=P;
    one.r=one.c=S,memset(one.num,0,sizeof one.num);
    for (int s=0;sfor (int s_=0;s_if (legal[s|(s_<<3)])
                (one.num[s][s_]+=pw[bitcnt[s_]]%P)%=P;
    F=F*(one^(n-1));
    int ret=0;
    for (int s=0;s0][s])%=P;
    return ret;
}

int main()
{
    freopen("battle.in","r",stdin),freopen("battle.out","w",stdout);
    scanf("%d%d",&n,&m);
    if (m>3*n) printf("0\n");
    {
        for (int i=0;i<3;++i)
            for (int j=0;j<3;++j)
                scanf("%d",&mat[i][j]);
        pre(),NTT_pre();
        for (int i=0;i1);
        for (int inv=quick_power(len,P-2),i=0;i1ll*a[i]*inv%P;
        printf("%d\n",a[m]);
    }
    fclose(stdin),fclose(stdout);
    return 0;
}

你可能感兴趣的:(状态压缩动态规划,矩阵乘法,快速傅里叶变换,纪中OJ,线性代数)