有长为 2 n 2n 2n 的序列 A , B A,B A,B,求有多少个单调不减的序列 C C C,要求 C C C 的每一位均为 A A A 或 B B B 中对应位置上的数,且恰有 n n n 个数来自 A A A。 n ≤ 5 × 1 0 4 n\leq 5\times 10^4 n≤5×104。3s。
设 f [ i ] [ j = 0 / 1 ] [ k ] f[i][j=0/1][k] f[i][j=0/1][k] 表示 DP 确定了前 i i i 个数,第 i i i 个数是 A i / B i A_i/B_i Ai/Bi ,已经有 k k k 个数来自 A A A 的方案数。直接 DP 是 O ( n 2 ) O(n^2) O(n2) 的。
考虑把 k k k 这一维写成形式幂级数的形式,则由 f [ i ] f[i] f[i] 转移到 f [ i + 1 ] f[i+1] f[i+1] 的过程可以看作乘上一个矩阵 [ 0 / x 0 / x 0 / 1 0 / 1 ] \begin{bmatrix}0/x&\quad&0/x\\0/1&&0/1\end{bmatrix} [0/x0/10/x0/1],具体是 0 0 0 还是 x x x(1)取决于 A i , B i , A i + 1 , B i + 1 A_i,B_i,A_{i+1},B_{i+1} Ai,Bi,Ai+1,Bi+1 的大小关系。于是我们要计算 2 n 2n 2n 个矩阵的积。乘法使用分治,复杂度是 O ( n log 2 n ) O(n\log^2 n) O(nlog2n)。
代码:
#include
using namespace std;
int getint(){
int ans=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
ans=ans*10+c-'0';
c=getchar();
}
return ans*f;
}
const int N=1e5+10,mod=998244353;
void inc(int &a,int b){
a+=b;
if(a>=mod)a-=mod;
}
int a[N],b[N];
int qpow(int x,int y){
int ans=1;
while(y){
if(y&1)ans=ans*1ll*x%mod;
x=x*1ll*x%mod;
y>>=1;
}
return ans;
}
struct mat{
int n;
vector<int> a[2][2];
vector<int>* operator[](int x){
return a[x];
}
};
ostream& operator<< (ostream &out,mat& a){
for(int i=0;i<2;i++){out<<">";
for(int j=0;j<2;j++){
out<<"{ ";for(int k=0;k<=a.n;k++)out<<a[i][j][k]<<" ";out<<"} ";
}out<<endl;
}
return out;
}
int w[N<<1],iw[N<<1],maxn;
void init_w(int n){
maxn=n;
w[0]=iw[0]=1;
w[1]=qpow(3,(mod-1)/n);iw[1]=qpow(w[1],mod-2);
for(int i=2;i<=n;i++){
w[i]=w[i-1]*1ll*w[1]%mod;
iw[i]=iw[i-1]*1ll*iw[1]%mod;
}
}
int rev[N<<1];
inline void init_rev(int l){
rev[0]=0;
for(int i=1;i<(1<<l);i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
}
void ntt(vector<int> &a,int n,int tp=1){
int *W=(tp==1?w:iw);
for(int i=0;i<n;i++)if(rev[i]>i)swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1){
int d=maxn/i/2;
for(int j=0;j<n;j+=i<<1){
int t=0;
for(int k=0;k<i;k++,t+=d){
int x=a[j+k],y=a[i+j+k]*1ll*W[t]%mod;
if(x+y<mod)a[j+k]=x+y;
else a[j+k]=x+y-mod;
if(x-y>=0)a[i+j+k]=x-y;
else a[i+j+k]=x+mod-y;
}
}
}
if(tp==-1){
int invn=qpow(n,mod-2);
for(vector<int>::iterator i=a.begin();i!=a.end();++i)
*i=*i*1ll*invn%mod;
//a[i]=a[i]*1ll*invn%mod;
}
}
void operator*= (mat &a,const mat &b_){
mat c,b=b_;c.n=a.n+b.n;
int nn=1,l=0;
while(nn<=a.n+b.n)nn<<=1,++l;
init_rev(l);
for(int i=0;i<2;i++)for(int j=0;j<2;j++){
a[i][j].resize(nn);
b[i][j].resize(nn);
c[i][j].resize(nn);
ntt(a[i][j],nn);
ntt(b[i][j],nn);
}
for(int i=0;i<2;i++){
for(int j=0;j<2;j++){
vector<int>&cc=c[i][j];
for(int k=0;k<2;k++){
vector<int>&aa=a[k][j],&bb=b[i][k];
for(int l=0;l<nn;l++){
cc[l]=(aa[l]*1ll*bb[l]+cc[l])%mod;
}
}
}
}
for(int i=0;i<2;i++)
for(int j=0;j<2;j++)
ntt(c[i][j],nn,-1),swap(a[i][j],c[i][j]);
a.n=c.n;
}
mat m[N];
void calc(int l,int r){
if(l==r)return;
int mid=l+r>>1;
calc(l,mid);
calc(mid+1,r);
m[l]*=m[mid+1];
//cerr<
}
int main(){
int n=getint(),subtask=getint();
for(int i=1;i<=n*2;i++)a[i]=getint();
for(int i=1;i<=n*2;i++)b[i]=getint();
int maxn=1;while(maxn<=n*2)maxn<<=1;
init_w(maxn);
for(int i=1;i<=n*2;i++){
m[i].n=1;
for(int j=0;j<2;j++)for(int k=0;k<2;k++)m[i][j][k].resize(2);
if(a[i-1]<=a[i])m[i][0][0][1]=1;
if(b[i-1]<=a[i])m[i][0][1][1]=1;
if(a[i-1]<=b[i])m[i][1][0][0]=1;
if(b[i-1]<=b[i])m[i][1][1][0]=1;
}
calc(1,n*2);
cout<<(m[1][0][0][n]+m[1][1][0][n])%mod;
return 0;
}