【Atcoder】 [ABC262Ex] Max Limited Sequence

题目链接

Atcoder方向
Luogu方向

题目解法

先讲一下某位歌姬的故事 的解法,算是这题的弱化版(只是需要多一个 l , r l,r l,r 的离散化)
首先把区间限制离散化,把限制变成一段区间之内最大值为 m i m_i mi
显然可以对区间分别考虑
考虑一段区间被多次覆盖的情况,这段区间的限制即为所有限制的最小值
考虑不同的 m i m_i mi 限制互不影响,所以可以对同一个 m i m_i mi 统一做,答案累乘即可
现在问题变成了对于固定的 m i m_i mi,有许多限制长成在 l j l_j lj r j r_j rj 之间必须有一个染色,求方案数
这也是一个经典的 d p dp dp
考虑令 d p i , j dp_{i,j} dpi,j 为到 i i i,上一个染色为 j j j 的方案数
考虑对于 i i i 的限制为 k − i k-i ki 中必须有一个染色( k k k 是极大的)
那么转移即为:
d p i , i = ∑ j = 0 i − 1 d p i − 1 , j dp_{i,i}=\sum_{j=0}^{i-1}dp_{i-1,j} dpi,i=j=0i1dpi1,j
d p i , j = d p i − 1 , j ∗ ( m l e n − ( m − 1 ) l e n )    ( j ≥ m x ) dp_{i,j}=dp_{i-1,j}*(m^{len}-(m-1)^{len})\;(j\ge mx) dpi,j=dpi1,j(mlen(m1)len)(jmx)
d p i , j = 0    ( j < m x ) dp_{i,j}=0\;(jdpi,j=0(j<mx)
其中 l e n len len 为当前一段的长度
这样就可以用 O ( T q 2 ) O(Tq^2) O(Tq2) 的时间做完了
还是放一下那题的代码:

#include 
using namespace std;
typedef pair<int,int> pii;
const int N=1100,P=998244353;
int n,q,A,cnt,disc[N],h[N],l[N],r[N],m[N];
int tot,lst[N],dp[N][N],lim[N];
vector<int> vec[N];
inline int read(){
	int FF=0,RR=1;
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
	for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
	return FF*RR;
}
int qmi(int a,int b){
    int res=1;
    for(;b;b>>=1){
        if(b&1) res=1ll*res*a%P;
        a=1ll*a*a%P;
    }
    return res;
}
int calc(int x){
    for(int i=1;i<=cnt;i++) lim[i]=0;
    for(int i=1;i<=q;i++) if(m[i]==x) lim[r[i]]=max(lim[r[i]],l[i]);
	for(int i=0;i<vec[x].size();i++) for(int j=0;j<=i;j++) dp[i][j]=0;
	int tru=lst[x];
	dp[0][0]=1;
	for(int i=1;i<vec[x].size();i++){
        int p=vec[x][i];
		int mx=lim[i],len=disc[p+1]-disc[p];
		//do not reach the bound
		for(int j=mx;j<i;j++) dp[i][j]=1ll*dp[i-1][j]*qmi(tru-1,len)%P;
		//reach the bound
		for(int j=0;j<i;j++) dp[i][i]=(dp[i-1][j]+dp[i][i])%P;
		dp[i][i]=1ll*dp[i][i]*(qmi(tru,len)-qmi(tru-1,len)+P)%P;
	}
	int res=0,nn=vec[x].size()-1;
	for(int i=0;i<=nn;i++) res=(res+dp[nn][i])%P;
	return res;
}
void work(){
	n=read(),q=read(),A=read();
    cnt=tot=0;
	disc[++cnt]=1,disc[++cnt]=n+1;
	for(int i=1;i<=q;i++){
		l[i]=read(),r[i]=read(),m[i]=read();
		disc[++cnt]=l[i],disc[++cnt]=r[i]+1,lst[i]=m[i];
	}
	sort(disc+1,disc+cnt+1);sort(lst+1,lst+q+1);
	cnt=unique(disc+1,disc+cnt+1)-disc-1;
	tot=unique(lst+1,lst+q+1)-lst-1;
    memset(h,0x3f,sizeof(h));
	for(int i=1;i<=q;i++){
		l[i]=lower_bound(disc+1,disc+cnt+1,l[i])-disc;
		r[i]=lower_bound(disc+1,disc+cnt+1,r[i]+1)-disc-1;
		m[i]=lower_bound(lst+1,lst+tot+1,m[i])-lst;
        for(int j=l[i];j<=r[i];j++) h[j]=min(h[j],m[i]);
	}
	for(int i=1;i<=tot;i++) vec[i].push_back(0);
	int ans=1;
	for(int i=1;i<cnt;i++){
        if(h[i]==h[0]) ans=1ll*ans*qmi(A,disc[i+1]-disc[i])%P;
        else vec[h[i]].push_back(i);
    }
    bool flg=1;
    for(int i=1;i<=q;i++){
        l[i]=lower_bound(vec[m[i]].begin(),vec[m[i]].end(),l[i])-vec[m[i]].begin();
        r[i]=upper_bound(vec[m[i]].begin(),vec[m[i]].end(),r[i])-vec[m[i]].begin()-1;
        if(l[i]>r[i]){ flg=0;break;}
    }
    if(flg){
        for(int i=1;i<=tot;i++) ans=1ll*ans*calc(i)%P;
        printf("%d\n",ans);
    }
    else puts("0");
    for(int i=1;i<=tot;i++) vec[i].clear();
}
int main(){
	int T=read();
	while(T--) work();
	return 0;
}

考虑到 d p dp dp 转移的操作为区间乘与区间查询和,这都是线段树的基本操作,考虑用线段树维护即可
同理,一开始求一段区间的最小限制时也用线段树维护
时间复杂度 O ( q l o g n ) O(qlogn) O(qlogn)

#include 
using namespace std;
typedef pair<int,int> pii;
const int N=400100,P=998244353,inf=0x3f3f3f3f;
int n,q,A,h[N],l[N],r[N],m[N];
int tot,lst[N];
vector<int> vec[N];
vector<pii> lim[N];
struct Segment1{
	int seg[N<<2];
	void modify(int l,int r,int x,int L,int R,int m){
		if(L<=l&&r<=R){ seg[x]=min(seg[x],m);return;}
		int mid=(l+r)>>1;
		if(mid>=L) modify(l,mid,x<<1,L,R,m);
		if(mid<R) modify(mid+1,r,x<<1^1,L,R,m);
	}
	void slv(int l,int r,int x,int mn){
		if(l==r){ h[l]=mn;return;}
		int mid=(l+r)>>1;
		slv(l,mid,x<<1,min(mn,seg[x<<1])),slv(mid+1,r,x<<1^1,min(mn,seg[x<<1^1]));
	}
}sg1;
struct Segment2{
	int seg[N<<2],mul[N<<2];
	void pushdown(int x){
		seg[x<<1]=1ll*seg[x<<1]*mul[x]%P,seg[x<<1^1]=1ll*seg[x<<1^1]*mul[x]%P;
		mul[x<<1]=1ll*mul[x<<1]*mul[x]%P,mul[x<<1^1]=1ll*mul[x<<1^1]*mul[x]%P;
		mul[x]=1;
	}
	void modify(int l,int r,int x,int pos,int val){//单点改
		if(l==r){ seg[x]=val;return;}
		pushdown(x);
		int mid=(l+r)>>1;
		if(mid>=pos) modify(l,mid,x<<1,pos,val);
		else modify(mid+1,r,x<<1^1,pos,val);
		seg[x]=(seg[x<<1]+seg[x<<1^1])%P;
	}
	void MODIFY(int l,int r,int x,int L,int R,int mu){
		if(L<=l&&r<=R){
			seg[x]=1ll*seg[x]*mu%P,mul[x]=1ll*mul[x]*mu%P;
			return;
		}
		pushdown(x);
		int mid=(l+r)>>1;
		if(mid>=L) MODIFY(l,mid,x<<1,L,R,mu);
		if(mid<R) MODIFY(mid+1,r,x<<1^1,L,R,mu);
		seg[x]=(seg[x<<1]+seg[x<<1^1])%P;
	}
	int query(int l,int r,int x,int L,int R){
		if(L<=l&&r<=R) return seg[x];
		pushdown(x);
		int mid=(l+r)>>1;
        if(mid>=L&&mid<R) return (query(l,mid,x<<1,L,R)+query(mid+1,r,x<<1^1,L,R))%P;
        if(mid>=L) return query(l,mid,x<<1,L,R);
		return query(mid+1,r,x<<1^1,L,R);
	}
}sg2;
inline int read(){
	int FF=0,RR=1;
	char ch=getchar();
	for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
	for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
	return FF*RR;
}
int calc(int x){
    sort(lim[x].begin(),lim[x].end());
	int tru=lst[x],nn=vec[x].size()-1;
	sg2.MODIFY(0,nn,1,0,nn,0);
    sg2.modify(0,nn,1,0,1);
	for(int i=1,j=0;i<=nn;i++){
		int mx=0;
		while(j<lim[x].size()&&lim[x][j].first==i) mx=max(mx,lim[x][j].second),j++;
		//reach the bound
		sg2.modify(0,nn,1,i,sg2.query(0,nn,1,0,i-1));
		//do not reach the bound
		if(0<mx) sg2.MODIFY(0,nn,1,0,mx-1,0);
        if(mx<=i-1) sg2.MODIFY(0,nn,1,mx,i-1,tru);
	}
	return sg2.query(0,nn,1,0,nn);
}
int main(){
	n=read(),A=read(),q=read();
	for(int i=1;i<=q;i++) l[i]=read(),r[i]=read(),m[i]=read(),lst[i]=m[i];
	sort(lst+1,lst+q+1);
	tot=unique(lst+1,lst+q+1)-lst-1;
    memset(sg1.seg,0x3f,sizeof(sg1.seg));
	for(int i=1;i<=q;i++){
		m[i]=lower_bound(lst+1,lst+tot+1,m[i])-lst;
    	sg1.modify(1,n,1,l[i],r[i],m[i]);
	}
    sg1.slv(1,n,1,sg1.seg[1]);
	for(int i=1;i<=tot;i++) vec[i].push_back(0);
	int ans=1;
	for(int i=1;i<=n;i++){
        if(h[i]==inf) ans=1ll*ans*(A+1)%P;
        else vec[h[i]].push_back(i);
    }
    for(int i=1;i<=q;i++){
        l[i]=lower_bound(vec[m[i]].begin(),vec[m[i]].end(),l[i])-vec[m[i]].begin();
        r[i]=upper_bound(vec[m[i]].begin(),vec[m[i]].end(),r[i])-vec[m[i]].begin()-1;
		lim[m[i]].push_back(make_pair(r[i],l[i]));
        if(l[i]>r[i]){ puts("0");exit(0);}
    }
    for(int i=1;i<=n<<2;i++) sg2.mul[i]=1;
    for(int i=1;i<=tot;i++) ans=1ll*ans*calc(i)%P;
    printf("%d\n",ans);
	return 0;
}

你可能感兴趣的:(Atcoder,算法)