给定一个 n×3 的矩形,你要在一些格子上放东西,一个格子最多只能放一个。而且一个格子上放了东西会对四周有影响。
输入会给定一个 3×3 的 01 矩阵,表示当一个 3×3 的子矩阵中心放了东西时,哪些地方不能放东西。
譬如矩阵
1≤n≤2500
首先我们可以写出一个很简单的状压 dp : fi,s,j 表示做到第 i 行,上一行的状态为 s ,已经放了 j 个东西的方案。但是这样很慢,怎么办呢?
可以发现,如果我们把 dp 的最后一维看成一个多项式, xi 的系数就是放了 i 个东西的方案数,那么转移的过程其实就是多项式乘法,而且最后我们相当于求答案多项式 xm 的系数。
考虑使用插值的方法求出这个多项式,这里采用傅里叶变换。我们做次数界次 dp ,每次把一个主次数界次单位根的次幂代入多项式然后计算。最后我们会得到次数界个值,其实就是答案多项式做了 DFT 的结果,最后做一次 IDFT 就好了。
可是中间的 dp 依然很慢,因为要做次数界(最多有 3n )次。考虑将 dp 的转移用矩阵乘法快速幂来优化就好了。
时间复杂度 O((23)33nlogn+3nlog3n) 。
#include
#include
#include
using namespace std;
const int P=998244353;
const int N=2505;
const int S=1<<3;
const int L=8192;
const int G=3;
int bitcnt[8]={0,1,1,2,1,2,2,3};
int omega[L+5],t[L+5],trs[L+5];
int mat[3][3],tmp[3][3];
bool legal[1<<6];
int f[N][S];
int a[L+5];
int n,m,len,wn;
struct matrix
{
int num[S][S];
int r,c;
matrix operator*(matrix const mat)const
{
matrix ret;ret.r=r,ret.c=mat.c,memset(ret.num,0,sizeof ret.num);
for (int i=0;ifor (int j=0;jfor (int k=0;k1ll*num[i][k]*mat.num[k][j]%P)%=P;
return ret;
}
}zero,one,F;
matrix operator^(matrix x,int y)
{
matrix ret=zero;
for (;y;y>>=1,x=x*x) if (y&1) ret=ret*x;
return ret;
}
int quick_power(int x,int y)
{
int ret=1;
for (;y;y>>=1,x=1ll*x*x%P) if (y&1) ret=1ll*ret*x%P;
return ret;
}
void DFT(int *a,int sig)
{
for (int i=0;ifor (int l=2;l<=len;l<<=1)
for (int h=l>>1,p=len/l,i=0;ifor (int w=omega[sig>0?i*p:len-i*p],j=i;jint u=t[j],v=1ll*t[j+h]*w%P;
t[j]=(u+v)%P,t[j+h]=(u-v+P)%P;
}
for (int i=0;ivoid NTT_pre()
{
for (len=1;len<=n*3;len<<=1);
wn=quick_power(G,(P-1)/len),omega[0]=1;
for (int i=1;i<=len;++i) omega[i]=1ll*omega[i-1]*wn%P;
for (int i=0;iint ret=0;
for (int x=i,j=1;j>=1,j<<=1) ret=(ret<<1)|(x&1);
trs[i]=ret;
}
}
void pre()
{
for (int sta=0;sta<1<<6;++sta)
{
for (int s=sta,i=0;i<2;++i)
for (int j=0;j<3;++j,s>>=1)
tmp[i][j]=s&1;
bool flag=1;
for (int i=0;flag&&i<3;++i)
for (int j=0;flag&&j<3;++j)
if (tmp[i][j])
for (int x=0;flag&&x<3;++x)
for (int y=0;flag&&y<3;++y)
if (!(x==1&&y==1)&&mat[x][y])
{
int u=x-1+i,v=y-1+j;
if (u>=0&&u<3&&v>=0&&v<3&&tmp[u][v]) flag=0;
}
legal[sta]=flag;
}
zero.r=zero.c=S;
for (int i=0;i1;
}
int dp(int w)
{
memset(f,0,sizeof f);
int pw[4];pw[0]=1;
for (int i=1;i<=3;++i) pw[i]=1ll*pw[i-1]*w%P;
F.r=1,F.c=S,memset(F.num,0,sizeof F.num);
for (int s=0;s0][s]+=quick_power(w,bitcnt[s]))%=P;
one.r=one.c=S,memset(one.num,0,sizeof one.num);
for (int s=0;sfor (int s_=0;s_if (legal[s|(s_<<3)])
(one.num[s][s_]+=pw[bitcnt[s_]]%P)%=P;
F=F*(one^(n-1));
int ret=0;
for (int s=0;s0][s])%=P;
return ret;
}
int main()
{
freopen("battle.in","r",stdin),freopen("battle.out","w",stdout);
scanf("%d%d",&n,&m);
if (m>3*n) printf("0\n");
{
for (int i=0;i<3;++i)
for (int j=0;j<3;++j)
scanf("%d",&mat[i][j]);
pre(),NTT_pre();
for (int i=0;i1);
for (int inv=quick_power(len,P-2),i=0;i1ll*a[i]*inv%P;
printf("%d\n",a[m]);
}
fclose(stdin),fclose(stdout);
return 0;
}