Educational Codeforces Round 49 (Rated for Div. 2) E - Inverse Coloring

题意:
有一个n*n的方格需要染成黑白颜色
定义方格为beautiful的当且仅当每对相邻行的对应格子都相同或都不同,对列同理。
定义方格为suitable的当且仅当不存在大小>=k的同色子矩阵
求有多少种染色方法使得这个n*n的方格beautiful并且suitable

首先考虑单独一行
可以通过dp算出长度为n的序列中最大连续同色长度为i的总方案

那么再考虑列
因为方格是n*n的 所以行跟列考虑其实是一样的
选出了一行方案的再选列的方案就能确定涂完一个方格的方案

可以当作先选出第一行的选色方案
然后再选列方案 根据第一行的选择涂满整个方格

假设第一行最长的连续同色长度为x,列的最长连续长度为y
那么需要满足x*y

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
#define pb push_back
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d%d",&n,&m)
#define sddd(n,m,k) scanf("%d%d%d",&n,&m,&k)
#define sld(n) scanf("%lld",&n)
#define sldd(n,m) scanf("%lld%lld",&n,&m)
#define slddd(n,m,k) scanf("%lld%lld%lld",&n,&m,&k)
#define sf(n) scanf("%lf",&n)
#define sff(n,m) scanf("%lf%lf",&n,&m)
#define sfff(n,m,k) scanf("%lf%lf%lf",&n,&m,&k)
#define ss(str) scanf("%s",str)
#define ansn() printf("%d\n",ans)
#define lansn() printf("%lld\n",ans)
#define r0(i,n) for(int i=0;i<(n);++i)
#define r1(i,e) for(int i=1;i<=e;++i)
#define rn(i,e) for(int i=e;i>=1;--i)
#define mst(abc,bca) memset(abc,bca,sizeof abc)
#define lowbit(a) (a&(-a))
#define all(a) a.begin(),a.end()
#define pii pair
#define pll pair
#define mp(aa,bb) make_pair(aa,bb)
#define lrt rt<<1
#define rrt rt<<1|1
#define X first
#define Y second
#define PI (acos(-1.0))
double pi = acos(-1.0);
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
//const ll mod = 1000000007;
const double eps=1e-12;
const int inf=0x3f3f3f3f;
//const ll infl = 100000000000000000;//1e17
const int maxn=  2e6+20;
const int maxm = 5e3+20;
//muv[i]=(p-(p/i))*muv[p%i]%p;
int in(int &ret) {
    char c;
    int sgn ;
    if(c=getchar(),c==EOF)return -1;
    while(c!='-'&&(c<'0'||c>'9'))c=getchar();
    sgn = (c=='-')?-1:1;
    ret = (c=='-')?0:(c-'0');
    while(c=getchar(),c>='0'&&c<='9')ret = ret*10+(c-'0');
    ret *=sgn;
    return 1;
}

ll dp[2][555][555];
ll mod = 998244353;
void add(ll &a,ll b) {
    a += b;
    if(a>=mod)a-= mod;
}
ll cnt[maxn];
ll pre[maxn];
int main() {
#ifdef LOCAL
    freopen("input.txt","r",stdin);
//    freopen("output.txt","w",stdout);
#endif // LOCAL

    int n,K;
    sdd(n,K);
    dp[0][0][0] = 1;
    for(int ii=1; ii<=n; ++ii) {
        int i = ii&1;
        int lt = i^1;
        mst(dp[i],0);
        for(int k=0; k<=n; ++k) {
            for(int j=0; j<=k; ++j) {
                add(dp[i][j+1][max(k,j+1)],dp[lt][j][k]);
                add(dp[i][1][max(k,1)],dp[lt][j][k]);
            }
        }
    }
    for(int i=1; i<=n; ++i) {
        int p = n&1;
        for(int j=1; j<=n; ++j)
            add(cnt[i],dp[p][j][i]);
    }
    for(int i=1; i<=n; ++i) {
        pre[i] = pre[i-1];
        add(pre[i],cnt[i]);
    }
    ll ans = 0;
    for(int i=1; i<=n; ++i) {
        int can = (K-1)/i ;
        can = min(can,n);
        ll ad = cnt[i]*pre[can]%mod;
        add(ans,ad);
    }
    ans = ans * ((mod+1)>>1)%mod;
    lansn();
    return 0;
}

你可能感兴趣的:(codeforces,补题,dp)