LOJ#2331. 「清华集训 2017」某位歌姬的故事

将序列离散化后,可以给每个点确定一个取值的上界 wi w i

对于限制 (lj,rj,cj) ( l j , r j , c j ) ,只有 [lj,rj] [ l j , r j ] wi=cj w i = c j 的点能贡献
对于一个 cj c j ,将所有 wi=cj w i = c j 的点拿出来,令 f[i][j] f [ i ] [ j ] 表示满足了前i个区间,最后一个权值取到了 wi w i 的点 i i j j 的方案数,转移用前缀和优化到O(1)

不同的 cj c j 互相独立,分别dp后把贡献乘起来

code:

#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define ll long long
using namespace std;

const int maxn = 2005;
const int mod  = 998244353;
inline void add(int &a,const int &b){a+=b;if(a>=mod)a-=mod;}

int pw(int x,int k)
{
    int re=1;
    for(;k;k>>=1,x=(ll)x*x%mod) if(k&1)
        re=(ll)re*x%mod;
    return re;
}
int inv(int x){ return pw(x,mod-2); }

int n,Q,A;

struct Op
{
    int l,r,c;
    friend inline bool operator <(const Op x,const Op y)
    {
        if(x.c==y.c) return x.l==y.l?x.relse return x.cint t[maxn],tp;
struct Point
{
    int l,r;
}p[maxn],np[maxn]; int cnt,N;
map<int,int>mp;
void Trans()
{
    mp.clear(); tp=0;
    for(int i=1;i<=Q;i++) t[++tp]=op[i].l,t[++tp]=op[i].r;
    sort(t+1,t+tp+1); t[0]=0; cnt=0;
    for(int i=1;i<=tp;i++)
    {
        if(t[i]-1>t[i-1]) p[++cnt]=(Point){t[i-1]+1,t[i]-1};
        if(i==1||t[i]!=t[i-1]) p[++cnt]=(Point){t[i],t[i]},mp[t[i]]=cnt;
    }
    if(p[cnt].r!=n) p[cnt+1]=(Point){p[cnt].r+1,n},++cnt;
    for(int i=1;i<=Q;i++) op[i].l=mp[op[i].l],op[i].r=mp[op[i].r];
}

struct node
{
    int pos,c;
    friend inline bool operator <(const node x,const node y){return x.posint an;
multiset<int>S;
multiset<int>::iterator it;
int U[maxn];
void Checku()
{
    an=0;
    for(int i=1;i<=Q;i++) a[++an]=(node){op[i].l,op[i].c},a[++an]=(node){op[i].r+1,-op[i].c};
    sort(a+1,a+an+1);
    int nowa=1; S.clear();
    for(int i=1;i<=cnt;i++)
    {
        while(nowa<=an&&a[nowa].pos==i)
        {
            if(a[nowa].c<0) it=S.find(-a[nowa].c),S.erase(it);
            else S.insert(a[nowa].c);
            nowa++;
        }
        if(S.empty()) U[i]=A;
        else it=S.begin(),U[i]=(*it);
    }
}

int ok0[maxn],ok1[maxn],ok2[maxn];
int f[maxn],v[maxn];
int dp()
{
    memset(v,0,sizeof v);
    int ans=1;

    sort(op+1,op+Q+1);
    for(int i=1;i<=Q;)
    {
        int j; for(j=i+1;j<=Q&&op[j].c==op[i].c;j++);j--;
        int c=op[i].c;

        N=0; for(int k=1;k<=cnt;k++) if(U[k]==c) np[++N]=p[k],v[k]=1,f[N]=0;
        for(int k=1;k<=N;k++)
        {
            int siz=np[k].r-np[k].l+1;
            ok0[k]=pw(c-1,siz);
            ok2[k]=pw(c,siz);
            ok1[k]=(ok2[k]-ok0[k]+mod)%mod;
        }

        f[0]=1;
        for(;i<=j;i++)
        {
            int s=0,k;
            for(k=0;np[k].r0;
            for(;np[k].r<=p[op[i].r].r&&k<=N;k++)
            {
                add(f[k],(ll)s*ok1[k]%mod);
                s=(ll)s*ok0[k]%mod;
            }
            for(;k<=N;k++) f[k]=0;
        }
        int s=0;
        for(int k=N,temp=1;k>=0;k--)
        {
            add(s,(ll)f[k]*temp%mod);
            temp=(ll)temp*ok2[k]%mod;
        }
        ans=(ll)ans*s%mod;
    }

    for(int i=1;i<=cnt;i++) if(!v[i]) ans=(ll)ans*pw(A,p[i].r-p[i].l+1)%mod;
    return ans;
}

int main()
{
    int tcase; scanf("%d",&tcase);
    while(tcase--)
    {
        scanf("%d%d%d",&n,&Q,&A);
        for(int i=1;i<=Q;i++) scanf("%d%d%d",&op[i].l,&op[i].r,&op[i].c);
        Trans();
        Checku();

        printf("%d\n",dp());
    }

    return 0;
}

你可能感兴趣的:(DP,LOJ)