某 SCOI 模拟赛 T1 集合划分(divide)【生成函数 NTT 分治】

题意

有长为 2 n 2n 2n 的序列 A , B A,B A,B,求有多少个单调不减的序列 C C C,要求 C C C 的每一位均为 A A A B B B 中对应位置上的数,且恰有 n n n 个数来自 A A A n ≤ 5 × 1 0 4 n\leq 5\times 10^4 n5×104。3s。

题解

f [ i ] [ j = 0 / 1 ] [ k ] f[i][j=0/1][k] f[i][j=0/1][k] 表示 DP 确定了前 i i i 个数,第 i i i 个数是 A i / B i A_i/B_i Ai/Bi ,已经有 k k k 个数来自 A A A 的方案数。直接 DP 是 O ( n 2 ) O(n^2) O(n2) 的。

考虑把 k k k 这一维写成形式幂级数的形式,则由 f [ i ] f[i] f[i] 转移到 f [ i + 1 ] f[i+1] f[i+1] 的过程可以看作乘上一个矩阵 [ 0 / x 0 / x 0 / 1 0 / 1 ] \begin{bmatrix}0/x&\quad&0/x\\0/1&&0/1\end{bmatrix} [0/x0/10/x0/1],具体是 0 0 0 还是 x x x(1)取决于 A i , B i , A i + 1 , B i + 1 A_i,B_i,A_{i+1},B_{i+1} Ai,Bi,Ai+1,Bi+1 的大小关系。于是我们要计算 2 n 2n 2n 个矩阵的积。乘法使用分治,复杂度是 O ( n log ⁡ 2 n ) O(n\log^2 n) O(nlog2n)

代码:

#include
using namespace std;
int getint(){
	int ans=0,f=1;
	char c=getchar();
	while(c<'0'||c>'9'){
		if(c=='-')f=-1;
		c=getchar();
	}
	while(c>='0'&&c<='9'){
		ans=ans*10+c-'0';
		c=getchar();
	}
	return ans*f;
}
const int N=1e5+10,mod=998244353;
void inc(int &a,int b){
	a+=b;
	if(a>=mod)a-=mod;
}
int a[N],b[N];

int qpow(int x,int y){
	int ans=1;
	while(y){
		if(y&1)ans=ans*1ll*x%mod;
		x=x*1ll*x%mod;
		y>>=1;
	}
	return ans;
}
struct mat{
	int n;
	vector<int> a[2][2];
	vector<int>* operator[](int x){
		return a[x];
	}
};
ostream& operator<< (ostream &out,mat& a){
	for(int i=0;i<2;i++){out<<">";
		for(int j=0;j<2;j++){
			out<<"{ ";for(int k=0;k<=a.n;k++)out<<a[i][j][k]<<" ";out<<"} ";
		}out<<endl;
	}
	return out;
}
int w[N<<1],iw[N<<1],maxn;
void init_w(int n){
	maxn=n;
	w[0]=iw[0]=1;
	w[1]=qpow(3,(mod-1)/n);iw[1]=qpow(w[1],mod-2);
	for(int i=2;i<=n;i++){
		w[i]=w[i-1]*1ll*w[1]%mod;
		iw[i]=iw[i-1]*1ll*iw[1]%mod;
	}
}
int rev[N<<1];
inline void init_rev(int l){
	rev[0]=0;
	for(int i=1;i<(1<<l);i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
}
void ntt(vector<int> &a,int n,int tp=1){
	int *W=(tp==1?w:iw);
	for(int i=0;i<n;i++)if(rev[i]>i)swap(a[i],a[rev[i]]);
	for(int i=1;i<n;i<<=1){
		int d=maxn/i/2;
		for(int j=0;j<n;j+=i<<1){
			int t=0;
			for(int k=0;k<i;k++,t+=d){
				int x=a[j+k],y=a[i+j+k]*1ll*W[t]%mod;
				if(x+y<mod)a[j+k]=x+y;
				else a[j+k]=x+y-mod;
				if(x-y>=0)a[i+j+k]=x-y;
				else a[i+j+k]=x+mod-y;
			}
		}
	}
	if(tp==-1){
		int invn=qpow(n,mod-2);
		for(vector<int>::iterator i=a.begin();i!=a.end();++i)
			*i=*i*1ll*invn%mod;
			//a[i]=a[i]*1ll*invn%mod;
	}
}
void operator*= (mat &a,const mat &b_){
	mat c,b=b_;c.n=a.n+b.n;
	int nn=1,l=0;
	while(nn<=a.n+b.n)nn<<=1,++l;
	init_rev(l);
	for(int i=0;i<2;i++)for(int j=0;j<2;j++){
		a[i][j].resize(nn);
		b[i][j].resize(nn);
		c[i][j].resize(nn);
		ntt(a[i][j],nn);
		ntt(b[i][j],nn);
	}
	for(int i=0;i<2;i++){
		for(int j=0;j<2;j++){
			vector<int>&cc=c[i][j];
			for(int k=0;k<2;k++){
				vector<int>&aa=a[k][j],&bb=b[i][k];
				for(int l=0;l<nn;l++){
					cc[l]=(aa[l]*1ll*bb[l]+cc[l])%mod;
				}
			}
		}
	}
	for(int i=0;i<2;i++)
		for(int j=0;j<2;j++)
			ntt(c[i][j],nn,-1),swap(a[i][j],c[i][j]);
	a.n=c.n;
}

mat m[N];
void calc(int l,int r){
	if(l==r)return;
	int mid=l+r>>1;
	calc(l,mid);
	calc(mid+1,r);
	m[l]*=m[mid+1];
	//cerr<
}

int main(){
	int n=getint(),subtask=getint();
	for(int i=1;i<=n*2;i++)a[i]=getint();
	for(int i=1;i<=n*2;i++)b[i]=getint();
	int maxn=1;while(maxn<=n*2)maxn<<=1;
	init_w(maxn);
	for(int i=1;i<=n*2;i++){
		m[i].n=1;
		for(int j=0;j<2;j++)for(int k=0;k<2;k++)m[i][j][k].resize(2);
		if(a[i-1]<=a[i])m[i][0][0][1]=1;
		if(b[i-1]<=a[i])m[i][0][1][1]=1;
		if(a[i-1]<=b[i])m[i][1][0][0]=1;
		if(b[i-1]<=b[i])m[i][1][1][0]=1;
	}
	calc(1,n*2);
	cout<<(m[1][0][0][n]+m[1][1][0][n])%mod;
	return 0;
}

你可能感兴趣的:(题解,#,来源-模拟赛,#,数学-组合数学-生成函数)